feat(checker): add aggregation ops on column groupby
This commit is contained in:
@@ -7,7 +7,7 @@ import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import ColumnGroupBy, Function, Type
|
||||
from midas.checker.types import ColumnGroupBy, ColumnType, Function, TopType, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
@@ -28,37 +28,46 @@ class Call:
|
||||
|
||||
|
||||
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
bool_ = self.types.get_type("bool")
|
||||
NAMED_ARGS: dict[str, str] = {
|
||||
"numeric_only": "bool",
|
||||
"skipna": "bool",
|
||||
"engine": "str",
|
||||
"engine_kwargs": "dict",
|
||||
}
|
||||
|
||||
def _aggregate(
|
||||
self,
|
||||
call: Call,
|
||||
args: list[str | tuple[str, str, bool]] = [],
|
||||
*,
|
||||
preserve_inner_type: bool = False,
|
||||
) -> Type:
|
||||
real_args: list[Function.Argument] = []
|
||||
for i, arg in enumerate(args):
|
||||
match arg:
|
||||
case str() as name:
|
||||
arg = Function.Argument(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(self.NAMED_ARGS[name]),
|
||||
required=False,
|
||||
)
|
||||
case (name, type, required):
|
||||
arg = Function.Argument(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(type),
|
||||
required=required,
|
||||
)
|
||||
real_args.append(arg)
|
||||
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="numeric_only",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="skipna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="engine",
|
||||
type=self.types.get_type("str"),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="engine_kwargs",
|
||||
type=self.types.get_type("dict"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.groupby.column,
|
||||
args=real_args,
|
||||
returns=(
|
||||
call.groupby.column
|
||||
if preserve_inner_type
|
||||
else ColumnType(type=TopType())
|
||||
),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
@@ -68,3 +77,127 @@ class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def kurt(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["skipna", "numeric_only"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna", "engine", "engine_kwargs"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna"],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def prod(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"ddof",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"var",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -160,7 +160,13 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
def eq(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__eq__")
|
||||
|
||||
def _statistical(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
|
||||
def _aggregate(
|
||||
self,
|
||||
call: Call,
|
||||
kwargs: list[Function.Argument] = [],
|
||||
*,
|
||||
preserve_inner_type: bool = False,
|
||||
) -> Type:
|
||||
signature = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
@@ -171,7 +177,7 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
returns=ColumnType(type=TopType()),
|
||||
returns=call.column if preserve_inner_type else ColumnType(type=TopType()),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
@@ -184,35 +190,35 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
|
||||
@method("kurtosis", "kurt")
|
||||
def kurtosis(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def mode(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method("product", "prod")
|
||||
def product(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._statistical(
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
@@ -226,11 +232,11 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._statistical(call)
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._statistical(
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
|
||||
@@ -38,14 +38,64 @@ _ = df1.sum()
|
||||
_ = df1.var()
|
||||
|
||||
# Groupby
|
||||
gb = df1.groupby(by="a")
|
||||
df_gb = df1.groupby(by="a")
|
||||
|
||||
_ = gb.kurt()
|
||||
_ = gb.max()
|
||||
_ = gb.mean()
|
||||
_ = gb.median()
|
||||
_ = gb.min()
|
||||
_ = gb.prod()
|
||||
_ = gb.std()
|
||||
_ = gb.sum()
|
||||
_ = gb.var()
|
||||
_ = df_gb.kurt()
|
||||
_ = df_gb.max()
|
||||
_ = df_gb.mean()
|
||||
_ = df_gb.median()
|
||||
_ = df_gb.min()
|
||||
_ = df_gb.prod()
|
||||
_ = df_gb.std()
|
||||
_ = df_gb.sum()
|
||||
_ = df_gb.var()
|
||||
|
||||
|
||||
# Columns
|
||||
|
||||
col1 = df1["a"]
|
||||
col2 = df1["a"]
|
||||
|
||||
# Arithmetic
|
||||
_ = col1 + col2
|
||||
_ = col1 - col2
|
||||
_ = col1 * col2
|
||||
_ = col1 / col2
|
||||
_ = col1 // col2
|
||||
_ = col1 % col2
|
||||
_ = col1**col2
|
||||
|
||||
# Comparisons
|
||||
_ = col1 < col2
|
||||
_ = col1 > col2
|
||||
_ = col1 <= col2
|
||||
_ = col1 >= col2
|
||||
_ = col1 != col2
|
||||
_ = col1 == col2
|
||||
|
||||
# Aggregate
|
||||
_ = col1.kurt()
|
||||
_ = col1.kurtosis()
|
||||
_ = col1.max()
|
||||
_ = col1.mean()
|
||||
_ = col1.median()
|
||||
_ = col1.min()
|
||||
_ = col1.mode()
|
||||
_ = col1.prod()
|
||||
_ = col1.product()
|
||||
_ = col1.std()
|
||||
_ = col1.sum()
|
||||
_ = col1.var()
|
||||
|
||||
# Groupby
|
||||
col_gb = col1.groupby(level=0)
|
||||
|
||||
_ = col_gb.kurt()
|
||||
_ = col_gb.max()
|
||||
_ = col_gb.mean()
|
||||
_ = col_gb.median()
|
||||
_ = col_gb.min()
|
||||
_ = col_gb.prod()
|
||||
_ = col_gb.std()
|
||||
_ = col_gb.sum()
|
||||
_ = col_gb.var()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user