diff --git a/midas/checker/python.py b/midas/checker/python.py index e1fb788..6a00249 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -22,8 +22,10 @@ from midas.checker.types import ( GenericType, OverloadedFunction, Type, + TypeVar, UnitType, UnknownType, + Variance, unfold_type, ) from midas.checker.unifier import Unifier @@ -229,7 +231,8 @@ class PythonTyper( ) pos += 1 - for arg in pos_args + args + kw_args: + all_args: list[Function.Argument] = pos_args + args + kw_args + for arg in all_args: env.define(arg.name, arg.type) returns_hint: Optional[Type] = None @@ -270,12 +273,25 @@ class PythonTyper( returns = inferred_return # TODO: handle *args and **kwargs sinks - function: Function = Function( + function: Type = Function( pos_args=pos_args, args=args, kw_args=kw_args, returns=returns, ) + generic_params: list[TypeVar] = [] + all_types: list[Type] = [arg.type for arg in all_args] + [returns] + for type in all_types: + if isinstance(type, TypeVar): + if type not in generic_params: + generic_params.append(type) + + if len(generic_params) != 0: + function = GenericType( + name=stmt.name, + params=generic_params, + body=function, + ) self.env.define(stmt.name, function) def visit_type_assign(self, stmt: p.TypeAssign) -> None: @@ -453,6 +469,10 @@ class PythonTyper( return result or UnknownType() def visit_call_expr(self, expr: p.CallExpr) -> Type: + match expr.callee: + case p.VariableExpr(name="TypeVar"): + return self.define_typevar(expr) or UnknownType() + callee: Type = self.type_of(expr.callee) positional: list[TypedExpr] = [ (arg, self.type_of(arg)) for arg in expr.arguments @@ -1033,3 +1053,57 @@ class PythonTyper( report_errors=False, ) return result + + def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]: + def is_kw_true(name: str) -> bool: + match call.keywords.get(name): + case p.LiteralExpr(value=True): + return True + case _: + return False + + match call: + case p.CallExpr( + arguments=[p.LiteralExpr(value=str() as name)], + ): + bound: Optional[Type] = None + variance: Variance = Variance.INVARIANT + if "bound" in call.keywords: + bound_type: p.MidasType = self._parse_type_from_expr( + call.keywords["bound"] + ) + bound = self.resolve_type_expr(bound_type) + + if is_kw_true("covariant"): + variance = Variance.COVARIANT + + if is_kw_true("contravariant"): + if variance == Variance.COVARIANT: + self.reporter.warning( + call.keywords["contravariant"].location, + "TypeVar cannot be covariant and contravariant at the same time. Marked as invariant", + ) + variance = Variance.INVARIANT + else: + variance = Variance.CONTRAVARIANT + var: TypeVar = TypeVar(name=name, bound=bound, variance=variance) + self.types.define_type(name, var) + return var + + case _: + self.reporter.warning( + call.location, "Invalid usage of 'TypeVar', skipping" + ) + return None + + def _parse_type_from_expr(self, expr: p.Expr) -> p.MidasType: + location: Location = expr.location + parser = PythonParser() + match expr: + case p.LiteralExpr(value=str() as value): + node: ast.Expression = ast.parse(value, mode="eval") + return parser._parse_type(node.body) + case p.VariableExpr(name=name): + return p.BaseType(location=location, base=name, param=None) + case _: + raise NotImplementedError diff --git a/midas/generator/generator.py b/midas/generator/generator.py index 22eab41..88065d2 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -323,7 +323,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): self._make_constraint_assert(src_location, expr, constraint) case TypeVar(): - raise RuntimeError("Unexpected TypeVar") + # TODO: check with type from arguments / use call-site context + pass case ( TopType()