diff --git a/midas/checker/types.py b/midas/checker/types.py index 309ad0f..82a08ba 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -238,6 +238,58 @@ def unfold_type(type: Type) -> Type: return type +def to_annotation(type: Type) -> str: + def _args_annotation(func: Function) -> str: + if len(func.kw_args) != 0: + return "..." + + args: str = ", ".join( + to_annotation(arg.type) for arg in func.pos_args + func.args + ) + return f"[{args}]" + + match type: + case TopType(): + return "Any" + + case BaseType(name=name): + return name + + case AliasType(name=name): + return name + + case UnknownType(): + return "Any" + + case UnitType(): + return "None" + + case Function(returns=returns): + params_annot: str = _args_annotation(type) + return f"Callable[{params_annot}, {to_annotation(returns)}]" + + case OverloadedFunction(): + return "Callable" + + case ComplexType() | ExtensionType(): + raise NotImplementedError + + case TypeVar(name=name): + return name + + case GenericType(name=name, params=params): + return f"{name}[{', '.join(map(to_annotation, params))}]" + + case AppliedType(name=name, args=args): + return f"{name}[{', '.join(map(to_annotation, args))}]" + + case ConstraintType(): + return str(type) + + case _: + assert_never(type) + + @dataclass(frozen=True, kw_only=True) class Predicate: type: Type diff --git a/midas/generator/constraints.py b/midas/generator/constraints.py index e9f21c1..e840b42 100644 --- a/midas/generator/constraints.py +++ b/midas/generator/constraints.py @@ -3,7 +3,12 @@ from typing import Optional import midas.ast.midas as m from midas.checker.registry import TypesRegistry -from midas.checker.types import Function, Predicate, Type +from midas.checker.types import ( + Function, + Predicate, + Type, + to_annotation, +) from midas.lexer.token import TokenType LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = { @@ -91,9 +96,27 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]): def make_args(self, func: Function) -> ast.arguments: return ast.arguments( - posonlyargs=[ast.arg(arg=arg.name) for arg in func.pos_args], - args=[ast.arg(arg=arg.name) for arg in func.args], - kwonlyargs=[ast.arg(arg=arg.name) for arg in func.kw_args], + posonlyargs=[ + ast.arg( + arg=arg.name, + annotation=ast.Constant(value=to_annotation(arg.type)), + ) + for arg in func.pos_args + ], + args=[ + ast.arg( + arg=arg.name, + annotation=ast.Constant(value=to_annotation(arg.type)), + ) + for arg in func.args + ], + kwonlyargs=[ + ast.arg( + arg=arg.name, + annotation=ast.Constant(value=to_annotation(arg.type)), + ) + for arg in func.kw_args + ], defaults=[], kw_defaults=[], ) @@ -111,6 +134,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]): self.make_func(inner_name, inner_body, type.returns, level + 1), ast.Return(value=ast.Name(id=inner_name)), ], + returns=ast.Constant(value=to_annotation(type.returns)), decorator_list=[], ) @@ -119,6 +143,7 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]): name=name, args=self.make_args(type), body=inner_body, + returns=ast.Constant(value=to_annotation(type.returns)), decorator_list=[], )