From ac620f318b5616232256b9c42cd069a39884a001 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Sat, 13 Jun 2026 18:44:49 +0200 Subject: [PATCH] feat(checker): type check subscripts --- .../04_complex_types.py | 5 ++ midas/checker/python.py | 62 ++++++++++++++----- midas/checker/resolver.py | 4 ++ 3 files changed, 56 insertions(+), 15 deletions(-) diff --git a/examples/01_simple_type_checking/04_complex_types.py b/examples/01_simple_type_checking/04_complex_types.py index ebe958f..4f21cc6 100644 --- a/examples/01_simple_type_checking/04_complex_types.py +++ b/examples/01_simple_type_checking/04_complex_types.py @@ -28,3 +28,8 @@ bar: list[list[Meter]] bar.append([p2.x]) foo2 = foo + foo + +a = foo[0] +b = bar[0][1] +c = bar[0][1][2] # invalid, not method __getitem__ on Meter +c = bar[""] # invalid, wrong index type diff --git a/midas/checker/python.py b/midas/checker/python.py index 9c788c8..45b9a33 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -356,7 +356,7 @@ class PythonTyper( match operation: case Function() as function: - if not self._is_binary_function(function): + if not self._check_arity(function, 1, 0, 0): self.reporter.error( location, f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}", @@ -395,7 +395,7 @@ class PythonTyper( match operation: case Function() as function: - if not self._is_unary_function(function): + if not self._check_arity(function, 0, 0, 0): self.reporter.error( expr.location, f"Wrong definition of unary operation. Expected function with 0 parameters, got {function}", @@ -512,6 +512,41 @@ class PythonTyper( ) return self.types.apply_generic(list_type, [UnknownType()]) + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type: + object: Type = self.type_of(expr.object) + operation: Optional[Type] = self.types.lookup_member(object, "__getitem__") + if operation is None: + self.reporter.error( + expr.location, + f"Undefined method __getitem__ on {object}", + ) + return UnknownType() + + index: Type = self.type_of(expr.index) + + match operation: + case Function() as function: + if not self._check_arity(function, 1, 0, 0): + self.reporter.error( + expr.location, + f"Wrong definition of __getitem__. Expected function with 1 positional-only parameters, got {function}", + ) + return UnknownType() + + index_arg: Function.Argument = function.pos_args[0] + if not self.is_subtype(index, index_arg.type): + self.reporter.error( + expr.location, + f"Wrong index type, expected {index_arg.type}, got {index}", + ) + return UnknownType() + return function.returns + case _: + self.reporter.warning( + expr.location, f"Unsupported operation {operation}" + ) + return UnknownType() + def visit_base_type(self, node: p.BaseType) -> Type: base: Type try: @@ -654,20 +689,17 @@ class PythonTyper( return mapped - def _is_binary_function(self, function: Function) -> bool: - if len(function.pos_args) != 1: + def _check_arity( + self, + function: Function, + n_pos: Optional[int] = None, + n_mixed: Optional[int] = None, + n_keyword: Optional[int] = None, + ) -> bool: + if n_pos is not None and len(function.pos_args) != n_pos: return False - if len(function.args) != 0: + if n_mixed is not None and len(function.args) != n_mixed: return False - if len(function.kw_args) != 0: - return False - return True - - def _is_unary_function(self, function: Function) -> bool: - if len(function.pos_args) != 0: - return False - if len(function.args) != 0: - return False - if len(function.kw_args) != 0: + if n_keyword is not None and len(function.kw_args) != n_keyword: return False return True diff --git a/midas/checker/resolver.py b/midas/checker/resolver.py index 02fcbbc..636ccfe 100644 --- a/midas/checker/resolver.py +++ b/midas/checker/resolver.py @@ -196,3 +196,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]): def visit_list_expr(self, expr: p.ListExpr) -> None: for item in expr.items: self.resolve(item) + + def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None: + self.resolve(expr.object) + self.resolve(expr.index)