feat(checker): add add/mean/groupby on columns
This commit is contained in:
@@ -1,13 +1,23 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.frames.utils import MethodRegistry
|
from midas.checker.dispatcher import CallResult
|
||||||
|
from midas.checker.frames.utils import MethodRegistry, method
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
|
ColumnGroupBy,
|
||||||
ColumnType,
|
ColumnType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
TopType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
UnknownType,
|
||||||
|
unfold_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -24,4 +34,183 @@ class Call:
|
|||||||
keywords: dict[str, TypedExpr]
|
keywords: dict[str, TypedExpr]
|
||||||
|
|
||||||
|
|
||||||
class ColumnMethodRegistry(MethodRegistry[Call]): ...
|
class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||||
|
@method("add", "__add__")
|
||||||
|
def add(self, call: Call) -> Type:
|
||||||
|
# TODO: support add with scalar
|
||||||
|
# TODO: check operation exists on inner column types
|
||||||
|
|
||||||
|
column2: Optional[ColumnType] = None
|
||||||
|
|
||||||
|
col_type1: Type = call.column.type
|
||||||
|
new_column: Type = ColumnType(type=UnknownType())
|
||||||
|
if len(call.positional) != 0:
|
||||||
|
other: Type = call.positional[0][1]
|
||||||
|
unfolded_other: Type = unfold_type(other)
|
||||||
|
if isinstance(unfolded_other, ColumnType):
|
||||||
|
column2 = unfolded_other
|
||||||
|
col_type2: Type = column2.type
|
||||||
|
if self.types.are_equivalent(col_type2, col_type1):
|
||||||
|
new_column = ColumnType(type=col_type1)
|
||||||
|
|
||||||
|
# Build signature with new column type and generic operand
|
||||||
|
param_type: TypeVar = TypeVar(name="T", bound=None)
|
||||||
|
signature = GenericType(
|
||||||
|
name="add",
|
||||||
|
params=[param_type],
|
||||||
|
body=Function(
|
||||||
|
args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="other",
|
||||||
|
type=ColumnType(type=param_type),
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
returns=new_column,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map arguments and compute result type
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=signature,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
if result.is_valid:
|
||||||
|
self._assert_same_length(
|
||||||
|
call.call_expr, call.column_expr, call.positional[0][0]
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def mean(self, call: Call) -> Type:
|
||||||
|
signature = Function(
|
||||||
|
kw_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="axis",
|
||||||
|
type=TopType(),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
returns=ColumnType(type=TopType()),
|
||||||
|
)
|
||||||
|
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=signature,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def groupby(self, call: Call) -> Type:
|
||||||
|
bool_: Type = self.types.get_type("bool")
|
||||||
|
function: Function = Function(
|
||||||
|
args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="by",
|
||||||
|
type=TopType(),
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=1,
|
||||||
|
name="level",
|
||||||
|
type=TopType(),
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
kw_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=2,
|
||||||
|
name="as_index",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=3,
|
||||||
|
name="sort",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=4,
|
||||||
|
name="group_keys",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=5,
|
||||||
|
name="observed",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
Function.Argument(
|
||||||
|
pos=6,
|
||||||
|
name="dropna",
|
||||||
|
type=bool_,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
returns=ColumnGroupBy(column=call.column),
|
||||||
|
)
|
||||||
|
|
||||||
|
result: CallResult = self.dispatcher.get_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=function,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
return result.result
|
||||||
|
|
||||||
|
def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr):
|
||||||
|
func_name: str = "__midas_column_same_length__"
|
||||||
|
self.assertions.define(
|
||||||
|
func_name,
|
||||||
|
ast.FunctionDef(
|
||||||
|
name=func_name,
|
||||||
|
args=ast.arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
ast.arg(arg="column1"),
|
||||||
|
ast.arg(arg="column2"),
|
||||||
|
],
|
||||||
|
kwonlyargs=[],
|
||||||
|
defaults=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
),
|
||||||
|
body=[
|
||||||
|
ast.Return(
|
||||||
|
value=ast.Compare(
|
||||||
|
left=ast.Attribute(
|
||||||
|
value=ast.Name(id="column1"),
|
||||||
|
attr="size",
|
||||||
|
),
|
||||||
|
ops=[ast.Eq()],
|
||||||
|
comparators=[
|
||||||
|
ast.Attribute(
|
||||||
|
value=ast.Name(id="column2"),
|
||||||
|
attr="size",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
decorator_list=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertions.add(
|
||||||
|
bound_expr=call_expr,
|
||||||
|
inputs=[column1, column2],
|
||||||
|
builder=lambda c1, c2: ast.Call(
|
||||||
|
func=ast.Name(id=func_name),
|
||||||
|
args=[c1, c2],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
message="Columns must have the same length",
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user