feat(checker): handle type vars in python functions
This commit is contained in:
@@ -22,8 +22,10 @@ from midas.checker.types import (
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
Variance,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.checker.unifier import Unifier
|
||||
@@ -229,7 +231,8 @@ class PythonTyper(
|
||||
)
|
||||
pos += 1
|
||||
|
||||
for arg in pos_args + args + kw_args:
|
||||
all_args: list[Function.Argument] = pos_args + args + kw_args
|
||||
for arg in all_args:
|
||||
env.define(arg.name, arg.type)
|
||||
|
||||
returns_hint: Optional[Type] = None
|
||||
@@ -270,12 +273,25 @@ class PythonTyper(
|
||||
returns = inferred_return
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
function: Function = Function(
|
||||
function: Type = Function(
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns,
|
||||
)
|
||||
generic_params: list[TypeVar] = []
|
||||
all_types: list[Type] = [arg.type for arg in all_args] + [returns]
|
||||
for type in all_types:
|
||||
if isinstance(type, TypeVar):
|
||||
if type not in generic_params:
|
||||
generic_params.append(type)
|
||||
|
||||
if len(generic_params) != 0:
|
||||
function = GenericType(
|
||||
name=stmt.name,
|
||||
params=generic_params,
|
||||
body=function,
|
||||
)
|
||||
self.env.define(stmt.name, function)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
@@ -453,6 +469,10 @@ class PythonTyper(
|
||||
return result or UnknownType()
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||
match expr.callee:
|
||||
case p.VariableExpr(name="TypeVar"):
|
||||
return self.define_typevar(expr) or UnknownType()
|
||||
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
positional: list[TypedExpr] = [
|
||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||
@@ -1033,3 +1053,57 @@ class PythonTyper(
|
||||
report_errors=False,
|
||||
)
|
||||
return result
|
||||
|
||||
def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]:
|
||||
def is_kw_true(name: str) -> bool:
|
||||
match call.keywords.get(name):
|
||||
case p.LiteralExpr(value=True):
|
||||
return True
|
||||
case _:
|
||||
return False
|
||||
|
||||
match call:
|
||||
case p.CallExpr(
|
||||
arguments=[p.LiteralExpr(value=str() as name)],
|
||||
):
|
||||
bound: Optional[Type] = None
|
||||
variance: Variance = Variance.INVARIANT
|
||||
if "bound" in call.keywords:
|
||||
bound_type: p.MidasType = self._parse_type_from_expr(
|
||||
call.keywords["bound"]
|
||||
)
|
||||
bound = self.resolve_type_expr(bound_type)
|
||||
|
||||
if is_kw_true("covariant"):
|
||||
variance = Variance.COVARIANT
|
||||
|
||||
if is_kw_true("contravariant"):
|
||||
if variance == Variance.COVARIANT:
|
||||
self.reporter.warning(
|
||||
call.keywords["contravariant"].location,
|
||||
"TypeVar cannot be covariant and contravariant at the same time. Marked as invariant",
|
||||
)
|
||||
variance = Variance.INVARIANT
|
||||
else:
|
||||
variance = Variance.CONTRAVARIANT
|
||||
var: TypeVar = TypeVar(name=name, bound=bound, variance=variance)
|
||||
self.types.define_type(name, var)
|
||||
return var
|
||||
|
||||
case _:
|
||||
self.reporter.warning(
|
||||
call.location, "Invalid usage of 'TypeVar', skipping"
|
||||
)
|
||||
return None
|
||||
|
||||
def _parse_type_from_expr(self, expr: p.Expr) -> p.MidasType:
|
||||
location: Location = expr.location
|
||||
parser = PythonParser()
|
||||
match expr:
|
||||
case p.LiteralExpr(value=str() as value):
|
||||
node: ast.Expression = ast.parse(value, mode="eval")
|
||||
return parser._parse_type(node.body)
|
||||
case p.VariableExpr(name=name):
|
||||
return p.BaseType(location=location, base=name, param=None)
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -323,7 +323,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
self._make_constraint_assert(src_location, expr, constraint)
|
||||
|
||||
case TypeVar():
|
||||
raise RuntimeError("Unexpected TypeVar")
|
||||
# TODO: check with type from arguments / use call-site context
|
||||
pass
|
||||
|
||||
case (
|
||||
TopType()
|
||||
|
||||
Reference in New Issue
Block a user