Compare commits

...

10 Commits

9 changed files with 3253 additions and 49 deletions

View File

@@ -14,6 +14,8 @@ if TYPE_CHECKING:
from midas.checker.registry import TypesRegistry
# Hard-coded subtype relationships between builtin types
# Circular dependencies and diamond inheritance MUST be avoided
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
"float": {"int"},

View File

@@ -22,6 +22,10 @@ class Call:
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.groupby_expr, self.groupby)
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
@method()

View File

@@ -33,6 +33,10 @@ class Call:
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.column_expr, self.column)
class ColumnMethodRegistry(MethodRegistry[Call]):
def _element_binary_op(self, call: Call, method: str) -> ColumnType:
@@ -69,8 +73,7 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
new_column = ColumnType(type=new_inner_type)
return new_column
@method("add", "__add__")
def add(self, call: Call) -> Type:
def _element_wise(self, call: Call, method: str) -> Type:
# TODO: support add with scalar
# Build signature with new column type and generic operand
@@ -87,7 +90,7 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
required=True,
),
],
returns=self._element_binary_op(call, "__add__"),
returns=self._element_binary_op(call, method),
),
)
@@ -105,8 +108,59 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
return result.result
@method()
def mean(self, call: Call) -> Type:
@method("add", "__add__")
def add(self, call: Call) -> Type:
return self._element_wise(call, "__add__")
@method("sub", "__sub__")
def sub(self, call: Call) -> Type:
return self._element_wise(call, "__sub__")
@method("mul", "__mul__")
def mul(self, call: Call) -> Type:
return self._element_wise(call, "__mul__")
@method("div", "truediv", "__truediv__")
def truediv(self, call: Call) -> Type:
return self._element_wise(call, "__truediv__")
@method("floordiv", "__floordiv__")
def floordiv(self, call: Call) -> Type:
return self._element_wise(call, "__floordiv__")
@method("mod", "__mod__")
def mod(self, call: Call) -> Type:
return self._element_wise(call, "__mod__")
@method("pow", "__pow__")
def pow(self, call: Call) -> Type:
return self._element_wise(call, "__pow__")
@method("lt", "__lt__")
def lt(self, call: Call) -> Type:
return self._element_wise(call, "__lt__")
@method("gt", "__gt__")
def gt(self, call: Call) -> Type:
return self._element_wise(call, "__gt__")
@method("le", "__le__")
def le(self, call: Call) -> Type:
return self._element_wise(call, "__le__")
@method("ge", "__ge__")
def ge(self, call: Call) -> Type:
return self._element_wise(call, "__ge__")
@method("ne", "__ne__")
def ne(self, call: Call) -> Type:
return self._element_wise(call, "__ne__")
@method("eq", "__eq__")
def eq(self, call: Call) -> Type:
return self._element_wise(call, "__eq__")
def _statistical(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
signature = Function(
kw_args=[
Function.Argument(
@@ -114,7 +168,8 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
name="axis",
type=TopType(),
required=False,
)
),
*kwargs,
],
returns=ColumnType(type=TopType()),
)
@@ -127,6 +182,66 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
)
return result.result
@method("kurtosis", "kurt")
def kurtosis(self, call: Call) -> Type:
return self._statistical(call)
@method()
def max(self, call: Call) -> Type:
return self._statistical(call)
@method()
def mean(self, call: Call) -> Type:
return self._statistical(call)
@method()
def median(self, call: Call) -> Type:
return self._statistical(call)
@method()
def min(self, call: Call) -> Type:
return self._statistical(call)
@method()
def mode(self, call: Call) -> Type:
return self._statistical(call)
@method("product", "prod")
def product(self, call: Call) -> Type:
return self._statistical(call)
@method()
def std(self, call: Call) -> Type:
return self._statistical(
call,
[
Function.Argument(
pos=1,
name="ddof",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def sum(self, call: Call) -> Type:
return self._statistical(call)
@method()
def var(self, call: Call) -> Type:
return self._statistical(
call,
[
Function.Argument(
pos=1,
name="var",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def groupby(self, call: Call) -> Type:
bool_: Type = self.types.get_type("bool")

View File

@@ -22,38 +22,43 @@ class Call:
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.groupby_expr, self.groupby)
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
@method()
def mean(self, call: Call) -> Type:
bool_ = self.types.get_type("bool")
NAMED_ARGS: dict[str, str] = {
"numeric_only": "bool",
"skipna": "bool",
"engine": "str",
"engine_kwargs": "dict",
}
def _aggregate(
self, call: Call, args: list[str | tuple[str, str, bool]] = []
) -> Type:
real_args: list[Function.Argument] = []
for i, arg in enumerate(args):
match arg:
case str() as name:
arg = Function.Argument(
pos=i,
name=name,
type=self.types.get_type(self.NAMED_ARGS[name]),
required=False,
)
case (name, type, required):
arg = Function.Argument(
pos=i,
name=name,
type=self.types.get_type(type),
required=required,
)
real_args.append(arg)
signature = Function(
args=[
Function.Argument(
pos=0,
name="numeric_only",
type=bool_,
required=False,
),
Function.Argument(
pos=1,
name="skipna",
type=bool_,
required=False,
),
Function.Argument(
pos=2,
name="engine",
type=self.types.get_type("str"),
required=False,
),
Function.Argument(
pos=3,
name="engine_kwargs",
type=self.types.get_type("dict"),
required=False,
),
],
args=real_args,
returns=call.groupby.frame,
)
@@ -64,3 +69,127 @@ class FrameGroupByMethodRegistry(MethodRegistry[Call]):
keywords=call.keywords,
)
return result.result
@method()
def kurt(self, call: Call) -> Type:
return self._aggregate(
call,
[
"skipna",
"numeric_only",
],
)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna", "engine", "engine_kwargs"],
)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna"],
)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
@method()
def prod(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
],
)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"ddof",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"var",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)

