diff --git a/examples/01_simple_type_checking/01_simple_operations.py b/examples/01_simple_type_checking/01_simple_operations.py index a3ac707..4e767f2 100644 --- a/examples/01_simple_type_checking/01_simple_operations.py +++ b/examples/01_simple_type_checking/01_simple_operations.py @@ -9,3 +9,5 @@ d = True e = d + d f: float = a + +f = -f diff --git a/examples/01_simple_type_checking/02_simple_types.midas b/examples/01_simple_type_checking/02_simple_types.midas index 6a1a6a2..ff4edb1 100644 --- a/examples/01_simple_type_checking/02_simple_types.midas +++ b/examples/01_simple_type_checking/02_simple_types.midas @@ -3,12 +3,12 @@ type Second = float type MeterPerSecond = float extend Meter { - op __add__(Meter) -> Meter - op __sub__(Meter) -> Meter - op __truediv__(Second) -> MeterPerSecond + def __add__: fn(Meter, /) -> Meter + def __sub__: fn(Meter, /) -> Meter + def __truediv__: fn(Second, /) -> MeterPerSecond } extend Second { - op __add__(Second) -> Second - op __sub__(Second) -> Second + def __add__: fn(Second, /) -> Second + def __sub__: fn(Second, /) -> Second } diff --git a/midas/checker/operators.py b/midas/checker/operators.py index e65ab07..58af88c 100644 --- a/midas/checker/operators.py +++ b/midas/checker/operators.py @@ -29,3 +29,10 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = { # ast.In: "__in__", # ast.NotIn: "__notin__", } + +UNARY_METHODS: dict[Type[ast.unaryop], str] = { + ast.Invert: "__invert__", + # ast.Not: "", + ast.UAdd: "__pos__", + ast.USub: "__neg__", +} diff --git a/midas/checker/python.py b/midas/checker/python.py index 5149bb7..9c788c8 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -6,7 +6,7 @@ from typing import Optional import midas.ast.python as p from midas.ast.location import Location from midas.checker.environment import Environment -from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS +from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter, Reporter from midas.checker.resolver import Resolver @@ -376,8 +376,37 @@ class PythonTyper( return UnknownType() def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: - self.reporter.warning(expr.location, "UnaryExpr not yet supported") - return UnknownType() + method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__) + if method is None: + self.logger.warning(f"Unsupported operator {expr.operator}") + self.reporter.warning( + expr.location, f"Unsupported operator {expr.operator}" + ) + return UnknownType() + + operand: Type = self.type_of(expr.right) + operation: Optional[Type] = self.types.lookup_member(operand, method) + if operation is None: + self.reporter.error( + expr.location, + f"Undefined operation {method} for {operand}", + ) + return UnknownType() + + match operation: + case Function() as function: + if not self._is_unary_function(function): + self.reporter.error( + expr.location, + f"Wrong definition of unary operation. Expected function with 0 parameters, got {function}", + ) + return UnknownType() + return function.returns + case _: + self.reporter.warning( + expr.location, f"Unsupported operation {operation}" + ) + return UnknownType() def visit_call_expr(self, expr: p.CallExpr) -> Type: callee: Type = self.type_of(expr.callee) @@ -633,3 +662,12 @@ class PythonTyper( if len(function.kw_args) != 0: return False return True + + def _is_unary_function(self, function: Function) -> bool: + if len(function.pos_args) != 0: + return False + if len(function.args) != 0: + return False + if len(function.kw_args) != 0: + return False + return True