Compare commits

..

3 Commits

2 changed files with 69 additions and 15 deletions

View File

@@ -228,6 +228,13 @@ class PythonHighlighter(
for item in expr.items: for item in expr.items:
item.accept(self) item.accept(self)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
for key in expr.keys:
if key is not None:
key.accept(self)
for value in expr.values:
value.accept(self)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
expr.object.accept(self) expr.object.accept(self)
expr.index.accept(self) expr.index.accept(self)
@@ -240,6 +247,10 @@ class PythonHighlighter(
if expr.step is not None: if expr.step is not None:
expr.step.accept(self) expr.step.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
class MidasHighlighter( class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None] Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
@@ -266,8 +277,9 @@ class MidasHighlighter(
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate") self.wrap(stmt, "predicate")
self.wrap(LocatableToken(stmt.name), "predicate-name") self.wrap(LocatableToken(stmt.name), "predicate-name")
stmt.type.accept(self) for spec in stmt.params:
stmt.condition.accept(self) self._visit_param_spec(spec)
stmt.body.accept(self)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.wrap(expr, "logical-expr") self.wrap(expr, "logical-expr")
@@ -283,6 +295,14 @@ class MidasHighlighter(
self.wrap(expr, "unary-expr") self.wrap(expr, "unary-expr")
expr.right.accept(self) expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self.wrap(expr, "call-expr")
expr.callee.accept(self)
for arg in expr.arguments:
arg.accept(self)
for arg in expr.keywords.values():
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr) -> None: def visit_get_expr(self, expr: m.GetExpr) -> None:
self.wrap(expr, "get-expr") self.wrap(expr, "get-expr")
expr.expr.accept(self) expr.expr.accept(self)
@@ -318,8 +338,7 @@ class MidasHighlighter(
def visit_function_type(self, type: m.FunctionType) -> None: def visit_function_type(self, type: m.FunctionType) -> None:
self.wrap(type, "function") self.wrap(type, "function")
for arg in type.pos_args + type.args + type.kw_args: self._visit_param_spec(type.params)
arg.type.accept(self)
type.returns.accept(self) type.returns.accept(self)
def visit_extension_type(self, type: m.ExtensionType) -> None: def visit_extension_type(self, type: m.ExtensionType) -> None:
@@ -327,6 +346,10 @@ class MidasHighlighter(
type.base.accept(self) type.base.accept(self)
type.extension.accept(self) type.extension.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
for param in spec.pos + spec.mixed + spec.kw:
param.type.accept(self)
class DiagnosticsHighlighter(Highlighter): class DiagnosticsHighlighter(Highlighter):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css" EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"

View File

@@ -1,5 +1,5 @@
import ast import ast
from typing import Optional from typing import Optional, assert_never
import midas.ast.midas as m import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry from midas.checker.registry import Member, TypesRegistry
@@ -8,6 +8,7 @@ from midas.checker.types import (
AppliedType, AppliedType,
BaseType, BaseType,
ComplexType, ComplexType,
ConstraintType,
ExtensionType, ExtensionType,
Function, Function,
GenericType, GenericType,
@@ -17,6 +18,7 @@ from midas.checker.types import (
TypeVar, TypeVar,
UnitType, UnitType,
UnknownType, UnknownType,
Variance,
substitute_typevars, substitute_typevars,
) )
@@ -84,6 +86,7 @@ class StubsGenerator:
match type: match type:
case AliasType(type=base): case AliasType(type=base):
return [self.dump_type(base)], {} return [self.dump_type(base)], {}
case GenericType(params=params, body=body): case GenericType(params=params, body=body):
self.add_typing_import("Generic") self.add_typing_import("Generic")
type_vars: ast.expr type_vars: ast.expr
@@ -111,6 +114,7 @@ class StubsGenerator:
], ],
body_subsitutions | substitutions, body_subsitutions | substitutions,
) )
case _: case _:
return [], {} return [], {}
@@ -148,15 +152,20 @@ class StubsGenerator:
case TopType() | UnknownType(): case TopType() | UnknownType():
self.add_typing_import("Any") self.add_typing_import("Any")
return ast.Name(id="Any") return ast.Name(id="Any")
case BaseType(name=name): case BaseType(name=name):
return ast.Name(id=name) return ast.Name(id=name)
case AliasType(name=name): case AliasType(name=name):
return ast.Name(id=name) return ast.Name(id=name)
case UnitType(): case UnitType():
return ast.Constant(value=None) return ast.Constant(value=None)
case Function(): case Function():
name: str = self.define_protocol(type) name: str = self.define_protocol(type)
return ast.Name(id=name) return ast.Name(id=name)
case OverloadedFunction(overloads=overloads): case OverloadedFunction(overloads=overloads):
if len(overloads) == 1: if len(overloads) == 1:
return self.dump_type(overloads[0]) return self.dump_type(overloads[0])
@@ -176,6 +185,7 @@ class StubsGenerator:
case TypeVar(): case TypeVar():
return ast.Name(id=type.name) return ast.Name(id=type.name)
case GenericType(name=name): case GenericType(name=name):
params: ast.expr params: ast.expr
if len(type.params) == 1: if len(type.params) == 1:
@@ -188,6 +198,7 @@ class StubsGenerator:
value=ast.Name(id=type.name), value=ast.Name(id=type.name),
slice=params, slice=params,
) )
case AppliedType(): case AppliedType():
args: ast.expr args: ast.expr
if len(type.args) == 1: if len(type.args) == 1:
@@ -199,6 +210,12 @@ class StubsGenerator:
slice=args, slice=args,
) )
case ConstraintType():
return self.dump_type(type.type)
case _:
assert_never(type)
def dump_method( def dump_method(
self, name: str, method: Type, overloaded: bool = False self, name: str, method: Type, overloaded: bool = False
) -> list[ast.stmt]: ) -> list[ast.stmt]:
@@ -313,6 +330,29 @@ class StubsGenerator:
def define_type_var(self, var: TypeVar) -> TypeVar: def define_type_var(self, var: TypeVar) -> TypeVar:
name: str = self.new_type_var_name() name: str = self.new_type_var_name()
self.add_typing_import("TypeVar") self.add_typing_import("TypeVar")
kwargs: list[ast.keyword] = []
if var.bound is not None:
kwargs.append(
ast.keyword(
arg="bound",
value=self.dump_type(var.bound),
)
)
if var.variance == Variance.COVARIANT:
kwargs.append(
ast.keyword(
arg="covariant",
value=ast.Constant(value=True),
)
)
elif var.variance == Variance.CONTRAVARIANT:
kwargs.append(
ast.keyword(
arg="contravariant",
value=ast.Constant(value=True),
)
)
self.add_stub( self.add_stub(
ast.Assign( ast.Assign(
targets=[ast.Name(id=name)], targets=[ast.Name(id=name)],
@@ -321,16 +361,7 @@ class StubsGenerator:
args=[ args=[
ast.Constant(value=name), ast.Constant(value=name),
], ],
keywords=( keywords=kwargs,
[]
if var.bound is None
else [
ast.keyword(
arg="bound",
value=self.dump_type(var.bound),
)
]
),
), ),
) )
) )