diff --git a/midas/checker/frame_methods.py b/midas/checker/frame_methods.py index c2935f8..7472bf4 100644 --- a/midas/checker/frame_methods.py +++ b/midas/checker/frame_methods.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ast from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Optional @@ -39,6 +40,7 @@ def frame_method(*names: str): @dataclass(frozen=True, kw_only=True) class Call: location: Location + call_expr: p.Expr frame: DataFrameType frame_expr: p.Expr positional: list[TypedExpr] @@ -174,6 +176,11 @@ class MethodRegistry(metaclass=_MethodRegistryMeta): positional=call.positional, keywords=call.keywords, ) + if result.is_valid: + self._assert_same_length( + call.call_expr, call.frame_expr, call.positional[0][0] + ) + return result.result @frame_method() @@ -214,3 +221,50 @@ class MethodRegistry(metaclass=_MethodRegistryMeta): keywords=call.keywords, ) return result.result + + def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr): + func_name: str = "__midas_frame_same_length__" + self.assertions.define( + func_name, + ast.FunctionDef( + name=func_name, + args=ast.arguments( + posonlyargs=[], + args=[ + ast.arg(arg="frame1"), + ast.arg(arg="frame2"), + ], + kwonlyargs=[], + defaults=[], + kw_defaults=[], + ), + body=[ + ast.Return( + value=ast.Compare( + left=ast.Attribute( + value=ast.Name(id="frame1"), + attr="size", + ), + ops=[ast.Eq()], + comparators=[ + ast.Attribute( + value=ast.Name(id="frame2"), + attr="size", + ) + ], + ) + ) + ], + decorator_list=[], + ), + ) + self.assertions.add( + bound_expr=call_expr, + inputs=[frame1, frame2], + builder=lambda f1, f2: ast.Call( + func=ast.Name(id=func_name), + args=[f1, f2], + keywords=[], + ), + message="DataFrames must have the same length", + ) diff --git a/midas/checker/frames.py b/midas/checker/frames.py index 56c34c6..1a1a917 100644 --- a/midas/checker/frames.py +++ b/midas/checker/frames.py @@ -141,6 +141,7 @@ class FrameManager: self, method: str, location: Location, + call_expr: p.Expr, frame: DataFrameType, frame_expr: p.Expr, positional: list[TypedExpr], @@ -148,6 +149,7 @@ class FrameManager: ) -> Type: call: Call = Call( location=location, + call_expr=call_expr, frame=frame, frame_expr=frame_expr, positional=positional, diff --git a/midas/checker/python.py b/midas/checker/python.py index a42115f..2e0a07c 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -212,6 +212,7 @@ class PythonTyper( def call_method( self, location: Location, + call_expr: p.Expr, obj: TypedExpr, method_name: str, positional: list[TypedExpr], @@ -223,6 +224,7 @@ class PythonTyper( return self.frame_mgr.call( method=method_name, location=location, + call_expr=call_expr, frame=unfolded, frame_expr=obj[0], positional=positional, @@ -503,7 +505,9 @@ class PythonTyper( ) return UnknownType() - return self._visit_binary_expr(expr.location, expr.left, expr.right, method) + return self._visit_binary_expr( + expr.location, expr, expr.left, expr.right, method + ) def visit_compare_expr(self, expr: p.CompareExpr) -> Type: method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__) @@ -514,10 +518,17 @@ class PythonTyper( ) return UnknownType() - return self._visit_binary_expr(expr.location, expr.left, expr.right, method) + return self._visit_binary_expr( + expr.location, expr, expr.left, expr.right, method + ) def _visit_binary_expr( - self, location: Location, left_expr: p.Expr, right_expr: p.Expr, method: str + self, + location: Location, + expr: p.Expr, + left_expr: p.Expr, + right_expr: p.Expr, + method: str, ) -> Type: left: Type = self.type_of(left_expr) right: Type = self.type_of(right_expr) @@ -525,7 +536,12 @@ class PythonTyper( result: Optional[Type] try: result = self.call_method( - location, (left_expr, left), method, [(right_expr, right)], {} + location=location, + call_expr=expr, + obj=(left_expr, left), + method_name=method, + positional=[(right_expr, right)], + keywords={}, ) except UndefinedMethodException: self.reporter.error( @@ -550,7 +566,12 @@ class PythonTyper( result: Optional[Type] try: result = self.call_method( - expr.location, (expr.right, operand), method, [], {} + location=expr.location, + call_expr=expr, + obj=(expr.right, operand), + method_name=method, + positional=[], + keywords={}, ) except UndefinedMethodException: self.reporter.error( @@ -581,6 +602,7 @@ class PythonTyper( return self.frame_mgr.call( method=method, location=expr.location, + call_expr=expr, frame=unfolded, frame_expr=obj, positional=positional,