feat(checker): add aggregation ops on column groupby

This commit is contained in:
2026-07-03 11:25:06 +02:00
parent 20173a0b07
commit 733c8736b8
4 changed files with 1949 additions and 95 deletions

View File

@@ -7,7 +7,7 @@ import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import ColumnGroupBy, Function, Type
from midas.checker.types import ColumnGroupBy, ColumnType, Function, TopType, Type
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@@ -28,37 +28,46 @@ class Call:
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
@method()
def mean(self, call: Call) -> Type:
bool_ = self.types.get_type("bool")
NAMED_ARGS: dict[str, str] = {
"numeric_only": "bool",
"skipna": "bool",
"engine": "str",
"engine_kwargs": "dict",
}
def _aggregate(
self,
call: Call,
args: list[str | tuple[str, str, bool]] = [],
*,
preserve_inner_type: bool = False,
) -> Type:
real_args: list[Function.Argument] = []
for i, arg in enumerate(args):
match arg:
case str() as name:
arg = Function.Argument(
pos=i,
name=name,
type=self.types.get_type(self.NAMED_ARGS[name]),
required=False,
)
case (name, type, required):
arg = Function.Argument(
pos=i,
name=name,
type=self.types.get_type(type),
required=required,
)
real_args.append(arg)
signature = Function(
args=[
Function.Argument(
pos=0,
name="numeric_only",
type=bool_,
required=False,
),
Function.Argument(
pos=1,
name="skipna",
type=bool_,
required=False,
),
Function.Argument(
pos=2,
name="engine",
type=self.types.get_type("str"),
required=False,
),
Function.Argument(
pos=3,
name="engine_kwargs",
type=self.types.get_type("dict"),
required=False,
),
],
returns=call.groupby.column,
args=real_args,
returns=(
call.groupby.column
if preserve_inner_type
else ColumnType(type=TopType())
),
)
result: CallResult = self.dispatcher.get_result(
@@ -68,3 +77,127 @@ class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
keywords=call.keywords,
)
return result.result
@method()
def kurt(self, call: Call) -> Type:
return self._aggregate(
call,
["skipna", "numeric_only"],
)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
preserve_inner_type=True,
)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna", "engine", "engine_kwargs"],
)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna"],
preserve_inner_type=True,
)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
preserve_inner_type=True,
)
@method()
def prod(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
],
)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"ddof",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"var",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)

View File

@@ -160,7 +160,13 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
def eq(self, call: Call) -> Type:
return self._element_wise(call, "__eq__")
def _statistical(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
def _aggregate(
self,
call: Call,
kwargs: list[Function.Argument] = [],
*,
preserve_inner_type: bool = False,
) -> Type:
signature = Function(
kw_args=[
Function.Argument(
@@ -171,7 +177,7 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
),
*kwargs,
],
returns=ColumnType(type=TopType()),
returns=call.column if preserve_inner_type else ColumnType(type=TopType()),
)
result: CallResult = self.dispatcher.get_result(
@@ -184,35 +190,35 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
@method("kurtosis", "kurt")
def kurtosis(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call)
@method()
def max(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call, preserve_inner_type=True)
@method()
def mean(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call)
@method()
def median(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call, preserve_inner_type=True)
@method()
def min(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call, preserve_inner_type=True)
@method()
def mode(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call, preserve_inner_type=True)
@method("product", "prod")
def product(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call)
@method()
def std(self, call: Call) -> Type:
return self._statistical(
return self._aggregate(
call,
[
Function.Argument(
@@ -226,11 +232,11 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
@method()
def sum(self, call: Call) -> Type:
return self._statistical(call)
return self._aggregate(call)
@method()
def var(self, call: Call) -> Type:
return self._statistical(
return self._aggregate(
call,
[
Function.Argument(

View File

@@ -38,14 +38,64 @@ _ = df1.sum()
_ = df1.var()
# Groupby
gb = df1.groupby(by="a")
df_gb = df1.groupby(by="a")
_ = gb.kurt()
_ = gb.max()
_ = gb.mean()
_ = gb.median()
_ = gb.min()
_ = gb.prod()
_ = gb.std()
_ = gb.sum()
_ = gb.var()
_ = df_gb.kurt()
_ = df_gb.max()
_ = df_gb.mean()
_ = df_gb.median()
_ = df_gb.min()
_ = df_gb.prod()
_ = df_gb.std()
_ = df_gb.sum()
_ = df_gb.var()
# Columns
col1 = df1["a"]
col2 = df1["a"]
# Arithmetic
_ = col1 + col2
_ = col1 - col2
_ = col1 * col2
_ = col1 / col2
_ = col1 // col2
_ = col1 % col2
_ = col1**col2
# Comparisons
_ = col1 < col2
_ = col1 > col2
_ = col1 <= col2
_ = col1 >= col2
_ = col1 != col2
_ = col1 == col2
# Aggregate
_ = col1.kurt()
_ = col1.kurtosis()
_ = col1.max()
_ = col1.mean()
_ = col1.median()
_ = col1.min()
_ = col1.mode()
_ = col1.prod()
_ = col1.product()
_ = col1.std()
_ = col1.sum()
_ = col1.var()
# Groupby
col_gb = col1.groupby(level=0)
_ = col_gb.kurt()
_ = col_gb.max()
_ = col_gb.mean()
_ = col_gb.median()
_ = col_gb.min()
_ = col_gb.prod()
_ = col_gb.std()
_ = col_gb.sum()
_ = col_gb.var()

File diff suppressed because it is too large Load Diff