feat(checker): handle type vars in python functions
This commit is contained in:
@@ -22,8 +22,10 @@ from midas.checker.types import (
|
|||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
Type,
|
Type,
|
||||||
|
TypeVar,
|
||||||
UnitType,
|
UnitType,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
|
Variance,
|
||||||
unfold_type,
|
unfold_type,
|
||||||
)
|
)
|
||||||
from midas.checker.unifier import Unifier
|
from midas.checker.unifier import Unifier
|
||||||
@@ -229,7 +231,8 @@ class PythonTyper(
|
|||||||
)
|
)
|
||||||
pos += 1
|
pos += 1
|
||||||
|
|
||||||
for arg in pos_args + args + kw_args:
|
all_args: list[Function.Argument] = pos_args + args + kw_args
|
||||||
|
for arg in all_args:
|
||||||
env.define(arg.name, arg.type)
|
env.define(arg.name, arg.type)
|
||||||
|
|
||||||
returns_hint: Optional[Type] = None
|
returns_hint: Optional[Type] = None
|
||||||
@@ -270,12 +273,25 @@ class PythonTyper(
|
|||||||
returns = inferred_return
|
returns = inferred_return
|
||||||
|
|
||||||
# TODO: handle *args and **kwargs sinks
|
# TODO: handle *args and **kwargs sinks
|
||||||
function: Function = Function(
|
function: Type = Function(
|
||||||
pos_args=pos_args,
|
pos_args=pos_args,
|
||||||
args=args,
|
args=args,
|
||||||
kw_args=kw_args,
|
kw_args=kw_args,
|
||||||
returns=returns,
|
returns=returns,
|
||||||
)
|
)
|
||||||
|
generic_params: list[TypeVar] = []
|
||||||
|
all_types: list[Type] = [arg.type for arg in all_args] + [returns]
|
||||||
|
for type in all_types:
|
||||||
|
if isinstance(type, TypeVar):
|
||||||
|
if type not in generic_params:
|
||||||
|
generic_params.append(type)
|
||||||
|
|
||||||
|
if len(generic_params) != 0:
|
||||||
|
function = GenericType(
|
||||||
|
name=stmt.name,
|
||||||
|
params=generic_params,
|
||||||
|
body=function,
|
||||||
|
)
|
||||||
self.env.define(stmt.name, function)
|
self.env.define(stmt.name, function)
|
||||||
|
|
||||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||||
@@ -453,6 +469,10 @@ class PythonTyper(
|
|||||||
return result or UnknownType()
|
return result or UnknownType()
|
||||||
|
|
||||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||||
|
match expr.callee:
|
||||||
|
case p.VariableExpr(name="TypeVar"):
|
||||||
|
return self.define_typevar(expr) or UnknownType()
|
||||||
|
|
||||||
callee: Type = self.type_of(expr.callee)
|
callee: Type = self.type_of(expr.callee)
|
||||||
positional: list[TypedExpr] = [
|
positional: list[TypedExpr] = [
|
||||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||||
@@ -1033,3 +1053,57 @@ class PythonTyper(
|
|||||||
report_errors=False,
|
report_errors=False,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]:
|
||||||
|
def is_kw_true(name: str) -> bool:
|
||||||
|
match call.keywords.get(name):
|
||||||
|
case p.LiteralExpr(value=True):
|
||||||
|
return True
|
||||||
|
case _:
|
||||||
|
return False
|
||||||
|
|
||||||
|
match call:
|
||||||
|
case p.CallExpr(
|
||||||
|
arguments=[p.LiteralExpr(value=str() as name)],
|
||||||
|
):
|
||||||
|
bound: Optional[Type] = None
|
||||||
|
variance: Variance = Variance.INVARIANT
|
||||||
|
if "bound" in call.keywords:
|
||||||
|
bound_type: p.MidasType = self._parse_type_from_expr(
|
||||||
|
call.keywords["bound"]
|
||||||
|
)
|
||||||
|
bound = self.resolve_type_expr(bound_type)
|
||||||
|
|
||||||
|
if is_kw_true("covariant"):
|
||||||
|
variance = Variance.COVARIANT
|
||||||
|
|
||||||
|
if is_kw_true("contravariant"):
|
||||||
|
if variance == Variance.COVARIANT:
|
||||||
|
self.reporter.warning(
|
||||||
|
call.keywords["contravariant"].location,
|
||||||
|
"TypeVar cannot be covariant and contravariant at the same time. Marked as invariant",
|
||||||
|
)
|
||||||
|
variance = Variance.INVARIANT
|
||||||
|
else:
|
||||||
|
variance = Variance.CONTRAVARIANT
|
||||||
|
var: TypeVar = TypeVar(name=name, bound=bound, variance=variance)
|
||||||
|
self.types.define_type(name, var)
|
||||||
|
return var
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(
|
||||||
|
call.location, "Invalid usage of 'TypeVar', skipping"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_type_from_expr(self, expr: p.Expr) -> p.MidasType:
|
||||||
|
location: Location = expr.location
|
||||||
|
parser = PythonParser()
|
||||||
|
match expr:
|
||||||
|
case p.LiteralExpr(value=str() as value):
|
||||||
|
node: ast.Expression = ast.parse(value, mode="eval")
|
||||||
|
return parser._parse_type(node.body)
|
||||||
|
case p.VariableExpr(name=name):
|
||||||
|
return p.BaseType(location=location, base=name, param=None)
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -323,7 +323,8 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
self._make_constraint_assert(src_location, expr, constraint)
|
self._make_constraint_assert(src_location, expr, constraint)
|
||||||
|
|
||||||
case TypeVar():
|
case TypeVar():
|
||||||
raise RuntimeError("Unexpected TypeVar")
|
# TODO: check with type from arguments / use call-site context
|
||||||
|
pass
|
||||||
|
|
||||||
case (
|
case (
|
||||||
TopType()
|
TopType()
|
||||||
|
|||||||
Reference in New Issue
Block a user