From b14f46d4056707e13d7b1a139c15e8319ca8dc91 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Thu, 2 Jul 2026 19:53:58 +0200 Subject: [PATCH] feat(checker): handle calls on group-bys --- .../checker/frames/column_groupby_methods.py | 66 +++++++++++++++++++ midas/checker/frames/column_manager.py | 27 +++++++- midas/checker/frames/frame_groupby_methods.py | 66 +++++++++++++++++++ midas/checker/frames/frame_manager.py | 25 +++++++ midas/checker/python.py | 35 ++++++++-- 5 files changed, 212 insertions(+), 7 deletions(-) create mode 100644 midas/checker/frames/column_groupby_methods.py create mode 100644 midas/checker/frames/frame_groupby_methods.py diff --git a/midas/checker/frames/column_groupby_methods.py b/midas/checker/frames/column_groupby_methods.py new file mode 100644 index 0000000..de10c84 --- /dev/null +++ b/midas/checker/frames/column_groupby_methods.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import midas.ast.python as p +from midas.ast.location import Location +from midas.checker.dispatcher import CallResult +from midas.checker.frames.utils import MethodRegistry, method +from midas.checker.types import ColumnGroupBy, Function, Type + +if TYPE_CHECKING: + from midas.checker.python import TypedExpr + + +@dataclass(frozen=True, kw_only=True) +class Call: + location: Location + call_expr: p.Expr + groupby: ColumnGroupBy + groupby_expr: p.Expr + positional: list[TypedExpr] + keywords: dict[str, TypedExpr] + + +class ColumnGroupByMethodRegistry(MethodRegistry[Call]): + @method() + def mean(self, call: Call) -> Type: + bool_ = self.types.get_type("bool") + signature = Function( + args=[ + Function.Argument( + pos=0, + name="numeric_only", + type=bool_, + required=False, + ), + Function.Argument( + pos=1, + name="skipna", + type=bool_, + required=False, + ), + Function.Argument( + pos=2, + name="engine", + type=self.types.get_type("str"), + required=False, + ), + Function.Argument( + pos=3, + name="engine_kwargs", + type=self.types.get_type("dict"), + required=False, + ), + ], + returns=call.groupby.column, + ) + + result: CallResult = self.dispatcher.get_result( + location=call.location, + callee=signature, + positional=call.positional, + keywords=call.keywords, + ) + return result.result diff --git a/midas/checker/frames/column_manager.py b/midas/checker/frames/column_manager.py index 1ff9799..0793034 100644 --- a/midas/checker/frames/column_manager.py +++ b/midas/checker/frames/column_manager.py @@ -4,8 +4,10 @@ from typing import TYPE_CHECKING import midas.ast.python as p from midas.ast.location import Location +from midas.checker.frames.column_groupby_methods import Call as GroupByCall +from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry from midas.checker.frames.column_methods import Call, ColumnMethodRegistry -from midas.checker.types import ColumnType, Type +from midas.checker.types import ColumnGroupBy, ColumnType, Type if TYPE_CHECKING: from midas.checker.python import PythonTyper, TypedExpr @@ -15,6 +17,9 @@ class ColumnManager: def __init__(self, typer: PythonTyper) -> None: self.typer: PythonTyper = typer self.method_resolver: ColumnMethodRegistry = ColumnMethodRegistry(self.typer) + self.groupby_method_resolver: ColumnGroupByMethodRegistry = ( + ColumnGroupByMethodRegistry(self.typer) + ) def call( self, @@ -35,3 +40,23 @@ class ColumnManager: keywords=keywords, ) return self.method_resolver.call(method, call) + + def groupby_call( + self, + method: str, + location: Location, + call_expr: p.Expr, + groupby: ColumnGroupBy, + groupby_expr: p.Expr, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + ) -> Type: + call: GroupByCall = GroupByCall( + location=location, + call_expr=call_expr, + groupby=groupby, + groupby_expr=groupby_expr, + positional=positional, + keywords=keywords, + ) + return self.groupby_method_resolver.call(method, call) diff --git a/midas/checker/frames/frame_groupby_methods.py b/midas/checker/frames/frame_groupby_methods.py new file mode 100644 index 0000000..4b0acc8 --- /dev/null +++ b/midas/checker/frames/frame_groupby_methods.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import midas.ast.python as p +from midas.ast.location import Location +from midas.checker.dispatcher import CallResult +from midas.checker.frames.utils import MethodRegistry, method +from midas.checker.types import FrameGroupBy, Function, Type + +if TYPE_CHECKING: + from midas.checker.python import TypedExpr + + +@dataclass(frozen=True, kw_only=True) +class Call: + location: Location + call_expr: p.Expr + groupby: FrameGroupBy + groupby_expr: p.Expr + positional: list[TypedExpr] + keywords: dict[str, TypedExpr] + + +class FrameGroupByMethodRegistry(MethodRegistry[Call]): + @method() + def mean(self, call: Call) -> Type: + bool_ = self.types.get_type("bool") + signature = Function( + args=[ + Function.Argument( + pos=0, + name="numeric_only", + type=bool_, + required=False, + ), + Function.Argument( + pos=1, + name="skipna", + type=bool_, + required=False, + ), + Function.Argument( + pos=2, + name="engine", + type=self.types.get_type("str"), + required=False, + ), + Function.Argument( + pos=3, + name="engine_kwargs", + type=self.types.get_type("dict"), + required=False, + ), + ], + returns=call.groupby.frame, + ) + + result: CallResult = self.dispatcher.get_result( + location=call.location, + callee=signature, + positional=call.positional, + keywords=call.keywords, + ) + return result.result diff --git a/midas/checker/frames/frame_manager.py b/midas/checker/frames/frame_manager.py index 36fd84d..8a5794d 100644 --- a/midas/checker/frames/frame_manager.py +++ b/midas/checker/frames/frame_manager.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast import midas.ast.python as p from midas.ast.location import Location +from midas.checker.frames.frame_groupby_methods import Call as GroupByCall +from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry from midas.checker.frames.frame_methods import Call, FrameMethodRegistry from midas.checker.reporter import FileReporter from midas.checker.types import ( @@ -28,6 +30,9 @@ class FrameManager: def __init__(self, typer: PythonTyper) -> None: self.typer: PythonTyper = typer self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer) + self.groupby_method_resolver: FrameGroupByMethodRegistry = ( + FrameGroupByMethodRegistry(self.typer) + ) def assign( self, @@ -184,3 +189,23 @@ class FrameManager: keywords=keywords, ) return self.method_resolver.call(method, call) + + def groupby_call( + self, + method: str, + location: Location, + call_expr: p.Expr, + groupby: FrameGroupBy, + groupby_expr: p.Expr, + positional: list[TypedExpr], + keywords: dict[str, TypedExpr], + ) -> Type: + call: GroupByCall = GroupByCall( + location=location, + call_expr=call_expr, + groupby=groupby, + groupby_expr=groupby_expr, + positional=positional, + keywords=keywords, + ) + return self.groupby_method_resolver.call(method, call) diff --git a/midas/checker/python.py b/midas/checker/python.py index e014681..8e4b59b 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -23,6 +23,7 @@ from midas.checker.resolver import Resolver from midas.checker.types import ( AppliedType, BaseType, + ColumnGroupBy, ColumnType, ConstraintType, DataFrameType, @@ -235,6 +236,17 @@ class PythonTyper( keywords=keywords, ) + case FrameGroupBy(): + return self.frame_mgr.groupby_call( + method=method_name, + location=location, + call_expr=call_expr, + groupby=unfolded, + groupby_expr=obj[0], + positional=positional, + keywords=keywords, + ) + case ColumnType(): return self.column_mgr.call( method=method_name, @@ -246,6 +258,17 @@ class PythonTyper( keywords=keywords, ) + case ColumnGroupBy(): + return self.column_mgr.groupby_call( + method=method_name, + location=location, + call_expr=call_expr, + groupby=unfolded, + groupby_expr=obj[0], + positional=positional, + keywords=keywords, + ) + method: Optional[Type] = self.types.lookup_member(obj[1], method_name) if method is None: raise UndefinedMethodException @@ -612,17 +635,17 @@ class PythonTyper( match expr.callee: case p.GetExpr(object=obj, name=method): obj_type: Type = self.type_of(obj) - unfolded: Type = unfold_type(obj_type) - if isinstance(unfolded, DataFrameType): - return self.frame_mgr.call( - method=method, + return ( + self.call_method( location=expr.location, call_expr=expr, - frame=unfolded, - frame_expr=obj, + obj=(obj, obj_type), + method_name=method, positional=positional, keywords=keywords, ) + or UnknownType() + ) callee: Type = self.type_of(expr.callee) result: CallResult = self.dispatcher.get_result(