feat(checker): add same length assertion on frames

safely adding two dataframes is only possible if the sizes are the same, or null values could be added dynamically to pad the shortest dataframe
This commit is contained in:
2026-07-02 17:14:05 +02:00
parent 8df01afd8c
commit ff69b65171
3 changed files with 83 additions and 5 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional
@@ -39,6 +40,7 @@ def frame_method(*names: str):
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
frame: DataFrameType
frame_expr: p.Expr
positional: list[TypedExpr]
@@ -174,6 +176,11 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
positional=call.positional,
keywords=call.keywords,
)
if result.is_valid:
self._assert_same_length(
call.call_expr, call.frame_expr, call.positional[0][0]
)
return result.result
@frame_method()
@@ -214,3 +221,50 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
keywords=call.keywords,
)
return result.result
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
func_name: str = "__midas_frame_same_length__"
self.assertions.define(
func_name,
ast.FunctionDef(
name=func_name,
args=ast.arguments(
posonlyargs=[],
args=[
ast.arg(arg="frame1"),
ast.arg(arg="frame2"),
],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Return(
value=ast.Compare(
left=ast.Attribute(
value=ast.Name(id="frame1"),
attr="size",
),
ops=[ast.Eq()],
comparators=[
ast.Attribute(
value=ast.Name(id="frame2"),
attr="size",
)
],
)
)
],
decorator_list=[],
),
)
self.assertions.add(
bound_expr=call_expr,
inputs=[frame1, frame2],
builder=lambda f1, f2: ast.Call(
func=ast.Name(id=func_name),
args=[f1, f2],
keywords=[],
),
message="DataFrames must have the same length",
)

View File

@@ -141,6 +141,7 @@ class FrameManager:
self,
method: str,
location: Location,
call_expr: p.Expr,
frame: DataFrameType,
frame_expr: p.Expr,
positional: list[TypedExpr],
@@ -148,6 +149,7 @@ class FrameManager:
) -> Type:
call: Call = Call(
location=location,
call_expr=call_expr,
frame=frame,
frame_expr=frame_expr,
positional=positional,

View File

@@ -212,6 +212,7 @@ class PythonTyper(
def call_method(
self,
location: Location,
call_expr: p.Expr,
obj: TypedExpr,
method_name: str,
positional: list[TypedExpr],
@@ -223,6 +224,7 @@ class PythonTyper(
return self.frame_mgr.call(
method=method_name,
location=location,
call_expr=call_expr,
frame=unfolded,
frame_expr=obj[0],
positional=positional,
@@ -503,7 +505,9 @@ class PythonTyper(
)
return UnknownType()
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
return self._visit_binary_expr(
expr.location, expr, expr.left, expr.right, method
)
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__)
@@ -514,10 +518,17 @@ class PythonTyper(
)
return UnknownType()
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
return self._visit_binary_expr(
expr.location, expr, expr.left, expr.right, method
)
def _visit_binary_expr(
self, location: Location, left_expr: p.Expr, right_expr: p.Expr, method: str
self,
location: Location,
expr: p.Expr,
left_expr: p.Expr,
right_expr: p.Expr,
method: str,
) -> Type:
left: Type = self.type_of(left_expr)
right: Type = self.type_of(right_expr)
@@ -525,7 +536,12 @@ class PythonTyper(
result: Optional[Type]
try:
result = self.call_method(
location, (left_expr, left), method, [(right_expr, right)], {}
location=location,
call_expr=expr,
obj=(left_expr, left),
method_name=method,
positional=[(right_expr, right)],
keywords={},
)
except UndefinedMethodException:
self.reporter.error(
@@ -550,7 +566,12 @@ class PythonTyper(
result: Optional[Type]
try:
result = self.call_method(
expr.location, (expr.right, operand), method, [], {}
location=expr.location,
call_expr=expr,
obj=(expr.right, operand),
method_name=method,
positional=[],
keywords={},
)
except UndefinedMethodException:
self.reporter.error(
@@ -581,6 +602,7 @@ class PythonTyper(
return self.frame_mgr.call(
method=method,
location=expr.location,
call_expr=expr,
frame=unfolded,
frame_expr=obj,
positional=positional,