View File

@@ -33,6 +33,10 @@ class Call:
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.frame_expr, self.frame)
class FrameMethodRegistry(MethodRegistry[Call]):
def _get_method_result(
@@ -142,10 +146,8 @@ class FrameMethodRegistry(MethodRegistry[Call]):
return DataFrameType(columns=new_columns)
@method("add", "__add__")
def add(self, call: Call) -> Type:
# TODO: support add with scalar, sequence, Series, dict
def _element_wise(self, call: Call, method: str) -> Type:
# TODO: support scalar, sequence, Series, dict operand
# Build signature with new schema and generic operand
signature = Function(
args=[
@@ -156,7 +158,7 @@ class FrameMethodRegistry(MethodRegistry[Call]):
required=True,
),
],
returns=self._element_binary_op(call, "__add__"),
returns=self._element_binary_op(call, method),
)
# Map arguments and compute result type
@@ -173,8 +175,59 @@ class FrameMethodRegistry(MethodRegistry[Call]):
return result.result
@method()
def mean(self, call: Call) -> Type:
@method("add", "__add__")
def add(self, call: Call) -> Type:
return self._element_wise(call, "__add__")
@method("sub", "__sub__")
def sub(self, call: Call) -> Type:
return self._element_wise(call, "__sub__")
@method("mul", "__mul__")
def mul(self, call: Call) -> Type:
return self._element_wise(call, "__mul__")
@method("div", "truediv", "__truediv__")
def truediv(self, call: Call) -> Type:
return self._element_wise(call, "__truediv__")
@method("floordiv", "__floordiv__")
def floordiv(self, call: Call) -> Type:
return self._element_wise(call, "__floordiv__")
@method("mod", "__mod__")
def mod(self, call: Call) -> Type:
return self._element_wise(call, "__mod__")
@method("pow", "__pow__")
def pow(self, call: Call) -> Type:
return self._element_wise(call, "__pow__")
@method("lt", "__lt__")
def lt(self, call: Call) -> Type:
return self._element_wise(call, "__lt__")
@method("gt", "__gt__")
def gt(self, call: Call) -> Type:
return self._element_wise(call, "__gt__")
@method("le", "__le__")
def le(self, call: Call) -> Type:
return self._element_wise(call, "__le__")
@method("ge", "__ge__")
def ge(self, call: Call) -> Type:
return self._element_wise(call, "__ge__")
@method("ne", "__ne__")
def ne(self, call: Call) -> Type:
return self._element_wise(call, "__ne__")
@method("eq", "__eq__")
def eq(self, call: Call) -> Type:
return self._element_wise(call, "__eq__")
def _aggregate(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
with_axis = Function(
kw_args=[
Function.Argument(
@@ -182,7 +235,8 @@ class FrameMethodRegistry(MethodRegistry[Call]):
name="axis",
type=self.types.get_type("int"),
required=False,
)
),
*kwargs,
],
returns=ColumnType(type=TopType()),
)
@@ -193,7 +247,8 @@ class FrameMethodRegistry(MethodRegistry[Call]):
name="axis",
type=self.types.get_type("None"),
required=True,
)
),
*kwargs,
],
returns=TopType(),
)
@@ -212,6 +267,66 @@ class FrameMethodRegistry(MethodRegistry[Call]):
)
return result.result
@method("kurtosis", "kurt")
def kurtosis(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def mode(self, call: Call) -> Type:
return self._aggregate(call)
@method("product", "prod")
def product(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Argument(
pos=1,
name="ddof",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Argument(
pos=1,
name="var",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def groupby(self, call: Call) -> Type:
bool_: Type = self.types.get_type("bool")

View File

@@ -20,7 +20,7 @@ from midas.checker.types import Type, UnknownType
from midas.generator.collector import AssertionCollector
if TYPE_CHECKING:
from midas.checker.python import PythonTyper
from midas.checker.python import PythonTyper, TypedExpr
class _MethodRegistryMeta(type):
@@ -41,12 +41,18 @@ class _MethodRegistryMeta(type):
return new_class
class HasLocation(Protocol):
class MethodCall(Protocol):
@property
def location(self) -> Location: ...
@property
def call_expr(self) -> p.Expr: ...
T = TypeVar("T", bound=HasLocation)
@property
def subject(self) -> TypedExpr: ...
T = TypeVar("T", bound=MethodCall)
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
@@ -72,7 +78,9 @@ class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
def call(self, method: str, call: T) -> Type:
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
if func is None:
self.reporter.warning(call.location, f"Unknown method {method}")
self.reporter.warning(
call.location, f"Unknown method {method} on {call.subject[1]}"
)
return UnknownType()
return func(self, call)

View File

@@ -113,6 +113,15 @@ class TypesRegistry:
raise ValueError(f"Predicate {name} already defined")
self._predicates[name] = predicate
def is_builtin_subtype(self, name1: str, name2: str) -> bool:
subtypes: set[str] = BUILTIN_SUBTYPES.get(name2, set())
if name1 in subtypes:
return True
for subtype in subtypes:
if self.is_builtin_subtype(name1, subtype):
return True
return False
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`
@@ -150,7 +159,7 @@ class TypesRegistry:
return self.is_subtype(base1, type2)
case (BaseType(name=name1), BaseType(name=name2)):
return name1 in BUILTIN_SUBTYPES.get(name2, set())
return self.is_builtin_subtype(name1, name2)
case (ComplexType(properties=props1), ComplexType(properties=props2)):
for k, t in props2.items():

View File

@@ -0,0 +1,51 @@
# type: ignore
# ruff: disable [F821]
df1: Frame[a:int, b:float]
df2: Frame[a:int, b:float]
_: Any
# Arithmetic
_ = df1 + df2
_ = df1 - df2
_ = df1 * df2
_ = df1 / df2
_ = df1 // df2
_ = df1 % df2
_ = df1**df2
# Comparisons
_ = df1 < df2
_ = df1 > df2
_ = df1 <= df2
_ = df1 >= df2
_ = df1 != df2
_ = df1 == df2
# Aggregate
_ = df1.kurt()
_ = df1.kurtosis()
_ = df1.max()
_ = df1.mean()
_ = df1.median()
_ = df1.min()
_ = df1.mode()
_ = df1.prod()
_ = df1.product()
_ = df1.std()
_ = df1.sum()
_ = df1.var()
# Groupby
gb = df1.groupby(by="a")
_ = gb.kurt()
_ = gb.max()
_ = gb.mean()
_ = gb.median()
_ = gb.min()
_ = gb.prod()
_ = gb.std()
_ = gb.sum()
_ = gb.var()

File diff suppressed because it is too large Load Diff