Compare commits
10 Commits
be2fd4c837
...
a143972ef1
| Author | SHA1 | Date | |
|---|---|---|---|
|
a143972ef1
|
|||
|
0c70048b62
|
|||
|
1c0c917873
|
|||
|
1f6189daa4
|
|||
|
66b585c3d6
|
|||
|
819ab3c2bf
|
|||
|
d8c0b17512
|
|||
|
6e06f9078e
|
|||
|
ece2e3a6a3
|
|||
|
74c07c9afb
|
@@ -14,6 +14,8 @@ if TYPE_CHECKING:
|
|||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
|
|
||||||
|
|
||||||
|
# Hard-coded subtype relationships between builtin types
|
||||||
|
# Circular dependencies and diamond inheritance MUST be avoided
|
||||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||||
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
|
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
|
||||||
"float": {"int"},
|
"float": {"int"},
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ class Call:
|
|||||||
positional: list[TypedExpr]
|
positional: list[TypedExpr]
|
||||||
keywords: dict[str, TypedExpr]
|
keywords: dict[str, TypedExpr]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subject(self) -> TypedExpr:
|
||||||
|
return (self.groupby_expr, self.groupby)
|
||||||
|
|
||||||
|
|
||||||
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
||||||
@method()
|
@method()
|
||||||
|
|||||||
@@ -33,6 +33,10 @@ class Call:
|
|||||||
positional: list[TypedExpr]
|
positional: list[TypedExpr]
|
||||||
keywords: dict[str, TypedExpr]
|
keywords: dict[str, TypedExpr]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subject(self) -> TypedExpr:
|
||||||
|
return (self.column_expr, self.column)
|
||||||
|
|
||||||
|
|
||||||
class ColumnMethodRegistry(MethodRegistry[Call]):
|
class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||||
def _element_binary_op(self, call: Call, method: str) -> ColumnType:
|
def _element_binary_op(self, call: Call, method: str) -> ColumnType:
|
||||||
@@ -69,8 +73,7 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
|||||||
new_column = ColumnType(type=new_inner_type)
|
new_column = ColumnType(type=new_inner_type)
|
||||||
return new_column
|
return new_column
|
||||||
|
|
||||||
@method("add", "__add__")
|
def _element_wise(self, call: Call, method: str) -> Type:
|
||||||
def add(self, call: Call) -> Type:
|
|
||||||
# TODO: support add with scalar
|
# TODO: support add with scalar
|
||||||
|
|
||||||
# Build signature with new column type and generic operand
|
# Build signature with new column type and generic operand
|
||||||
@@ -87,7 +90,7 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
|||||||
required=True,
|
required=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
returns=self._element_binary_op(call, "__add__"),
|
returns=self._element_binary_op(call, method),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -105,8 +108,59 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
|||||||
|
|
||||||
return result.result
|
return result.result
|
||||||
|
|
||||||
@method()
|
@method("add", "__add__")
|
||||||
def mean(self, call: Call) -> Type:
|
def add(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__add__")
|
||||||
|
|
||||||
|
@method("sub", "__sub__")
|
||||||
|
def sub(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__sub__")
|
||||||
|
|
||||||
|
@method("mul", "__mul__")
|
||||||
|
def mul(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__mul__")
|
||||||
|
|
||||||
|
@method("div", "truediv", "__truediv__")
|
||||||
|
def truediv(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__truediv__")
|
||||||
|
|
||||||
|
@method("floordiv", "__floordiv__")
|
||||||
|
def floordiv(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__floordiv__")
|
||||||
|
|
||||||
|
@method("mod", "__mod__")
|
||||||
|
def mod(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__mod__")
|
||||||
|
|
||||||
|
@method("pow", "__pow__")
|
||||||
|
def pow(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__pow__")
|
||||||
|
|
||||||
|
@method("lt", "__lt__")
|
||||||
|
def lt(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__lt__")
|
||||||
|
|
||||||
|
@method("gt", "__gt__")
|
||||||
|
def gt(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__gt__")
|
||||||
|
|
||||||
|
@method("le", "__le__")
|
||||||
|
def le(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__le__")
|
||||||
|
|
||||||
|
@method("ge", "__ge__")
|
||||||
|
def ge(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__ge__")
|
||||||
|
|
||||||
|
@method("ne", "__ne__")
|
||||||
|
def ne(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__ne__")
|
||||||
|
|
||||||
|
@method("eq", "__eq__")
|
||||||
|
def eq(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__eq__")
|
||||||
|
|
||||||
|
def _statistical(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
|
||||||
signature = Function(
|
signature = Function(
|
||||||
kw_args=[
|
kw_args=[
|
||||||
Function.Argument(
|
Function.Argument(
|
||||||
@@ -114,7 +168,8 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
|||||||
name="axis",
|
name="axis",
|
||||||
type=TopType(),
|
type=TopType(),
|
||||||
required=False,
|
required=False,
|
||||||
)
|
),
|
||||||
|
*kwargs,
|
||||||
],
|
],
|
||||||
returns=ColumnType(type=TopType()),
|
returns=ColumnType(type=TopType()),
|
||||||
)
|
)
|
||||||
@@ -127,6 +182,66 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
|||||||
)
|
)
|
||||||
return result.result
|
return result.result
|
||||||
|
|
||||||
|
@method("kurtosis", "kurt")
|
||||||
|
def kurtosis(self, call: Call) -> Type:
|
||||||
|
return self._statistical(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def max(self, call: Call) -> Type:
|
||||||
|
return self._statistical(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def mean(self, call: Call) -> Type:
|
||||||
|
return self._statistical(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def median(self, call: Call) -> Type:
|
||||||
|
return self._statistical(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def min(self, call: Call) -> Type:
|
||||||
|
return self._statistical(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def mode(self, call: Call) -> Type:
|
||||||
|
return self._statistical(call)
|
||||||
|
|
||||||
|
@method("product", "prod")
|
||||||
|
def product(self, call: Call) -> Type:
|
||||||
|
return self._statistical(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def std(self, call: Call) -> Type:
|
||||||
|
return self._statistical(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
Function.Argument(
|
||||||
|
pos=1,
|
||||||
|
name="ddof",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def sum(self, call: Call) -> Type:
|
||||||
|
return self._statistical(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def var(self, call: Call) -> Type:
|
||||||
|
return self._statistical(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
Function.Argument(
|
||||||
|
pos=1,
|
||||||
|
name="var",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@method()
|
@method()
|
||||||
def groupby(self, call: Call) -> Type:
|
def groupby(self, call: Call) -> Type:
|
||||||
bool_: Type = self.types.get_type("bool")
|
bool_: Type = self.types.get_type("bool")
|
||||||
|
|||||||
@@ -22,38 +22,43 @@ class Call:
|
|||||||
positional: list[TypedExpr]
|
positional: list[TypedExpr]
|
||||||
keywords: dict[str, TypedExpr]
|
keywords: dict[str, TypedExpr]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subject(self) -> TypedExpr:
|
||||||
|
return (self.groupby_expr, self.groupby)
|
||||||
|
|
||||||
|
|
||||||
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
|
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
|
||||||
@method()
|
NAMED_ARGS: dict[str, str] = {
|
||||||
def mean(self, call: Call) -> Type:
|
"numeric_only": "bool",
|
||||||
bool_ = self.types.get_type("bool")
|
"skipna": "bool",
|
||||||
|
"engine": "str",
|
||||||
|
"engine_kwargs": "dict",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _aggregate(
|
||||||
|
self, call: Call, args: list[str | tuple[str, str, bool]] = []
|
||||||
|
) -> Type:
|
||||||
|
real_args: list[Function.Argument] = []
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
match arg:
|
||||||
|
case str() as name:
|
||||||
|
arg = Function.Argument(
|
||||||
|
pos=i,
|
||||||
|
name=name,
|
||||||
|
type=self.types.get_type(self.NAMED_ARGS[name]),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
case (name, type, required):
|
||||||
|
arg = Function.Argument(
|
||||||
|
pos=i,
|
||||||
|
name=name,
|
||||||
|
type=self.types.get_type(type),
|
||||||
|
required=required,
|
||||||
|
)
|
||||||
|
real_args.append(arg)
|
||||||
|
|
||||||
signature = Function(
|
signature = Function(
|
||||||
args=[
|
args=real_args,
|
||||||
Function.Argument(
|
|
||||||
pos=0,
|
|
||||||
name="numeric_only",
|
|
||||||
type=bool_,
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
Function.Argument(
|
|
||||||
pos=1,
|
|
||||||
name="skipna",
|
|
||||||
type=bool_,
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
Function.Argument(
|
|
||||||
pos=2,
|
|
||||||
name="engine",
|
|
||||||
type=self.types.get_type("str"),
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
Function.Argument(
|
|
||||||
pos=3,
|
|
||||||
name="engine_kwargs",
|
|
||||||
type=self.types.get_type("dict"),
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
returns=call.groupby.frame,
|
returns=call.groupby.frame,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -64,3 +69,127 @@ class FrameGroupByMethodRegistry(MethodRegistry[Call]):
|
|||||||
keywords=call.keywords,
|
keywords=call.keywords,
|
||||||
)
|
)
|
||||||
return result.result
|
return result.result
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def kurt(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
"skipna",
|
||||||
|
"numeric_only",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def max(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
"numeric_only",
|
||||||
|
(
|
||||||
|
"min_count",
|
||||||
|
"int",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
"skipna",
|
||||||
|
"engine",
|
||||||
|
"engine_kwargs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def mean(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
["numeric_only", "skipna", "engine", "engine_kwargs"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def median(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
["numeric_only", "skipna"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def min(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
"numeric_only",
|
||||||
|
(
|
||||||
|
"min_count",
|
||||||
|
"int",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
"skipna",
|
||||||
|
"engine",
|
||||||
|
"engine_kwargs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def prod(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
"numeric_only",
|
||||||
|
(
|
||||||
|
"min_count",
|
||||||
|
"int",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
"skipna",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def std(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"ddof",
|
||||||
|
"int",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
"engine",
|
||||||
|
"engine_kwargs",
|
||||||
|
"numeric_only",
|
||||||
|
"skipna",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def sum(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
"numeric_only",
|
||||||
|
(
|
||||||
|
"min_count",
|
||||||
|
"int",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
"skipna",
|
||||||
|
"engine",
|
||||||
|
"engine_kwargs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def var(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"var",
|
||||||
|
"int",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
"engine",
|
||||||
|
"engine_kwargs",
|
||||||
|
"numeric_only",
|
||||||
|
"skipna",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
@@ -33,6 +33,10 @@ class Call:
|
|||||||
positional: list[TypedExpr]
|
positional: list[TypedExpr]
|
||||||
keywords: dict[str, TypedExpr]
|
keywords: dict[str, TypedExpr]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subject(self) -> TypedExpr:
|
||||||
|
return (self.frame_expr, self.frame)
|
||||||
|
|
||||||
|
|
||||||
class FrameMethodRegistry(MethodRegistry[Call]):
|
class FrameMethodRegistry(MethodRegistry[Call]):
|
||||||
def _get_method_result(
|
def _get_method_result(
|
||||||
@@ -142,10 +146,8 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
|||||||
|
|
||||||
return DataFrameType(columns=new_columns)
|
return DataFrameType(columns=new_columns)
|
||||||
|
|
||||||
@method("add", "__add__")
|
def _element_wise(self, call: Call, method: str) -> Type:
|
||||||
def add(self, call: Call) -> Type:
|
# TODO: support scalar, sequence, Series, dict operand
|
||||||
# TODO: support add with scalar, sequence, Series, dict
|
|
||||||
|
|
||||||
# Build signature with new schema and generic operand
|
# Build signature with new schema and generic operand
|
||||||
signature = Function(
|
signature = Function(
|
||||||
args=[
|
args=[
|
||||||
@@ -156,7 +158,7 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
|||||||
required=True,
|
required=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
returns=self._element_binary_op(call, "__add__"),
|
returns=self._element_binary_op(call, method),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map arguments and compute result type
|
# Map arguments and compute result type
|
||||||
@@ -173,8 +175,59 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
|||||||
|
|
||||||
return result.result
|
return result.result
|
||||||
|
|
||||||
@method()
|
@method("add", "__add__")
|
||||||
def mean(self, call: Call) -> Type:
|
def add(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__add__")
|
||||||
|
|
||||||
|
@method("sub", "__sub__")
|
||||||
|
def sub(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__sub__")
|
||||||
|
|
||||||
|
@method("mul", "__mul__")
|
||||||
|
def mul(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__mul__")
|
||||||
|
|
||||||
|
@method("div", "truediv", "__truediv__")
|
||||||
|
def truediv(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__truediv__")
|
||||||
|
|
||||||
|
@method("floordiv", "__floordiv__")
|
||||||
|
def floordiv(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__floordiv__")
|
||||||
|
|
||||||
|
@method("mod", "__mod__")
|
||||||
|
def mod(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__mod__")
|
||||||
|
|
||||||
|
@method("pow", "__pow__")
|
||||||
|
def pow(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__pow__")
|
||||||
|
|
||||||
|
@method("lt", "__lt__")
|
||||||
|
def lt(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__lt__")
|
||||||
|
|
||||||
|
@method("gt", "__gt__")
|
||||||
|
def gt(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__gt__")
|
||||||
|
|
||||||
|
@method("le", "__le__")
|
||||||
|
def le(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__le__")
|
||||||
|
|
||||||
|
@method("ge", "__ge__")
|
||||||
|
def ge(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__ge__")
|
||||||
|
|
||||||
|
@method("ne", "__ne__")
|
||||||
|
def ne(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__ne__")
|
||||||
|
|
||||||
|
@method("eq", "__eq__")
|
||||||
|
def eq(self, call: Call) -> Type:
|
||||||
|
return self._element_wise(call, "__eq__")
|
||||||
|
|
||||||
|
def _aggregate(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
|
||||||
with_axis = Function(
|
with_axis = Function(
|
||||||
kw_args=[
|
kw_args=[
|
||||||
Function.Argument(
|
Function.Argument(
|
||||||
@@ -182,7 +235,8 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
|||||||
name="axis",
|
name="axis",
|
||||||
type=self.types.get_type("int"),
|
type=self.types.get_type("int"),
|
||||||
required=False,
|
required=False,
|
||||||
)
|
),
|
||||||
|
*kwargs,
|
||||||
],
|
],
|
||||||
returns=ColumnType(type=TopType()),
|
returns=ColumnType(type=TopType()),
|
||||||
)
|
)
|
||||||
@@ -193,7 +247,8 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
|||||||
name="axis",
|
name="axis",
|
||||||
type=self.types.get_type("None"),
|
type=self.types.get_type("None"),
|
||||||
required=True,
|
required=True,
|
||||||
)
|
),
|
||||||
|
*kwargs,
|
||||||
],
|
],
|
||||||
returns=TopType(),
|
returns=TopType(),
|
||||||
)
|
)
|
||||||
@@ -212,6 +267,66 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
|||||||
)
|
)
|
||||||
return result.result
|
return result.result
|
||||||
|
|
||||||
|
@method("kurtosis", "kurt")
|
||||||
|
def kurtosis(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def max(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def mean(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def median(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def min(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def mode(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method("product", "prod")
|
||||||
|
def product(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def std(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
Function.Argument(
|
||||||
|
pos=1,
|
||||||
|
name="ddof",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def sum(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(call)
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def var(self, call: Call) -> Type:
|
||||||
|
return self._aggregate(
|
||||||
|
call,
|
||||||
|
[
|
||||||
|
Function.Argument(
|
||||||
|
pos=1,
|
||||||
|
name="var",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@method()
|
@method()
|
||||||
def groupby(self, call: Call) -> Type:
|
def groupby(self, call: Call) -> Type:
|
||||||
bool_: Type = self.types.get_type("bool")
|
bool_: Type = self.types.get_type("bool")
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from midas.checker.types import Type, UnknownType
|
|||||||
from midas.generator.collector import AssertionCollector
|
from midas.generator.collector import AssertionCollector
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from midas.checker.python import PythonTyper
|
from midas.checker.python import PythonTyper, TypedExpr
|
||||||
|
|
||||||
|
|
||||||
class _MethodRegistryMeta(type):
|
class _MethodRegistryMeta(type):
|
||||||
@@ -41,12 +41,18 @@ class _MethodRegistryMeta(type):
|
|||||||
return new_class
|
return new_class
|
||||||
|
|
||||||
|
|
||||||
class HasLocation(Protocol):
|
class MethodCall(Protocol):
|
||||||
@property
|
@property
|
||||||
def location(self) -> Location: ...
|
def location(self) -> Location: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def call_expr(self) -> p.Expr: ...
|
||||||
|
|
||||||
T = TypeVar("T", bound=HasLocation)
|
@property
|
||||||
|
def subject(self) -> TypedExpr: ...
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=MethodCall)
|
||||||
|
|
||||||
|
|
||||||
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
|
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
|
||||||
@@ -72,7 +78,9 @@ class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
|
|||||||
def call(self, method: str, call: T) -> Type:
|
def call(self, method: str, call: T) -> Type:
|
||||||
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
|
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
|
||||||
if func is None:
|
if func is None:
|
||||||
self.reporter.warning(call.location, f"Unknown method {method}")
|
self.reporter.warning(
|
||||||
|
call.location, f"Unknown method {method} on {call.subject[1]}"
|
||||||
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
return func(self, call)
|
return func(self, call)
|
||||||
|
|
||||||
|
|||||||
@@ -113,6 +113,15 @@ class TypesRegistry:
|
|||||||
raise ValueError(f"Predicate {name} already defined")
|
raise ValueError(f"Predicate {name} already defined")
|
||||||
self._predicates[name] = predicate
|
self._predicates[name] = predicate
|
||||||
|
|
||||||
|
def is_builtin_subtype(self, name1: str, name2: str) -> bool:
|
||||||
|
subtypes: set[str] = BUILTIN_SUBTYPES.get(name2, set())
|
||||||
|
if name1 in subtypes:
|
||||||
|
return True
|
||||||
|
for subtype in subtypes:
|
||||||
|
if self.is_builtin_subtype(name1, subtype):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
"""Check whether `type1` is a subtype of `type2`
|
"""Check whether `type1` is a subtype of `type2`
|
||||||
|
|
||||||
@@ -150,7 +159,7 @@ class TypesRegistry:
|
|||||||
return self.is_subtype(base1, type2)
|
return self.is_subtype(base1, type2)
|
||||||
|
|
||||||
case (BaseType(name=name1), BaseType(name=name2)):
|
case (BaseType(name=name1), BaseType(name=name2)):
|
||||||
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
return self.is_builtin_subtype(name1, name2)
|
||||||
|
|
||||||
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||||
for k, t in props2.items():
|
for k, t in props2.items():
|
||||||
|
|||||||
51
tests/cases/checker/09_frame_ops.py
Normal file
51
tests/cases/checker/09_frame_ops.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# type: ignore
|
||||||
|
# ruff: disable [F821]
|
||||||
|
|
||||||
|
df1: Frame[a:int, b:float]
|
||||||
|
df2: Frame[a:int, b:float]
|
||||||
|
|
||||||
|
_: Any
|
||||||
|
|
||||||
|
# Arithmetic
|
||||||
|
_ = df1 + df2
|
||||||
|
_ = df1 - df2
|
||||||
|
_ = df1 * df2
|
||||||
|
_ = df1 / df2
|
||||||
|
_ = df1 // df2
|
||||||
|
_ = df1 % df2
|
||||||
|
_ = df1**df2
|
||||||
|
|
||||||
|
# Comparisons
|
||||||
|
_ = df1 < df2
|
||||||
|
_ = df1 > df2
|
||||||
|
_ = df1 <= df2
|
||||||
|
_ = df1 >= df2
|
||||||
|
_ = df1 != df2
|
||||||
|
_ = df1 == df2
|
||||||
|
|
||||||
|
# Aggregate
|
||||||
|
_ = df1.kurt()
|
||||||
|
_ = df1.kurtosis()
|
||||||
|
_ = df1.max()
|
||||||
|
_ = df1.mean()
|
||||||
|
_ = df1.median()
|
||||||
|
_ = df1.min()
|
||||||
|
_ = df1.mode()
|
||||||
|
_ = df1.prod()
|
||||||
|
_ = df1.product()
|
||||||
|
_ = df1.std()
|
||||||
|
_ = df1.sum()
|
||||||
|
_ = df1.var()
|
||||||
|
|
||||||
|
# Groupby
|
||||||
|
gb = df1.groupby(by="a")
|
||||||
|
|
||||||
|
_ = gb.kurt()
|
||||||
|
_ = gb.max()
|
||||||
|
_ = gb.mean()
|
||||||
|
_ = gb.median()
|
||||||
|
_ = gb.min()
|
||||||
|
_ = gb.prod()
|
||||||
|
_ = gb.std()
|
||||||
|
_ = gb.sum()
|
||||||
|
_ = gb.var()
|
||||||
2771
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
2771
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user