feat(checker): add same length assertion on frames
safely adding two dataframes is only possible if the sizes are the same, or null values could be added dynamically to pad the shortest dataframe
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
@@ -39,6 +40,7 @@ def frame_method(*names: str):
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
frame: DataFrameType
|
||||
frame_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
@@ -174,6 +176,11 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if result.is_valid:
|
||||
self._assert_same_length(
|
||||
call.call_expr, call.frame_expr, call.positional[0][0]
|
||||
)
|
||||
|
||||
return result.result
|
||||
|
||||
@frame_method()
|
||||
@@ -214,3 +221,50 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
|
||||
func_name: str = "__midas_frame_same_length__"
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
name=func_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
ast.arg(arg="frame1"),
|
||||
ast.arg(arg="frame2"),
|
||||
],
|
||||
kwonlyargs=[],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
),
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=ast.Attribute(
|
||||
value=ast.Name(id="frame1"),
|
||||
attr="size",
|
||||
),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="frame2"),
|
||||
attr="size",
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
],
|
||||
decorator_list=[],
|
||||
),
|
||||
)
|
||||
self.assertions.add(
|
||||
bound_expr=call_expr,
|
||||
inputs=[frame1, frame2],
|
||||
builder=lambda f1, f2: ast.Call(
|
||||
func=ast.Name(id=func_name),
|
||||
args=[f1, f2],
|
||||
keywords=[],
|
||||
),
|
||||
message="DataFrames must have the same length",
|
||||
)
|
||||
|
||||
@@ -141,6 +141,7 @@ class FrameManager:
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
frame: DataFrameType,
|
||||
frame_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
@@ -148,6 +149,7 @@ class FrameManager:
|
||||
) -> Type:
|
||||
call: Call = Call(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
frame=frame,
|
||||
frame_expr=frame_expr,
|
||||
positional=positional,
|
||||
|
||||
@@ -212,6 +212,7 @@ class PythonTyper(
|
||||
def call_method(
|
||||
self,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
obj: TypedExpr,
|
||||
method_name: str,
|
||||
positional: list[TypedExpr],
|
||||
@@ -223,6 +224,7 @@ class PythonTyper(
|
||||
return self.frame_mgr.call(
|
||||
method=method_name,
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
frame=unfolded,
|
||||
frame_expr=obj[0],
|
||||
positional=positional,
|
||||
@@ -503,7 +505,9 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||
return self._visit_binary_expr(
|
||||
expr.location, expr, expr.left, expr.right, method
|
||||
)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||
method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||
@@ -514,10 +518,17 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||
return self._visit_binary_expr(
|
||||
expr.location, expr, expr.left, expr.right, method
|
||||
)
|
||||
|
||||
def _visit_binary_expr(
|
||||
self, location: Location, left_expr: p.Expr, right_expr: p.Expr, method: str
|
||||
self,
|
||||
location: Location,
|
||||
expr: p.Expr,
|
||||
left_expr: p.Expr,
|
||||
right_expr: p.Expr,
|
||||
method: str,
|
||||
) -> Type:
|
||||
left: Type = self.type_of(left_expr)
|
||||
right: Type = self.type_of(right_expr)
|
||||
@@ -525,7 +536,12 @@ class PythonTyper(
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(
|
||||
location, (left_expr, left), method, [(right_expr, right)], {}
|
||||
location=location,
|
||||
call_expr=expr,
|
||||
obj=(left_expr, left),
|
||||
method_name=method,
|
||||
positional=[(right_expr, right)],
|
||||
keywords={},
|
||||
)
|
||||
except UndefinedMethodException:
|
||||
self.reporter.error(
|
||||
@@ -550,7 +566,12 @@ class PythonTyper(
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(
|
||||
expr.location, (expr.right, operand), method, [], {}
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
obj=(expr.right, operand),
|
||||
method_name=method,
|
||||
positional=[],
|
||||
keywords={},
|
||||
)
|
||||
except UndefinedMethodException:
|
||||
self.reporter.error(
|
||||
@@ -581,6 +602,7 @@ class PythonTyper(
|
||||
return self.frame_mgr.call(
|
||||
method=method,
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
frame=unfolded,
|
||||
frame_expr=obj,
|
||||
positional=positional,
|
||||
|
||||
Reference in New Issue
Block a user