diff --git a/midas/checker/midas.py b/midas/checker/midas.py index decb40c..6a528a0 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -8,6 +8,7 @@ from midas.checker.reporter import FileReporter, Reporter from midas.checker.types import ( AliasType, ComplexType, + ExtensionType, Function, GenericType, Type, @@ -76,7 +77,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type self.types.define_type(name, type) self._local_variables.clear() - def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... + def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ... def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: self._resolve_type_params(stmt.params) @@ -126,16 +127,21 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type # TODO return UnknownType() - def visit_complex_type(self, type: m.ComplexType) -> Type: + def visit_complex_type(self, type: m.ComplexType) -> ComplexType: return ComplexType( - properties={ - prop.name.lexeme: prop.type.accept(self) for prop in type.properties + members={ + member.name.lexeme: member.type.accept(self) for member in type.members } ) + def visit_extension_type(self, type: m.ExtensionType) -> Type: + return ExtensionType( + base=type.base.accept(self), + extension=self.visit_complex_type(type.extension), + ) + def visit_function_type(self, type: m.FunctionType) -> Type: return Function( - name="", pos_args=[ Function.Argument( pos=i, diff --git a/midas/checker/python.py b/midas/checker/python.py index c11ee22..3d44975 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -12,6 +12,7 @@ from midas.checker.reporter import FileReporter, Reporter from midas.checker.resolver import Resolver from midas.checker.types import ( ComplexType, + ExtensionType, Function, Operation, Type, @@ -192,7 +193,6 @@ class PythonTyper( returns_hint = stmt.returns.accept(self) # Early define to handle simple fully-typed recursion inside_function: Function = Function( - name=stmt.name, pos_args=pos_args, args=args, kw_args=kw_args, @@ -227,7 +227,6 @@ class PythonTyper( # TODO: handle *args and **kwargs sinks function: Function = Function( - name=stmt.name, pos_args=pos_args, args=args, kw_args=kw_args, @@ -250,8 +249,9 @@ class PythonTyper( case p.VariableExpr(): self._assign_var(location, target, value_type) - case p.GetExpr(): - self._assign_attr(location, target, value_type) + case p.GetExpr(object=object, name=name): + object_type: Type = self.type_of(object) + self._assign_attr(location, object_type, name, value_type) case _: if not isinstance(target, p.VariableExpr): @@ -276,32 +276,43 @@ class PythonTyper( f"Cannot assign {value_type} to variable '{name}' of type {var_type}", ) - def _assign_attr(self, location: Location, target: p.GetExpr, value_type: Type): - object: Type = self.type_of(target.object) + def _assign_attr( + self, location: Location, object: Type, name: str, value_type: Type + ): + # TODO: improve recursion to have better error messages base_object: Type = unfold_type(object) match base_object: - case ComplexType(properties=properties): - if target.name not in properties: + case ComplexType(members=members): + if name not in members: + self.reporter.error(location, f"Unknown member '{object}.{name}'") + return + + member_type: Type = members[name] + if not self.is_subtype(value_type, member_type): self.reporter.error( - target.location, f"Unknown property '{object}.{target.name}'" + location, + f"Cannot assign {value_type} to member '{object}.{name}' of type {member_type}", ) return - prop_type: Type = properties[target.name] - if not self.is_subtype(value_type, prop_type): - self.reporter.error( - location, - f"Cannot assign {value_type} to property '{object}.{target.name}' of type {prop_type}", - ) - return + case ExtensionType(base=base, extension=ComplexType(members=members)): + if name in members: + member_type: Type = members[name] + if not self.is_subtype(value_type, member_type): + self.reporter.error( + location, + f"Cannot assign {value_type} to member '{object}.{name}' of type {member_type}", + ) + return + return self._assign_attr(location, base, name, value_type) case UnknownType(): pass case _: self.reporter.error( - target.location, - f"Cannot assign {value_type} to unknown property '{object}.{target.name}'", + location, + f"Cannot assign {value_type} to unknown property '{object}.{name}'", ) def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: @@ -422,23 +433,37 @@ class PythonTyper( def visit_get_expr(self, expr: p.GetExpr) -> Type: object: Type = self.type_of(expr.object) + member: Optional[Type] = self._get_member(object, expr.name) + if member is None: + self.reporter.error( + expr.location, f"Unknown property '{expr.name}' on {object}" + ) + return UnknownType() + self.logger.debug(f"Property '{expr.name}' on {object} has type {member}") + return member + + def _get_member(self, object: Type, name: str) -> Optional[Type]: base_object: Type = unfold_type(object) match base_object: - case ComplexType(properties=properties): - if expr.name not in properties: - self.reporter.error( - expr.location, f"Unknown property '{expr.name} on {object}" - ) - return UnknownType() - return properties[expr.name] + case ComplexType(members=members): + if name in members: + return members[name] + self.logger.debug(f"No property '{name}' in {base_object}") + return None + + case ExtensionType(base=base, extension=ComplexType(members=members)): + if name in members: + return members[name] + self.logger.debug( + f"No property '{name}' on {base_object}, looking up in base" + ) + return self._get_member(base, name) case UnknownType(): return UnknownType() case _: - self.reporter.error( - expr.location, f"Cannot get property '{expr.name}' on {object}" - ) + self.logger.debug(f"Can't get property on {base_object}") return UnknownType() def visit_literal_expr(self, expr: p.LiteralExpr) -> Type: diff --git a/midas/checker/types.py b/midas/checker/types.py index 41ad786..9057e4c 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -35,7 +35,6 @@ class UnitType: @dataclass(frozen=True, kw_only=True) class Function: - name: str pos_args: list[Argument] args: list[Argument] kw_args: list[Argument] @@ -56,7 +55,7 @@ class Function: args.append("*") args += list(map(str, self.kw_args)) - return f"{self.name}({', '.join(args)}) -> {self.returns}" + return f"({', '.join(args)}) -> {self.returns}" @dataclass(frozen=True, kw_only=True) class Argument: @@ -72,13 +71,22 @@ class Function: @dataclass(frozen=True, kw_only=True) class ComplexType: - properties: dict[str, Type] + members: dict[str, Type] def __str__(self) -> str: - props: list[str] = [f"{name}: {type}" for name, type in self.properties.items()] + props: list[str] = [f"{name}: {type}" for name, type in self.members.items()] return f"{{{', '.join(props)}}}" +@dataclass(frozen=True, kw_only=True) +class ExtensionType: + base: Type + extension: ComplexType + + def __str__(self) -> str: + return f"{self.base} & {self.extension}" + + @dataclass(frozen=True, kw_only=True) class Operation: signature: CallSignature @@ -145,26 +153,35 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: return AliasType(name=name, type=substitute_typevars(type2, substitutions)) case Function( - name=name, pos_args=pos_args, args=args, kw_args=kw_args, returns=returns, ): return Function( - name=name, pos_args=list(map(sub_argument, pos_args)), args=list(map(sub_argument, args)), kw_args=list(map(sub_argument, kw_args)), returns=substitute_typevars(returns, substitutions), ) - case ComplexType(properties=properties): - properties2: dict[str, Type] = { + case ComplexType(members=members): + members2: dict[str, Type] = { name: substitute_typevars(prop, substitutions) - for name, prop in properties.items() + for name, prop in members.items() } - return ComplexType(properties=properties2) + return ComplexType(members=members2) + + case ExtensionType(base=base, extension=ComplexType(members=members)): + return ExtensionType( + base=substitute_typevars(base, substitutions), + extension=ComplexType( + members={ + name: substitute_typevars(prop, substitutions) + for name, prop in members.items() + } + ), + ) case TypeVar(name=name): if name in substitutions: @@ -193,6 +210,7 @@ Type = ( | UnitType | Function | ComplexType + | ExtensionType | TypeVar | GenericType | AppliedType