feat(checker): handle calls on group-bys

This commit is contained in:
2026-07-02 19:53:58 +02:00
parent 8e8ed62266
commit b14f46d405
5 changed files with 212 additions and 7 deletions

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import ColumnGroupBy, Function, Type
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
groupby: ColumnGroupBy
groupby_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
@method()
def mean(self, call: Call) -> Type:
bool_ = self.types.get_type("bool")
signature = Function(
args=[
Function.Argument(
pos=0,
name="numeric_only",
type=bool_,
required=False,
),
Function.Argument(
pos=1,
name="skipna",
type=bool_,
required=False,
),
Function.Argument(
pos=2,
name="engine",
type=self.types.get_type("str"),
required=False,
),
Function.Argument(
pos=3,
name="engine_kwargs",
type=self.types.get_type("dict"),
required=False,
),
],
returns=call.groupby.column,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result

View File

@@ -4,8 +4,10 @@ from typing import TYPE_CHECKING
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frames.column_groupby_methods import Call as GroupByCall
from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
from midas.checker.types import ColumnType, Type
from midas.checker.types import ColumnGroupBy, ColumnType, Type
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
@@ -15,6 +17,9 @@ class ColumnManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: ColumnMethodRegistry = ColumnMethodRegistry(self.typer)
self.groupby_method_resolver: ColumnGroupByMethodRegistry = (
ColumnGroupByMethodRegistry(self.typer)
)
def call(
self,
@@ -35,3 +40,23 @@ class ColumnManager:
keywords=keywords,
)
return self.method_resolver.call(method, call)
def groupby_call(
self,
method: str,
location: Location,
call_expr: p.Expr,
groupby: ColumnGroupBy,
groupby_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: GroupByCall = GroupByCall(
location=location,
call_expr=call_expr,
groupby=groupby,
groupby_expr=groupby_expr,
positional=positional,
keywords=keywords,
)
return self.groupby_method_resolver.call(method, call)

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import FrameGroupBy, Function, Type
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
groupby: FrameGroupBy
groupby_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
@method()
def mean(self, call: Call) -> Type:
bool_ = self.types.get_type("bool")
signature = Function(
args=[
Function.Argument(
pos=0,
name="numeric_only",
type=bool_,
required=False,
),
Function.Argument(
pos=1,
name="skipna",
type=bool_,
required=False,
),
Function.Argument(
pos=2,
name="engine",
type=self.types.get_type("str"),
required=False,
),
Function.Argument(
pos=3,
name="engine_kwargs",
type=self.types.get_type("dict"),
required=False,
),
],
returns=call.groupby.frame,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result

View File

@@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frames.frame_groupby_methods import Call as GroupByCall
from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import (
@@ -28,6 +30,9 @@ class FrameManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer)
self.groupby_method_resolver: FrameGroupByMethodRegistry = (
FrameGroupByMethodRegistry(self.typer)
)
def assign(
self,
@@ -184,3 +189,23 @@ class FrameManager:
keywords=keywords,
)
return self.method_resolver.call(method, call)
def groupby_call(
self,
method: str,
location: Location,
call_expr: p.Expr,
groupby: FrameGroupBy,
groupby_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: GroupByCall = GroupByCall(
location=location,
call_expr=call_expr,
groupby=groupby,
groupby_expr=groupby_expr,
positional=positional,
keywords=keywords,
)
return self.groupby_method_resolver.call(method, call)

View File

@@ -23,6 +23,7 @@ from midas.checker.resolver import Resolver
from midas.checker.types import (
AppliedType,
BaseType,
ColumnGroupBy,
ColumnType,
ConstraintType,
DataFrameType,
@@ -235,6 +236,17 @@ class PythonTyper(
keywords=keywords,
)
case FrameGroupBy():
return self.frame_mgr.groupby_call(
method=method_name,
location=location,
call_expr=call_expr,
groupby=unfolded,
groupby_expr=obj[0],
positional=positional,
keywords=keywords,
)
case ColumnType():
return self.column_mgr.call(
method=method_name,
@@ -246,6 +258,17 @@ class PythonTyper(
keywords=keywords,
)
case ColumnGroupBy():
return self.column_mgr.groupby_call(
method=method_name,
location=location,
call_expr=call_expr,
groupby=unfolded,
groupby_expr=obj[0],
positional=positional,
keywords=keywords,
)
method: Optional[Type] = self.types.lookup_member(obj[1], method_name)
if method is None:
raise UndefinedMethodException
@@ -612,17 +635,17 @@ class PythonTyper(
match expr.callee:
case p.GetExpr(object=obj, name=method):
obj_type: Type = self.type_of(obj)
unfolded: Type = unfold_type(obj_type)
if isinstance(unfolded, DataFrameType):
return self.frame_mgr.call(
method=method,
return (
self.call_method(
location=expr.location,
call_expr=expr,
frame=unfolded,
frame_expr=obj,
obj=(obj, obj_type),
method_name=method,
positional=positional,
keywords=keywords,
)
or UnknownType()
)
callee: Type = self.type_of(expr.callee)
result: CallResult = self.dispatcher.get_result(