feat(checker): adapt typers to members and extension type

This commit is contained in:
2026-06-11 17:12:50 +02:00
parent bfa0bb3ee0
commit beaa4d95d8
3 changed files with 92 additions and 43 deletions

View File

@@ -8,6 +8,7 @@ from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AliasType,
ComplexType,
ExtensionType,
Function,
GenericType,
Type,
@@ -76,7 +77,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
self.types.define_type(name, type)
self._local_variables.clear()
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._resolve_type_params(stmt.params)
@@ -126,16 +127,21 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
# TODO
return UnknownType()
def visit_complex_type(self, type: m.ComplexType) -> Type:
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
return ComplexType(
properties={
prop.name.lexeme: prop.type.accept(self) for prop in type.properties
members={
member.name.lexeme: member.type.accept(self) for member in type.members
}
)
def visit_extension_type(self, type: m.ExtensionType) -> Type:
return ExtensionType(
base=type.base.accept(self),
extension=self.visit_complex_type(type.extension),
)
def visit_function_type(self, type: m.FunctionType) -> Type:
return Function(
name="<anonymous>",
pos_args=[
Function.Argument(
pos=i,

View File

@@ -12,6 +12,7 @@ from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver
from midas.checker.types import (
ComplexType,
ExtensionType,
Function,
Operation,
Type,
@@ -192,7 +193,6 @@ class PythonTyper(
returns_hint = stmt.returns.accept(self)
# Early define to handle simple fully-typed recursion
inside_function: Function = Function(
name=stmt.name,
pos_args=pos_args,
args=args,
kw_args=kw_args,
@@ -227,7 +227,6 @@ class PythonTyper(
# TODO: handle *args and **kwargs sinks
function: Function = Function(
name=stmt.name,
pos_args=pos_args,
args=args,
kw_args=kw_args,
@@ -250,8 +249,9 @@ class PythonTyper(
case p.VariableExpr():
self._assign_var(location, target, value_type)
case p.GetExpr():
self._assign_attr(location, target, value_type)
case p.GetExpr(object=object, name=name):
object_type: Type = self.type_of(object)
self._assign_attr(location, object_type, name, value_type)
case _:
if not isinstance(target, p.VariableExpr):
@@ -276,32 +276,43 @@ class PythonTyper(
f"Cannot assign {value_type} to variable '{name}' of type {var_type}",
)
def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type):
object: Type = self.type_of(target.object)
def _assign_attr(
self, location: Location, object: Type, name: str, value_type: Type
):
# TODO: improve recursion to have better error messages
base_object: Type = unfold_type(object)
match base_object:
case ComplexType(properties=properties):
if target.name not in properties:
case ComplexType(members=members):
if name not in members:
self.reporter.error(location, f"Unknown member '{object}.{name}'")
return
member_type: Type = members[name]
if not self.is_subtype(value_type, member_type):
self.reporter.error(
target.location, f"Unknown property '{object}.{target.name}'"
location,
f"Cannot assign {value_type} to member '{object}.{name}' of type {member_type}",
)
return
prop_type: Type = properties[target.name]
if not self.is_subtype(value_type, prop_type):
self.reporter.error(
location,
f"Cannot assign {value_type} to property '{object}.{target.name}' of type {prop_type}",
)
return
case ExtensionType(base=base, extension=ComplexType(members=members)):
if name in members:
member_type: Type = members[name]
if not self.is_subtype(value_type, member_type):
self.reporter.error(
location,
f"Cannot assign {value_type} to member '{object}.{name}' of type {member_type}",
)
return
return self._assign_attr(location, base, name, value_type)
case UnknownType():
pass
case _:
self.reporter.error(
target.location,
f"Cannot assign {value_type} to unknown property '{object}.{target.name}'",
location,
f"Cannot assign {value_type} to unknown property '{object}.{name}'",
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
@@ -422,23 +433,37 @@ class PythonTyper(
def visit_get_expr(self, expr: p.GetExpr) -> Type:
object: Type = self.type_of(expr.object)
member: Optional[Type] = self._get_member(object, expr.name)
if member is None:
self.reporter.error(
expr.location, f"Unknown property '{expr.name}' on {object}"
)
return UnknownType()
self.logger.debug(f"Property '{expr.name}' on {object} has type {member}")
return member
def _get_member(self, object: Type, name: str) -> Optional[Type]:
base_object: Type = unfold_type(object)
match base_object:
case ComplexType(properties=properties):
if expr.name not in properties:
self.reporter.error(
expr.location, f"Unknown property '{expr.name} on {object}"
)
return UnknownType()
return properties[expr.name]
case ComplexType(members=members):
if name in members:
return members[name]
self.logger.debug(f"No property '{name}' in {base_object}")
return None
case ExtensionType(base=base, extension=ComplexType(members=members)):
if name in members:
return members[name]
self.logger.debug(
f"No property '{name}' on {base_object}, looking up in base"
)
return self._get_member(base, name)
case UnknownType():
return UnknownType()
case _:
self.reporter.error(
expr.location, f"Cannot get property '{expr.name}' on {object}"
)
self.logger.debug(f"Can't get property on {base_object}")
return UnknownType()
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:

View File

@@ -35,7 +35,6 @@ class UnitType:
@dataclass(frozen=True, kw_only=True)
class Function:
name: str
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
@@ -56,7 +55,7 @@ class Function:
args.append("*")
args += list(map(str, self.kw_args))
return f"{self.name}({', '.join(args)}) -> {self.returns}"
return f"({', '.join(args)}) -> {self.returns}"
@dataclass(frozen=True, kw_only=True)
class Argument:
@@ -72,13 +71,22 @@ class Function:
@dataclass(frozen=True, kw_only=True)
class ComplexType:
properties: dict[str, Type]
members: dict[str, Type]
def __str__(self) -> str:
props: list[str] = [f"{name}: {type}" for name, type in self.properties.items()]
props: list[str] = [f"{name}: {type}" for name, type in self.members.items()]
return f"{{{', '.join(props)}}}"
@dataclass(frozen=True, kw_only=True)
class ExtensionType:
base: Type
extension: ComplexType
def __str__(self) -> str:
return f"{self.base} & {self.extension}"
@dataclass(frozen=True, kw_only=True)
class Operation:
signature: CallSignature
@@ -145,26 +153,35 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
return AliasType(name=name, type=substitute_typevars(type2, substitutions))
case Function(
name=name,
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
):
return Function(
name=name,
pos_args=list(map(sub_argument, pos_args)),
args=list(map(sub_argument, args)),
kw_args=list(map(sub_argument, kw_args)),
returns=substitute_typevars(returns, substitutions),
)
case ComplexType(properties=properties):
properties2: dict[str, Type] = {
case ComplexType(members=members):
members2: dict[str, Type] = {
name: substitute_typevars(prop, substitutions)
for name, prop in properties.items()
for name, prop in members.items()
}
return ComplexType(properties=properties2)
return ComplexType(members=members2)
case ExtensionType(base=base, extension=ComplexType(members=members)):
return ExtensionType(
base=substitute_typevars(base, substitutions),
extension=ComplexType(
members={
name: substitute_typevars(prop, substitutions)
for name, prop in members.items()
}
),
)
case TypeVar(name=name):
if name in substitutions:
@@ -193,6 +210,7 @@ Type = (
| UnitType
| Function
| ComplexType
| ExtensionType
| TypeVar
| GenericType
| AppliedType