Compare commits
22 Commits
be2fd4c837
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 03bc32400b | |||
|
4a93ee45d9
|
|||
|
8197131d8d
|
|||
|
cf91187b7a
|
|||
|
1b2bdf0b79
|
|||
| c6cc38bfeb | |||
|
4d3e3f44a1
|
|||
|
ec80b1e92e
|
|||
|
4ea15519f3
|
|||
|
7a6e01cff8
|
|||
|
733c8736b8
|
|||
|
20173a0b07
|
|||
|
a143972ef1
|
|||
|
0c70048b62
|
|||
|
1c0c917873
|
|||
|
1f6189daa4
|
|||
|
66b585c3d6
|
|||
|
819ab3c2bf
|
|||
|
d8c0b17512
|
|||
|
6e06f9078e
|
|||
|
ece2e3a6a3
|
|||
|
74c07c9afb
|
@@ -198,10 +198,26 @@ python3 build/midas/script.py
|
||||
In this chapter, you will find a complete reference for the Midas definition language.
|
||||
|
||||
A `*.midas` file contains a number of statements, which can be:
|
||||
- *`alias`* statements (see @alias-stmt): to define a new type alias
|
||||
- *`type`* statements (see @type-stmt): to define a new type
|
||||
- *`extend`* statements (see @extend-stmt): to define member of a type
|
||||
- *`predicate`* statements (see @predicate-stmt): to define named predicates that can be used in constraint types
|
||||
|
||||
== Alias Statement <alias-stmt>
|
||||
|
||||
An *`alias`* statement lets you define a new type alias. It requires a unique name and base type.
|
||||
|
||||
While a `type` statement (see @type-stmt) allows generic definitions, aliases are purely a for givin an alternative name to a type.
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
alias MyType = float
|
||||
```,
|
||||
caption: [Simple `alias` statement declaring a new type "`MyType`" equivalent to `float`],
|
||||
) <midas-simple-alias>
|
||||
|
||||
This statement defines a new type called `MyType` which is equivalent to `float`. `MyType` and `float` can be used interchangeably.
|
||||
|
||||
== Type Statement <type-stmt>
|
||||
|
||||
A *`type`* statement lets you define a new type. It requires a unique name and base type.
|
||||
@@ -212,7 +228,7 @@ The simplest form of a *`type`* statement is:
|
||||
type MyType = float
|
||||
```,
|
||||
caption: [Simple `type` statement declaring a new type "`MyType`" as a subtype of `float`],
|
||||
) <midas-simple-alias>
|
||||
) <midas-simple-type>
|
||||
|
||||
This statement defines a new type called `MyType` which is a subtype of `float`. `MyType` is a `float` but a `float` is not necessarily `MyType`.
|
||||
|
||||
@@ -291,8 +307,7 @@ To better refine a generic type, you can also bound type parameters using the fo
|
||||
caption: [Generic container type definition with a bound],
|
||||
)
|
||||
|
||||
This can be read as "`Container` is a generic type which takes one type parameter `T` that must be a subtype of `float`".
|
||||
|
||||
This can be read as "`Container` is a generic type which takes one type parameter `T` that must be a subtype of `float`".\
|
||||
You can use a generic type, i.e. instantiate it, by using a similar syntax with concrete type as arguments:
|
||||
|
||||
#figure(
|
||||
@@ -318,6 +333,46 @@ The _body_ of a generic type, i.e. the right-hand side of the definition, can co
|
||||
caption: [Type parameters in a generic type's body],
|
||||
)
|
||||
|
||||
=== `Column` / `Frame` types
|
||||
|
||||
To provide useful type-checking for data engineers, Midas offers two special types: `Column` and `Frame`.
|
||||
Their goal is to help type check Pandas' `Series` and `DataFrame` respectively.
|
||||
|
||||
==== `Column`
|
||||
|
||||
The `Column` type is a generic type used to represent a `pandas.Series` object.
|
||||
You can use it like any other generic type and it will provide type checking for some common methods and attributes offered by Pandas.
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
type Temperature = float
|
||||
alias Temperatures = Column[Temperature]
|
||||
```,
|
||||
caption: [Simple column type definition],
|
||||
)
|
||||
|
||||
==== `Frame` <frame-type>
|
||||
|
||||
The `Frame` type is a super-powered generic type used to represent a `pandas.DataFrame` object.
|
||||
In place of type arguments, `Frame` accepts a schema, i.e. a series of column definitions.
|
||||
@simple-frame show how you can define a simple frame type with 3 columns:
|
||||
- `name`: a column of `Name` values
|
||||
- `age`: a column of `int` values
|
||||
- `height`: a column of `float where _ >= 0` values
|
||||
|
||||
Notice that you don't need to specify `Column` types.
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
type Name = str where len(_) != 0
|
||||
alias Data = Frame[
|
||||
name: Name,
|
||||
age: int,
|
||||
height: float where _ >= 0
|
||||
]
|
||||
```,
|
||||
) <simple-frame>
|
||||
|
||||
#pagebreak()
|
||||
|
||||
== Extend Statement <extend-stmt>
|
||||
@@ -503,6 +558,7 @@ A simple annotation declaration, without assigning a value, is enough to declare
|
||||
)
|
||||
|
||||
Because unpacking is not supported, assigning to multiple values is also not handled by the type checker.
|
||||
For more information about type annotations, see @type-annotations
|
||||
|
||||
== Arithmetic
|
||||
|
||||
@@ -578,7 +634,7 @@ Conditional statements are checked relatively strictly by Midas. The test expres
|
||||
|
||||
Simple forms of `for` loops can be used, that is using a single variable and iterating over an object implementing the `__getitem__` method. Like above in @if-else, leaking variables from inside the loop is ignored.
|
||||
|
||||
The `for`-`else` statements are not supported. `while` loops are also not not supported.
|
||||
`for`-`else` statements are not supported. `while` loops are also not supported.
|
||||
|
||||
== Functions
|
||||
|
||||
@@ -686,6 +742,35 @@ There may be some cases where the cost of checking a value at runtime is simply
|
||||
|
||||
If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated.
|
||||
|
||||
== Annotations / Type Hints <type-annotations>
|
||||
|
||||
Vanilla Python already lets you use type hints to specify the type of variables and function parameters.
|
||||
|
||||
Midas use them to type check your code. Additionally, it allows you to use a special syntax to define a `Frame` types directly in these annotations.
|
||||
|
||||
Because these annotations are not interpretable by Python, your integrated type checker might complain loudly about them being invalid.
|
||||
A workaround is to silence it by adding a type comment at the end of the line, as shown in @silence-errors.
|
||||
|
||||
#figure(
|
||||
```python
|
||||
var: Frame[name: str, age: float] # type: ignore # noqa: F821
|
||||
```,
|
||||
caption: [MyPy's and Pylance's complaints about custom type annotation can be silenced with type comments],
|
||||
) <silence-errors>
|
||||
|
||||
=== Frame type annotation
|
||||
|
||||
The syntax is similar to how you can define frame types in the Midas language (see @frame-type). The only difference is that types can only be name references; you cannot inline constraint types.
|
||||
|
||||
The example of @python-frame-type shows how you can annotate a dataframe with some columns directly in Python.
|
||||
|
||||
#figure(
|
||||
```python
|
||||
df: Frame[name: Name, age: float, height: Length[Meter]] = ...
|
||||
```,
|
||||
caption: [Frame type annotation in Python],
|
||||
) <python-frame-type>
|
||||
|
||||
= Commands <commands>
|
||||
|
||||
#TODO
|
||||
|
||||
@@ -37,6 +37,9 @@ contexts:
|
||||
pop: true
|
||||
|
||||
keywords:
|
||||
- match: \balias\b
|
||||
scope: keyword.declaration.midas
|
||||
push: alias-stmt
|
||||
- match: \btype\b
|
||||
scope: keyword.declaration.midas
|
||||
push: type-stmt
|
||||
@@ -47,6 +50,15 @@ contexts:
|
||||
scope: keyword.declaration.midas
|
||||
push: predicate-stmt
|
||||
|
||||
alias-stmt:
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.type
|
||||
- match: "="
|
||||
scope: keyword.operator.equal.midas
|
||||
push: type-expr
|
||||
- match: $
|
||||
pop: true
|
||||
|
||||
type-stmt:
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.type
|
||||
@@ -67,6 +79,13 @@ contexts:
|
||||
- match: \b(where)\b
|
||||
scope: keyword.other.midas
|
||||
set: constraint
|
||||
- match: "Frame"
|
||||
scope: entity.name.type
|
||||
push:
|
||||
- match: \[
|
||||
push: frame-schema
|
||||
- match: $
|
||||
pop: true
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.type
|
||||
- match: $
|
||||
@@ -178,3 +197,15 @@ contexts:
|
||||
|
||||
- match: '\)'
|
||||
pop: true
|
||||
|
||||
frame-schema:
|
||||
- include: frame-column
|
||||
- match: \]
|
||||
# scope: punctuation.section.block.end
|
||||
pop: true
|
||||
|
||||
frame-column:
|
||||
- match: "{{identifier}}"
|
||||
scope: variable.other.member
|
||||
- match: ":"
|
||||
push: type-expr
|
||||
|
||||
@@ -14,10 +14,11 @@ if TYPE_CHECKING:
|
||||
from midas.checker.registry import TypesRegistry
|
||||
|
||||
|
||||
# Hard-coded subtype relationships between builtin types
|
||||
# Circular dependencies and diamond inheritance MUST be avoided
|
||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
|
||||
"float": {"int"},
|
||||
"int": {"bool"},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import ColumnGroupBy, Function, Type
|
||||
from midas.checker.types import ColumnGroupBy, ColumnType, Function, TopType, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
@@ -22,39 +22,52 @@ class Call:
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.groupby_expr, self.groupby)
|
||||
|
||||
|
||||
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
bool_ = self.types.get_type("bool")
|
||||
NAMED_ARGS: dict[str, str] = {
|
||||
"numeric_only": "bool",
|
||||
"skipna": "bool",
|
||||
"engine": "str",
|
||||
"engine_kwargs": "dict",
|
||||
}
|
||||
|
||||
def _aggregate(
|
||||
self,
|
||||
call: Call,
|
||||
args: list[str | tuple[str, str, bool]] = [],
|
||||
*,
|
||||
preserve_inner_type: bool = False,
|
||||
) -> Type:
|
||||
real_args: list[Function.Argument] = []
|
||||
for i, arg in enumerate(args):
|
||||
match arg:
|
||||
case str() as name:
|
||||
arg = Function.Argument(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(self.NAMED_ARGS[name]),
|
||||
required=False,
|
||||
)
|
||||
case (name, type, required):
|
||||
arg = Function.Argument(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(type),
|
||||
required=required,
|
||||
)
|
||||
real_args.append(arg)
|
||||
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="numeric_only",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="skipna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="engine",
|
||||
type=self.types.get_type("str"),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="engine_kwargs",
|
||||
type=self.types.get_type("dict"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.groupby.column,
|
||||
args=real_args,
|
||||
returns=(
|
||||
call.groupby.column
|
||||
if preserve_inner_type
|
||||
else ColumnType(type=TopType())
|
||||
),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
@@ -64,3 +77,127 @@ class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def kurt(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["skipna", "numeric_only"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna", "engine", "engine_kwargs"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna"],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def prod(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"ddof",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"var",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.column_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry
|
||||
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import ColumnGroupBy, ColumnType, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -60,3 +61,18 @@ class ColumnManager:
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
|
||||
def get_attribute(self, column: ColumnType, name: str) -> Optional[Type]:
|
||||
types: TypesRegistry = self.typer.types
|
||||
match name:
|
||||
case "ndim" | "size":
|
||||
return types.get_type("int")
|
||||
|
||||
case "shape":
|
||||
return types.tuple_of("int")
|
||||
|
||||
case "T":
|
||||
return column
|
||||
|
||||
case _:
|
||||
return None
|
||||
|
||||
@@ -33,6 +33,10 @@ class Call:
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.column_expr, self.column)
|
||||
|
||||
|
||||
class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
def _element_binary_op(self, call: Call, method: str) -> ColumnType:
|
||||
@@ -69,8 +73,7 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
new_column = ColumnType(type=new_inner_type)
|
||||
return new_column
|
||||
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
def _element_wise(self, call: Call, method: str) -> Type:
|
||||
# TODO: support add with scalar
|
||||
|
||||
# Build signature with new column type and generic operand
|
||||
@@ -87,7 +90,7 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
returns=self._element_binary_op(call, "__add__"),
|
||||
returns=self._element_binary_op(call, method),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -105,8 +108,65 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__add__")
|
||||
|
||||
@method("sub", "__sub__")
|
||||
def sub(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__sub__")
|
||||
|
||||
@method("mul", "__mul__")
|
||||
def mul(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mul__")
|
||||
|
||||
@method("div", "truediv", "__truediv__")
|
||||
def truediv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__truediv__")
|
||||
|
||||
@method("floordiv", "__floordiv__")
|
||||
def floordiv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__floordiv__")
|
||||
|
||||
@method("mod", "__mod__")
|
||||
def mod(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mod__")
|
||||
|
||||
@method("pow", "__pow__")
|
||||
def pow(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__pow__")
|
||||
|
||||
@method("lt", "__lt__")
|
||||
def lt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__lt__")
|
||||
|
||||
@method("gt", "__gt__")
|
||||
def gt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__gt__")
|
||||
|
||||
@method("le", "__le__")
|
||||
def le(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__le__")
|
||||
|
||||
@method("ge", "__ge__")
|
||||
def ge(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ge__")
|
||||
|
||||
@method("ne", "__ne__")
|
||||
def ne(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ne__")
|
||||
|
||||
@method("eq", "__eq__")
|
||||
def eq(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__eq__")
|
||||
|
||||
def _aggregate(
|
||||
self,
|
||||
call: Call,
|
||||
kwargs: list[Function.Argument] = [],
|
||||
*,
|
||||
preserve_inner_type: bool = False,
|
||||
) -> Type:
|
||||
signature = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
@@ -114,9 +174,114 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
name="axis",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
returns=call.column if preserve_inner_type else ColumnType(type=TopType()),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method("kurtosis", "kurt")
|
||||
def kurtosis(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def mode(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method("product", "prod")
|
||||
def product(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="ddof",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
returns=ColumnType(type=TopType()),
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="var",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def head(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.column,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def tail(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.column,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
@@ -190,6 +355,21 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr):
|
||||
func_name: str = "__midas_column_same_length__"
|
||||
|
||||
# Efficiently compute length
|
||||
# https://stackoverflow.com/a/15943975/11109181
|
||||
def len_of_col(col: ast.expr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=ast.Name(id="len"),
|
||||
args=[
|
||||
ast.Attribute(
|
||||
value=col,
|
||||
attr="index",
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
@@ -207,16 +387,10 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=ast.Attribute(
|
||||
value=ast.Name(id="column1"),
|
||||
attr="size",
|
||||
),
|
||||
left=len_of_col(ast.Name(id="column1")),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="column2"),
|
||||
attr="size",
|
||||
)
|
||||
len_of_col(ast.Name(id="column2")),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,9 +5,15 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import FrameGroupBy, Function, Type
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
@@ -22,45 +28,76 @@ class Call:
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.groupby_expr, self.groupby)
|
||||
|
||||
|
||||
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
NAMED_ARGS: dict[str, str] = {
|
||||
"numeric_only": "bool",
|
||||
"skipna": "bool",
|
||||
"engine": "str",
|
||||
"engine_kwargs": "dict",
|
||||
}
|
||||
|
||||
def _aggregate(self, call: Call, method: str) -> Type:
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
|
||||
for column in call.groupby.frame.columns:
|
||||
column_groupby: ColumnGroupBy = ColumnGroupBy(column=column.type)
|
||||
result_type: Type = self.typer.call_method(
|
||||
location=call.location,
|
||||
call_expr=call.call_expr,
|
||||
obj=(call.groupby_expr, column_groupby),
|
||||
method_name=method,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
if not isinstance(result_type, ColumnType):
|
||||
result_type = ColumnType(type=UnknownType())
|
||||
new_columns.append(
|
||||
DataFrameType.Column(
|
||||
index=column.index,
|
||||
name=column.name,
|
||||
type=result_type,
|
||||
)
|
||||
)
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
@method()
|
||||
def kurt(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "kurt")
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "max")
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
bool_ = self.types.get_type("bool")
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="numeric_only",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="skipna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="engine",
|
||||
type=self.types.get_type("str"),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="engine_kwargs",
|
||||
type=self.types.get_type("dict"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.groupby.frame,
|
||||
)
|
||||
return self._aggregate(call, "mean")
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "median")
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "min")
|
||||
|
||||
@method()
|
||||
def prod(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "prod")
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "std")
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "sum")
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "var")
|
||||
|
||||
@@ -7,6 +7,7 @@ from midas.ast.location import Location
|
||||
from midas.checker.frames.frame_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry
|
||||
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
@@ -240,3 +241,15 @@ class FrameManager:
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
|
||||
def get_attribute(self, frame: DataFrameType, name: str) -> Optional[Type]:
|
||||
types: TypesRegistry = self.typer.types
|
||||
match name:
|
||||
case "ndim" | "size":
|
||||
return types.get_type("int")
|
||||
|
||||
case "shape":
|
||||
return types.tuple_of("int", "int")
|
||||
|
||||
case _:
|
||||
return None
|
||||
|
||||
@@ -33,6 +33,10 @@ class Call:
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.frame_expr, self.frame)
|
||||
|
||||
|
||||
class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
def _get_method_result(
|
||||
@@ -142,10 +146,8 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
# TODO: support add with scalar, sequence, Series, dict
|
||||
|
||||
def _element_wise(self, call: Call, method: str) -> Type:
|
||||
# TODO: support scalar, sequence, Series, dict operand
|
||||
# Build signature with new schema and generic operand
|
||||
signature = Function(
|
||||
args=[
|
||||
@@ -156,7 +158,7 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
returns=self._element_binary_op(call, "__add__"),
|
||||
returns=self._element_binary_op(call, method),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
@@ -173,8 +175,59 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__add__")
|
||||
|
||||
@method("sub", "__sub__")
|
||||
def sub(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__sub__")
|
||||
|
||||
@method("mul", "__mul__")
|
||||
def mul(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mul__")
|
||||
|
||||
@method("div", "truediv", "__truediv__")
|
||||
def truediv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__truediv__")
|
||||
|
||||
@method("floordiv", "__floordiv__")
|
||||
def floordiv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__floordiv__")
|
||||
|
||||
@method("mod", "__mod__")
|
||||
def mod(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mod__")
|
||||
|
||||
@method("pow", "__pow__")
|
||||
def pow(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__pow__")
|
||||
|
||||
@method("lt", "__lt__")
|
||||
def lt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__lt__")
|
||||
|
||||
@method("gt", "__gt__")
|
||||
def gt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__gt__")
|
||||
|
||||
@method("le", "__le__")
|
||||
def le(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__le__")
|
||||
|
||||
@method("ge", "__ge__")
|
||||
def ge(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ge__")
|
||||
|
||||
@method("ne", "__ne__")
|
||||
def ne(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ne__")
|
||||
|
||||
@method("eq", "__eq__")
|
||||
def eq(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__eq__")
|
||||
|
||||
def _aggregate(self, call: Call, kwargs: list[Function.Argument] = []) -> Type:
|
||||
with_axis = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
@@ -182,7 +235,8 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
name="axis",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
returns=ColumnType(type=TopType()),
|
||||
)
|
||||
@@ -193,7 +247,8 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
name="axis",
|
||||
type=self.types.get_type("None"),
|
||||
required=True,
|
||||
)
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
returns=TopType(),
|
||||
)
|
||||
@@ -212,6 +267,110 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method("kurtosis", "kurt")
|
||||
def kurtosis(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def mode(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method("product", "prod")
|
||||
def product(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="ddof",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="var",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def head(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.frame,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def tail(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.frame,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
@@ -275,6 +434,21 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
|
||||
func_name: str = "__midas_frame_same_length__"
|
||||
|
||||
# Efficiently compute length
|
||||
# https://stackoverflow.com/a/15943975/11109181
|
||||
def len_of_df(df: ast.expr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=ast.Name(id="len"),
|
||||
args=[
|
||||
ast.Attribute(
|
||||
value=df,
|
||||
attr="index",
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
@@ -292,17 +466,9 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=ast.Attribute(
|
||||
value=ast.Name(id="frame1"),
|
||||
attr="size",
|
||||
),
|
||||
left=len_of_df(ast.Name(id="frame1")),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="frame2"),
|
||||
attr="size",
|
||||
)
|
||||
],
|
||||
comparators=[len_of_df(ast.Name(id="frame2"))],
|
||||
)
|
||||
)
|
||||
],
|
||||
|
||||
@@ -20,7 +20,7 @@ from midas.checker.types import Type, UnknownType
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
class _MethodRegistryMeta(type):
|
||||
@@ -41,12 +41,18 @@ class _MethodRegistryMeta(type):
|
||||
return new_class
|
||||
|
||||
|
||||
class HasLocation(Protocol):
|
||||
class MethodCall(Protocol):
|
||||
@property
|
||||
def location(self) -> Location: ...
|
||||
|
||||
@property
|
||||
def call_expr(self) -> p.Expr: ...
|
||||
|
||||
T = TypeVar("T", bound=HasLocation)
|
||||
@property
|
||||
def subject(self) -> TypedExpr: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=MethodCall)
|
||||
|
||||
|
||||
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
|
||||
@@ -72,7 +78,9 @@ class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
|
||||
def call(self, method: str, call: T) -> Type:
|
||||
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
|
||||
if func is None:
|
||||
self.reporter.warning(call.location, f"Unknown method {method}")
|
||||
self.reporter.warning(
|
||||
call.location, f"Unknown method {method} on {call.subject[1]}"
|
||||
)
|
||||
return UnknownType()
|
||||
return func(self, call)
|
||||
|
||||
|
||||
@@ -108,8 +108,8 @@ class Preamble(Environment):
|
||||
],
|
||||
)
|
||||
|
||||
def _list_of(self, item_type: Type) -> Type:
|
||||
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
||||
def _list_of(self, item_type: str | Type) -> Type:
|
||||
return self._types.list_of(item_type)
|
||||
|
||||
def _def_type_constructor(
|
||||
self, name: str, py_function: Optional[Callable[..., Any]] = None
|
||||
|
||||
@@ -222,7 +222,7 @@ class PythonTyper(
|
||||
method_name: str,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Optional[Type]:
|
||||
) -> Type:
|
||||
unfolded: Type = unfold_type(obj[1])
|
||||
match unfolded:
|
||||
case DataFrameType():
|
||||
@@ -580,9 +580,8 @@ class PythonTyper(
|
||||
right: TypedExpr,
|
||||
method: str,
|
||||
) -> Type:
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(
|
||||
return self.call_method(
|
||||
location=location,
|
||||
call_expr=expr,
|
||||
obj=left,
|
||||
@@ -597,8 +596,6 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return result or UnknownType()
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
||||
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
@@ -610,9 +607,8 @@ class PythonTyper(
|
||||
|
||||
operand: Type = self.type_of(expr.right)
|
||||
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(
|
||||
return self.call_method(
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
obj=(expr.right, operand),
|
||||
@@ -627,8 +623,6 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return result or UnknownType()
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||
match expr.callee:
|
||||
case p.VariableExpr(name="TypeVar"):
|
||||
@@ -644,16 +638,13 @@ class PythonTyper(
|
||||
match expr.callee:
|
||||
case p.GetExpr(object=obj, name=method):
|
||||
obj_type: Type = self.type_of(obj)
|
||||
return (
|
||||
self.call_method(
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
obj=(obj, obj_type),
|
||||
method_name=method,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
or UnknownType()
|
||||
return self.call_method(
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
obj=(obj, obj_type),
|
||||
method_name=method,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
@@ -668,6 +659,14 @@ class PythonTyper(
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||
object: Type = self.type_of(expr.object)
|
||||
member: Optional[Type] = self.types.lookup_member(object, expr.name)
|
||||
|
||||
if member is None:
|
||||
match object:
|
||||
case DataFrameType():
|
||||
member = self.frame_mgr.get_attribute(object, expr.name)
|
||||
case ColumnType():
|
||||
member = self.column_mgr.get_attribute(object, expr.name)
|
||||
|
||||
if member is None:
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unknown member '{expr.name}' of {object}"
|
||||
|
||||
@@ -113,6 +113,15 @@ class TypesRegistry:
|
||||
raise ValueError(f"Predicate {name} already defined")
|
||||
self._predicates[name] = predicate
|
||||
|
||||
def is_builtin_subtype(self, name1: str, name2: str) -> bool:
|
||||
subtypes: set[str] = BUILTIN_SUBTYPES.get(name2, set())
|
||||
if name1 in subtypes:
|
||||
return True
|
||||
for subtype in subtypes:
|
||||
if self.is_builtin_subtype(name1, subtype):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
"""Check whether `type1` is a subtype of `type2`
|
||||
|
||||
@@ -150,7 +159,7 @@ class TypesRegistry:
|
||||
return self.is_subtype(base1, type2)
|
||||
|
||||
case (BaseType(name=name1), BaseType(name=name2)):
|
||||
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
||||
return self.is_builtin_subtype(name1, name2)
|
||||
|
||||
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||
for k, t in props2.items():
|
||||
@@ -443,3 +452,29 @@ class TypesRegistry:
|
||||
|
||||
def lookup_predicate(self, name: str) -> Optional[Predicate]:
|
||||
return self._predicates.get(name)
|
||||
|
||||
def _by_name_or_type(self, name_or_type: str | Type) -> Type:
|
||||
if isinstance(name_or_type, str):
|
||||
return self.get_type(name_or_type)
|
||||
return name_or_type
|
||||
|
||||
def list_of(self, item_type: str | Type) -> Type:
|
||||
list_ = self.get_type("list")
|
||||
return self.apply_generic(list_, [self._by_name_or_type(item_type)])
|
||||
|
||||
def tuple_of(self, *item_types: str | Type) -> Type:
|
||||
tuple_ = self.get_type("tuple")
|
||||
return self.apply_generic(
|
||||
tuple_,
|
||||
[self._by_name_or_type(item_type) for item_type in item_types],
|
||||
)
|
||||
|
||||
def dict_of(self, key_type: str | Type, value_type: str | Type) -> Type:
|
||||
dict_ = self.get_type("dict")
|
||||
return self.apply_generic(
|
||||
dict_,
|
||||
[
|
||||
self._by_name_or_type(key_type),
|
||||
self._by_name_or_type(value_type),
|
||||
],
|
||||
)
|
||||
|
||||
43
tests/__main__.py
Normal file
43
tests/__main__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Type
|
||||
|
||||
from midas.cli.ansi import Ansi
|
||||
from tests.base import Tester
|
||||
from tests.checker import CheckerTester
|
||||
from tests.generator import GeneratorTester
|
||||
from tests.midas import MidasTester
|
||||
from tests.python import PythonTester
|
||||
|
||||
|
||||
def print_banner(name: str):
|
||||
horizontal: str = "+" + "-" * (len(name) + 2) + "+"
|
||||
print(horizontal)
|
||||
print(f"| {name} |")
|
||||
print(horizontal)
|
||||
|
||||
|
||||
def run_tests(tester_cls: Type[Tester]) -> bool:
|
||||
print_banner(tester_cls.__name__)
|
||||
tester: Tester = tester_cls()
|
||||
success: bool = tester.run_all_tests()
|
||||
print()
|
||||
return success
|
||||
|
||||
|
||||
def main():
|
||||
testers: list[Type[Tester]] = [
|
||||
PythonTester,
|
||||
MidasTester,
|
||||
CheckerTester,
|
||||
GeneratorTester,
|
||||
]
|
||||
|
||||
success: bool = all(map(run_tests, testers))
|
||||
|
||||
if success:
|
||||
print(Ansi.FG(Ansi.BRIGHT_GREEN) + "All tests passed!" + Ansi.RESET)
|
||||
else:
|
||||
print(Ansi.FG(Ansi.BRIGHT_RED) + "Some tests failed!" + Ansi.RESET)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -7,6 +7,8 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Protocol
|
||||
|
||||
from midas.cli.ansi import Ansi
|
||||
|
||||
|
||||
class CaseResult(Protocol):
|
||||
def dumps(self) -> str: ...
|
||||
@@ -44,8 +46,11 @@ class Tester(ABC):
|
||||
|
||||
print(rule)
|
||||
for i, test in enumerate(tests):
|
||||
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
|
||||
path: Path = test.resolve().relative_to(self.CASES_DIR)
|
||||
print(f"{Ansi.FG(Ansi.BRIGHT_CYAN)}Case {i+1}/{n}: {path}{Ansi.RESET}")
|
||||
print(Ansi.DIM, end="")
|
||||
success: bool = self._run_test(test)
|
||||
print(Ansi.RESET, end="")
|
||||
if success:
|
||||
successes += 1
|
||||
else:
|
||||
@@ -146,8 +151,9 @@ class Tester(ABC):
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
case None:
|
||||
print("No subcommand provided. Available subcommands: run, update")
|
||||
sys.exit(1)
|
||||
success: bool = tester.run_all_tests()
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
case _:
|
||||
print(f"Unknown subcommand '{args.subcommand}'")
|
||||
sys.exit(1)
|
||||
|
||||
117
tests/cases/checker/09_frame_ops.py
Normal file
117
tests/cases/checker/09_frame_ops.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
df1: Frame[a:int, b:float]
|
||||
df2: Frame[a:int, b:float]
|
||||
|
||||
_: Any
|
||||
|
||||
# Arithmetic
|
||||
_ = df1 + df2
|
||||
_ = df1 - df2
|
||||
_ = df1 * df2
|
||||
_ = df1 / df2
|
||||
_ = df1 // df2
|
||||
_ = df1 % df2
|
||||
_ = df1**df2
|
||||
|
||||
# Comparisons
|
||||
_ = df1 < df2
|
||||
_ = df1 > df2
|
||||
_ = df1 <= df2
|
||||
_ = df1 >= df2
|
||||
_ = df1 != df2
|
||||
_ = df1 == df2
|
||||
|
||||
# Aggregate
|
||||
_ = df1.kurt()
|
||||
_ = df1.kurtosis()
|
||||
_ = df1.max()
|
||||
_ = df1.mean()
|
||||
_ = df1.median()
|
||||
_ = df1.min()
|
||||
_ = df1.mode()
|
||||
_ = df1.prod()
|
||||
_ = df1.product()
|
||||
_ = df1.std()
|
||||
_ = df1.sum()
|
||||
_ = df1.var()
|
||||
|
||||
# Groupby
|
||||
df_gb = df1.groupby(by="a")
|
||||
|
||||
_ = df_gb.kurt()
|
||||
_ = df_gb.max()
|
||||
_ = df_gb.mean()
|
||||
_ = df_gb.median()
|
||||
_ = df_gb.min()
|
||||
_ = df_gb.prod()
|
||||
_ = df_gb.std()
|
||||
_ = df_gb.sum()
|
||||
_ = df_gb.var()
|
||||
|
||||
|
||||
# Columns
|
||||
|
||||
col1 = df1["a"]
|
||||
col2 = df1["a"]
|
||||
|
||||
# Arithmetic
|
||||
_ = col1 + col2
|
||||
_ = col1 - col2
|
||||
_ = col1 * col2
|
||||
_ = col1 / col2
|
||||
_ = col1 // col2
|
||||
_ = col1 % col2
|
||||
_ = col1**col2
|
||||
|
||||
# Comparisons
|
||||
_ = col1 < col2
|
||||
_ = col1 > col2
|
||||
_ = col1 <= col2
|
||||
_ = col1 >= col2
|
||||
_ = col1 != col2
|
||||
_ = col1 == col2
|
||||
|
||||
# Aggregate
|
||||
_ = col1.kurt()
|
||||
_ = col1.kurtosis()
|
||||
_ = col1.max()
|
||||
_ = col1.mean()
|
||||
_ = col1.median()
|
||||
_ = col1.min()
|
||||
_ = col1.mode()
|
||||
_ = col1.prod()
|
||||
_ = col1.product()
|
||||
_ = col1.std()
|
||||
_ = col1.sum()
|
||||
_ = col1.var()
|
||||
|
||||
# Groupby
|
||||
col_gb = col1.groupby(level=0)
|
||||
|
||||
_ = col_gb.kurt()
|
||||
_ = col_gb.max()
|
||||
_ = col_gb.mean()
|
||||
_ = col_gb.median()
|
||||
_ = col_gb.min()
|
||||
_ = col_gb.prod()
|
||||
_ = col_gb.std()
|
||||
_ = col_gb.sum()
|
||||
_ = col_gb.var()
|
||||
|
||||
# Attributes
|
||||
_ = df1.ndim # int
|
||||
_ = df1.size # int
|
||||
_ = df1.shape # (int, int)
|
||||
_ = col1.ndim # int
|
||||
_ = col1.size # int
|
||||
_ = col1.shape # (int)
|
||||
_ = col1.T # Column[int]
|
||||
|
||||
|
||||
# Misc
|
||||
_ = df1.head()
|
||||
_ = df1.tail()
|
||||
_ = col1.head()
|
||||
_ = col1.tail()
|
||||
4924
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
4924
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user