Compare commits
28 Commits
be2fd4c837
...
feat/add-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
9764484fd9
|
|||
|
5b9e322c91
|
|||
|
c18d9c18de
|
|||
|
9229f00375
|
|||
|
6b7a682dc5
|
|||
|
35b97fd17b
|
|||
| 03bc32400b | |||
|
4a93ee45d9
|
|||
|
8197131d8d
|
|||
|
cf91187b7a
|
|||
|
1b2bdf0b79
|
|||
| c6cc38bfeb | |||
|
4d3e3f44a1
|
|||
|
ec80b1e92e
|
|||
|
4ea15519f3
|
|||
|
7a6e01cff8
|
|||
|
733c8736b8
|
|||
|
20173a0b07
|
|||
|
a143972ef1
|
|||
|
0c70048b62
|
|||
|
1c0c917873
|
|||
|
1f6189daa4
|
|||
|
66b585c3d6
|
|||
|
819ab3c2bf
|
|||
|
d8c0b17512
|
|||
|
6e06f9078e
|
|||
|
ece2e3a6a3
|
|||
|
74c07c9afb
|
@@ -198,10 +198,26 @@ python3 build/midas/script.py
|
||||
In this chapter, you will find a complete reference for the Midas definition language.
|
||||
|
||||
A `*.midas` file contains a number of statements, which can be:
|
||||
- *`alias`* statements (see @alias-stmt): to define a new type alias
|
||||
- *`type`* statements (see @type-stmt): to define a new type
|
||||
- *`extend`* statements (see @extend-stmt): to define member of a type
|
||||
- *`predicate`* statements (see @predicate-stmt): to define named predicates that can be used in constraint types
|
||||
|
||||
== Alias Statement <alias-stmt>
|
||||
|
||||
An *`alias`* statement lets you define a new type alias. It requires a unique name and base type.
|
||||
|
||||
While a `type` statement (see @type-stmt) allows generic definitions, aliases are purely a for givin an alternative name to a type.
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
alias MyType = float
|
||||
```,
|
||||
caption: [Simple `alias` statement declaring a new type "`MyType`" equivalent to `float`],
|
||||
) <midas-simple-alias>
|
||||
|
||||
This statement defines a new type called `MyType` which is equivalent to `float`. `MyType` and `float` can be used interchangeably.
|
||||
|
||||
== Type Statement <type-stmt>
|
||||
|
||||
A *`type`* statement lets you define a new type. It requires a unique name and base type.
|
||||
@@ -212,7 +228,7 @@ The simplest form of a *`type`* statement is:
|
||||
type MyType = float
|
||||
```,
|
||||
caption: [Simple `type` statement declaring a new type "`MyType`" as a subtype of `float`],
|
||||
) <midas-simple-alias>
|
||||
) <midas-simple-type>
|
||||
|
||||
This statement defines a new type called `MyType` which is a subtype of `float`. `MyType` is a `float` but a `float` is not necessarily `MyType`.
|
||||
|
||||
@@ -291,8 +307,7 @@ To better refine a generic type, you can also bound type parameters using the fo
|
||||
caption: [Generic container type definition with a bound],
|
||||
)
|
||||
|
||||
This can be read as "`Container` is a generic type which takes one type parameter `T` that must be a subtype of `float`".
|
||||
|
||||
This can be read as "`Container` is a generic type which takes one type parameter `T` that must be a subtype of `float`".\
|
||||
You can use a generic type, i.e. instantiate it, by using a similar syntax with concrete type as arguments:
|
||||
|
||||
#figure(
|
||||
@@ -318,6 +333,46 @@ The _body_ of a generic type, i.e. the right-hand side of the definition, can co
|
||||
caption: [Type parameters in a generic type's body],
|
||||
)
|
||||
|
||||
=== `Column` / `Frame` types
|
||||
|
||||
To provide useful type-checking for data engineers, Midas offers two special types: `Column` and `Frame`.
|
||||
Their goal is to help type check Pandas' `Series` and `DataFrame` respectively.
|
||||
|
||||
==== `Column`
|
||||
|
||||
The `Column` type is a generic type used to represent a `pandas.Series` object.
|
||||
You can use it like any other generic type and it will provide type checking for some common methods and attributes offered by Pandas.
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
type Temperature = float
|
||||
alias Temperatures = Column[Temperature]
|
||||
```,
|
||||
caption: [Simple column type definition],
|
||||
)
|
||||
|
||||
==== `Frame` <frame-type>
|
||||
|
||||
The `Frame` type is a super-powered generic type used to represent a `pandas.DataFrame` object.
|
||||
In place of type arguments, `Frame` accepts a schema, i.e. a series of column definitions.
|
||||
@simple-frame show how you can define a simple frame type with 3 columns:
|
||||
- `name`: a column of `Name` values
|
||||
- `age`: a column of `int` values
|
||||
- `height`: a column of `float where _ >= 0` values
|
||||
|
||||
Notice that you don't need to specify `Column` types.
|
||||
|
||||
#figure(
|
||||
```midas
|
||||
type Name = str where len(_) != 0
|
||||
alias Data = Frame[
|
||||
name: Name,
|
||||
age: int,
|
||||
height: float where _ >= 0
|
||||
]
|
||||
```,
|
||||
) <simple-frame>
|
||||
|
||||
#pagebreak()
|
||||
|
||||
== Extend Statement <extend-stmt>
|
||||
@@ -503,6 +558,7 @@ A simple annotation declaration, without assigning a value, is enough to declare
|
||||
)
|
||||
|
||||
Because unpacking is not supported, assigning to multiple values is also not handled by the type checker.
|
||||
For more information about type annotations, see @type-annotations
|
||||
|
||||
== Arithmetic
|
||||
|
||||
@@ -578,7 +634,7 @@ Conditional statements are checked relatively strictly by Midas. The test expres
|
||||
|
||||
Simple forms of `for` loops can be used, that is using a single variable and iterating over an object implementing the `__getitem__` method. Like above in @if-else, leaking variables from inside the loop is ignored.
|
||||
|
||||
The `for`-`else` statements are not supported. `while` loops are also not not supported.
|
||||
`for`-`else` statements are not supported. `while` loops are also not supported.
|
||||
|
||||
== Functions
|
||||
|
||||
@@ -686,6 +742,35 @@ There may be some cases where the cost of checking a value at runtime is simply
|
||||
|
||||
If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated.
|
||||
|
||||
== Annotations / Type Hints <type-annotations>
|
||||
|
||||
Vanilla Python already lets you use type hints to specify the type of variables and function parameters.
|
||||
|
||||
Midas use them to type check your code. Additionally, it allows you to use a special syntax to define a `Frame` types directly in these annotations.
|
||||
|
||||
Because these annotations are not interpretable by Python, your integrated type checker might complain loudly about them being invalid.
|
||||
A workaround is to silence it by adding a type comment at the end of the line, as shown in @silence-errors.
|
||||
|
||||
#figure(
|
||||
```python
|
||||
var: Frame[name: str, age: float] # type: ignore # noqa: F821
|
||||
```,
|
||||
caption: [MyPy's and Pylance's complaints about custom type annotation can be silenced with type comments],
|
||||
) <silence-errors>
|
||||
|
||||
=== Frame type annotation
|
||||
|
||||
The syntax is similar to how you can define frame types in the Midas language (see @frame-type). The only difference is that types can only be name references; you cannot inline constraint types.
|
||||
|
||||
The example of @python-frame-type shows how you can annotate a dataframe with some columns directly in Python.
|
||||
|
||||
#figure(
|
||||
```python
|
||||
df: Frame[name: Name, age: float, height: Length[Meter]] = ...
|
||||
```,
|
||||
caption: [Frame type annotation in Python],
|
||||
) <python-frame-type>
|
||||
|
||||
= Commands <commands>
|
||||
|
||||
#TODO
|
||||
|
||||
@@ -37,6 +37,9 @@ contexts:
|
||||
pop: true
|
||||
|
||||
keywords:
|
||||
- match: \balias\b
|
||||
scope: keyword.declaration.midas
|
||||
push: alias-stmt
|
||||
- match: \btype\b
|
||||
scope: keyword.declaration.midas
|
||||
push: type-stmt
|
||||
@@ -47,6 +50,15 @@ contexts:
|
||||
scope: keyword.declaration.midas
|
||||
push: predicate-stmt
|
||||
|
||||
alias-stmt:
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.type
|
||||
- match: "="
|
||||
scope: keyword.operator.equal.midas
|
||||
push: type-expr
|
||||
- match: $
|
||||
pop: true
|
||||
|
||||
type-stmt:
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.type
|
||||
@@ -67,6 +79,13 @@ contexts:
|
||||
- match: \b(where)\b
|
||||
scope: keyword.other.midas
|
||||
set: constraint
|
||||
- match: "Frame"
|
||||
scope: entity.name.type
|
||||
push:
|
||||
- match: \[
|
||||
push: frame-schema
|
||||
- match: $
|
||||
pop: true
|
||||
- match: "{{identifier}}"
|
||||
scope: entity.name.type
|
||||
- match: $
|
||||
@@ -178,3 +197,15 @@ contexts:
|
||||
|
||||
- match: '\)'
|
||||
pop: true
|
||||
|
||||
frame-schema:
|
||||
- include: frame-column
|
||||
- match: \]
|
||||
# scope: punctuation.section.block.end
|
||||
pop: true
|
||||
|
||||
frame-column:
|
||||
- match: "{{identifier}}"
|
||||
scope: variable.other.member
|
||||
- match: ":"
|
||||
push: type-expr
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
"""
|
||||
Helper script to generate AST nodes for Midas and Python.
|
||||
|
||||
Takes in simple templates and generates full dataclasses and a visitor interface
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -29,9 +29,9 @@ class MemberKind(Enum):
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
l_paren: Token
|
||||
pos: list[FunctionType.Argument]
|
||||
mixed: list[FunctionType.Argument]
|
||||
kw: list[FunctionType.Argument]
|
||||
pos: list[FunctionType.Parameter]
|
||||
mixed: list[FunctionType.Parameter]
|
||||
kw: list[FunctionType.Parameter]
|
||||
|
||||
|
||||
###<
|
||||
@@ -150,7 +150,7 @@ class FunctionType:
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
class Parameter:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[Token]
|
||||
type: Type
|
||||
|
||||
@@ -12,6 +12,21 @@ from midas.ast.location import Location
|
||||
###<
|
||||
|
||||
|
||||
###> Preamble
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
pos: list[Function.Parameter]
|
||||
mixed: list[Function.Parameter]
|
||||
kw: list[Function.Parameter]
|
||||
|
||||
@property
|
||||
def all(self) -> list[Function.Parameter]:
|
||||
return self.pos + self.mixed + self.kw
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> MidasType | Type annotations | node
|
||||
class BaseType:
|
||||
base: str
|
||||
@@ -42,25 +57,17 @@ class ExpressionStmt:
|
||||
|
||||
class Function:
|
||||
name: str
|
||||
posonlyargs: list[Argument]
|
||||
args: list[Argument]
|
||||
sink: Optional[Argument]
|
||||
kwonlyargs: list[Argument]
|
||||
kw_sink: Optional[Argument]
|
||||
params: ParamSpec
|
||||
returns: Optional[MidasType]
|
||||
body: list[Stmt]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
class Parameter:
|
||||
location: Optional[Location] = None
|
||||
name: str
|
||||
type: Optional[MidasType]
|
||||
default: Optional[Expr]
|
||||
|
||||
@property
|
||||
def all_args(self) -> list[Argument]:
|
||||
return self.posonlyargs + self.args + self.kwonlyargs
|
||||
|
||||
|
||||
class TypeAssign:
|
||||
name: str
|
||||
|
||||
@@ -13,6 +13,8 @@ class HasLocation(Protocol):
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Location:
|
||||
"""Information about the location of an AST node"""
|
||||
|
||||
lineno: int
|
||||
col_offset: int
|
||||
end_lineno: Optional[int]
|
||||
@@ -29,6 +31,16 @@ class Location:
|
||||
|
||||
@staticmethod
|
||||
def span(start: Location, end: Location) -> Location:
|
||||
"""Create a new location spanning from one location to another
|
||||
|
||||
Args:
|
||||
start (Location): the starting location
|
||||
end (Location): the end location
|
||||
|
||||
Returns:
|
||||
Location: a new location spanning from the start of `start`
|
||||
to the end of `end`
|
||||
"""
|
||||
return Location(
|
||||
lineno=start.lineno,
|
||||
col_offset=start.col_offset,
|
||||
|
||||
@@ -30,9 +30,9 @@ class MemberKind(Enum):
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
l_paren: Token
|
||||
pos: list[FunctionType.Argument]
|
||||
mixed: list[FunctionType.Argument]
|
||||
kw: list[FunctionType.Argument]
|
||||
pos: list[FunctionType.Parameter]
|
||||
mixed: list[FunctionType.Parameter]
|
||||
kw: list[FunctionType.Parameter]
|
||||
|
||||
|
||||
##############
|
||||
@@ -318,7 +318,7 @@ class FunctionType(Type):
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
class Parameter:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[Token]
|
||||
type: Type
|
||||
|
||||
@@ -1,896 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import io
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from typing import Generator, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
|
||||
|
||||
class _Level(Enum):
|
||||
EMPTY = auto()
|
||||
ACTIVE = auto()
|
||||
LAST = auto()
|
||||
|
||||
|
||||
class Expr(Protocol):
|
||||
def accept(self, printer: AstPrinter) -> None: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Expr)
|
||||
|
||||
|
||||
class AstPrinter(Generic[T]):
|
||||
LAST_CHILD = "└── "
|
||||
CHILD = "├── "
|
||||
VERTICAL = "│ "
|
||||
EMPTY = " "
|
||||
|
||||
def __init__(self):
|
||||
self._levels: list[_Level] = []
|
||||
self._idx: Optional[int] = None
|
||||
self._buf: io.StringIO = io.StringIO()
|
||||
|
||||
def print(self, expr: T):
|
||||
self._buf = io.StringIO()
|
||||
expr.accept(self)
|
||||
return self._buf.getvalue()
|
||||
|
||||
@contextmanager
|
||||
def _child_level(self, single: bool = False) -> Generator[None, None, None]:
|
||||
self._levels.append(_Level.LAST if single else _Level.ACTIVE)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._levels.pop()
|
||||
|
||||
def _mark_last(self):
|
||||
if self._levels:
|
||||
self._levels[-1] = _Level.LAST
|
||||
|
||||
def _write_line(self, text: str, *, last: bool = False):
|
||||
if last:
|
||||
self._mark_last()
|
||||
indent: str = self._build_indent()
|
||||
if self._idx is not None:
|
||||
text = f"[{self._idx}] {text}"
|
||||
self._idx = None
|
||||
self._buf.write(indent + text + "\n")
|
||||
|
||||
def _build_indent(self) -> str:
|
||||
parts: list[str] = []
|
||||
for level in self._levels[:-1]:
|
||||
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
|
||||
if self._levels:
|
||||
if self._levels[-1] == _Level.LAST:
|
||||
parts.append(self.LAST_CHILD)
|
||||
self._levels[-1] = _Level.EMPTY
|
||||
else:
|
||||
parts.append(self.CHILD)
|
||||
return "".join(parts)
|
||||
|
||||
def _write_optional_child(
|
||||
self, label: str, child: Optional[T], *, last: bool = False
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
if child is None:
|
||||
self._write_line(f"{label}: None")
|
||||
else:
|
||||
self._write_line(label)
|
||||
with self._child_level(single=True):
|
||||
child.accept(self)
|
||||
|
||||
|
||||
class MidasAstPrinter(
|
||||
AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None], m.Type.Visitor[None]
|
||||
):
|
||||
# Statements
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
self._write_line("TypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, param in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._print_type_param(param)
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
|
||||
self._write_line("AliasStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def _print_type_param(self, param: m.TypeParam) -> None:
|
||||
self._write_line("Param")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{param.name.lexeme}"')
|
||||
self._write_optional_child("bound", param.bound, last=True)
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||
self._write_line("MemberStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f"kind: {stmt.kind.name}")
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._write_line("ExtendStmt")
|
||||
with self._child_level():
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, param in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._print_type_param(param)
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, param in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._print_type_param(param)
|
||||
self._write_line("members", last=True)
|
||||
with self._child_level():
|
||||
for i, member in enumerate(stmt.members):
|
||||
self._idx = i
|
||||
if i == len(stmt.members) - 1:
|
||||
self._mark_last()
|
||||
member.accept(self)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
self._write_line("PredicateStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, spec in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._visit_param_spec(spec)
|
||||
|
||||
self._write_line("body", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.body.accept(self)
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
self._write_line("arguments")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(expr.arguments):
|
||||
self._idx = i
|
||||
if i == len(expr.arguments) - 1:
|
||||
self._mark_last()
|
||||
arg.accept(self)
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr")
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
self._write_line("GroupingExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"value: {expr.value}", last=True)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> None:
|
||||
self._write_line("NamedType")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{type.name.lexeme}"', last=True)
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self._write_line("GenericType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level():
|
||||
type.type.accept(self)
|
||||
self._write_line("args", last=True)
|
||||
with self._child_level():
|
||||
for i, param in enumerate(type.args):
|
||||
self._idx = i
|
||||
if i == len(type.args) - 1:
|
||||
self._mark_last()
|
||||
param.accept(self)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
type.type.accept(self)
|
||||
self._write_line("constraint", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.constraint.accept(self)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||
self._write_line("ComplexType")
|
||||
with self._child_level():
|
||||
self._write_line("members", last=True)
|
||||
with self._child_level():
|
||||
for i, member in enumerate(type.members):
|
||||
self._idx = i
|
||||
if i == len(type.members) - 1:
|
||||
self._mark_last()
|
||||
member.accept(self)
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> None:
|
||||
self._write_line("ExtensionType")
|
||||
with self._child_level():
|
||||
self._write_line("base")
|
||||
with self._child_level(single=True):
|
||||
type.base.accept(self)
|
||||
self._write_line("extension", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.extension.accept(self)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self._write_line("FunctionType")
|
||||
with self._child_level():
|
||||
self._write_line("params")
|
||||
with self._child_level(single=True):
|
||||
self._visit_param_spec(type.params)
|
||||
|
||||
self._write_line("returns", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.returns.accept(self)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
|
||||
self._write_line("ParamSpec")
|
||||
with self._child_level():
|
||||
self._write_line("pos")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.pos):
|
||||
self._idx = i
|
||||
if i == len(spec.pos) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("mixed")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.mixed):
|
||||
self._idx = i
|
||||
if i == len(spec.mixed) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("kw", last=True)
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(spec.kw):
|
||||
self._idx = i
|
||||
if i == len(spec.kw) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||
self._write_line("Argument")
|
||||
with self._child_level():
|
||||
name: str = "None"
|
||||
if arg.name is not None:
|
||||
name = f'"{arg.name.lexeme}"'
|
||||
self._write_line(f"name: {name}")
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
arg.type.accept(self)
|
||||
self._write_line(f"required: {arg.required}", last=True)
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> None:
|
||||
self._write_line("FrameType")
|
||||
with self._child_level(single=True):
|
||||
self._write_line("columns")
|
||||
with self._child_level():
|
||||
for i, column in enumerate(type.columns):
|
||||
self._idx = i
|
||||
if i == len(type.columns) - 1:
|
||||
self._mark_last()
|
||||
self._print_frame_column(column)
|
||||
|
||||
def _print_frame_column(self, column: m.FrameType.Column) -> None:
|
||||
self._write_line("Column")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{column.name.lexeme}"')
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
column.type.accept(self)
|
||||
|
||||
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||
def __init__(self, indent: int = 4):
|
||||
self.indent: int = indent
|
||||
self.level: int = 0
|
||||
|
||||
def indented(self, text: str) -> str:
|
||||
return " " * (self.level * self.indent) + text
|
||||
|
||||
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
|
||||
self.level = 0
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_alias_stmt(self, stmt: m.AliasStmt) -> str:
|
||||
return self.indented(f"alias {stmt.name.lexeme} = {stmt.type.accept(self)}")
|
||||
|
||||
def _print_type_param(self, param: m.TypeParam) -> str:
|
||||
res: str = param.name.lexeme
|
||||
if param.bound is not None:
|
||||
res += "<:" + param.bound.accept(self)
|
||||
return res
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||
keyword: str = {
|
||||
m.MemberKind.PROPERTY: "prop",
|
||||
m.MemberKind.METHOD: "def",
|
||||
}.get(stmt.kind, "")
|
||||
res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt):
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = self.indented(f"extend {stmt.name.lexeme}{template}")
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
for member in stmt.members:
|
||||
res += member.accept(self) + "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
|
||||
body: str = stmt.body.accept(self)
|
||||
return self.indented(f"predicate {name}{sig} = {body}")
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{operator}{right}"
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> str:
|
||||
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
|
||||
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
|
||||
]
|
||||
return f"{expr.callee.accept(self)}({', '.join(args)})"
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
name: str = expr.name.lexeme
|
||||
return f"{expr_}.{name}"
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
return expr.name.lexeme
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
return f"({expr_})"
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr):
|
||||
return str(expr.value)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
||||
return "_"
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> str:
|
||||
return type.name.lexeme
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
if len(type.args) != 0:
|
||||
args: list[str] = [param.accept(self) for param in type.args]
|
||||
res += f"[{', '.join(args)}]"
|
||||
return res
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
res += " where " + type.constraint.accept(self)
|
||||
return res
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> str:
|
||||
res: str = "{\n"
|
||||
self.level += 1
|
||||
for member in type.members:
|
||||
res += member.accept(self)
|
||||
res += "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> str:
|
||||
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||
spec: str = self._visit_param_spec(type.params)
|
||||
return f"fn {spec} -> {type.returns.accept(self)}"
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
|
||||
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
|
||||
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
|
||||
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
|
||||
args: list[str] = pos_args
|
||||
|
||||
if len(pos_args) != 0:
|
||||
args.append("/")
|
||||
args += mixed_args
|
||||
if len(kw_args) != 0:
|
||||
args.append("*")
|
||||
args += kw_args
|
||||
return f"({', '.join(args)})"
|
||||
|
||||
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||
res: str = ""
|
||||
if arg.name is not None:
|
||||
res += arg.name.lexeme
|
||||
res += ": "
|
||||
res += arg.type.accept(self)
|
||||
if not arg.required:
|
||||
res += "?"
|
||||
return res
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> str:
|
||||
res: str = self.indented("Frame[")
|
||||
if len(type.columns) != 0:
|
||||
res += "\n"
|
||||
self.level += 1
|
||||
columns: list[str] = []
|
||||
for column in type.columns:
|
||||
columns.append(self.indented(self._print_frame_column(column)))
|
||||
res += ",\n".join(columns)
|
||||
self.level -= 1
|
||||
res += "\n"
|
||||
res += "]"
|
||||
return res
|
||||
|
||||
def _print_frame_column(self, column: m.FrameType.Column) -> str:
|
||||
return f"{column.name.lexeme}: {column.type.accept(self)}"
|
||||
|
||||
|
||||
class PythonAstPrinter(
|
||||
AstPrinter,
|
||||
p.MidasType.Visitor[None],
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[None],
|
||||
):
|
||||
def visit_base_type(self, node: p.BaseType) -> None:
|
||||
self._write_line("BaseType")
|
||||
with self._child_level():
|
||||
self._write_line(f"base: {node.base}")
|
||||
self._write_line("args:", last=True)
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(node.args):
|
||||
self._idx = i
|
||||
if i == len(node.args) - 1:
|
||||
self._mark_last()
|
||||
arg.accept(self)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
node.type.accept(self)
|
||||
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> None:
|
||||
self._write_line("FrameColumn")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {node.name}")
|
||||
self._write_optional_child("type", node.type, last=True)
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> None:
|
||||
self._write_line("FrameType")
|
||||
with self._child_level():
|
||||
self._write_line("columns", last=True)
|
||||
with self._child_level():
|
||||
for i, col in enumerate(node.columns):
|
||||
self._idx = i
|
||||
if i == len(node.columns) - 1:
|
||||
self._mark_last()
|
||||
col.accept(self)
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
self._write_line("Function")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
|
||||
self._write_line("posonlyargs")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(stmt.posonlyargs):
|
||||
self._idx = i
|
||||
if i == len(stmt.posonlyargs) - 1:
|
||||
self._mark_last()
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_line("args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(stmt.args):
|
||||
self._idx = i
|
||||
if i == len(stmt.args) - 1:
|
||||
self._mark_last()
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_line("kwonlyargs")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(stmt.kwonlyargs):
|
||||
self._idx = i
|
||||
if i == len(stmt.kwonlyargs) - 1:
|
||||
self._mark_last()
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_optional_child("returns", stmt.returns)
|
||||
self._write_line("body", last=True)
|
||||
with self._child_level():
|
||||
for i, body_stmt in enumerate(stmt.body):
|
||||
self._idx = i
|
||||
if i == len(stmt.body) - 1:
|
||||
self._mark_last()
|
||||
body_stmt.accept(self)
|
||||
|
||||
def _print_argument(self, arg: p.Function.Argument) -> None:
|
||||
self._write_line("FunctionArgument")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {arg.name}")
|
||||
self._write_optional_child("type", arg.type, last=True)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
self._write_line("TypeAssign")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
self._write_line("AssignStmt")
|
||||
with self._child_level():
|
||||
self._write_line("targets")
|
||||
with self._child_level():
|
||||
for i, target in enumerate(stmt.targets):
|
||||
self._idx = i
|
||||
if i == len(stmt.targets) - 1:
|
||||
self._mark_last()
|
||||
target.accept(self)
|
||||
self._write_line("value", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
self._write_line("ReturnStmt")
|
||||
with self._child_level():
|
||||
self._write_optional_child("value", stmt.value, last=True)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
self._write_line("IfStmt")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
stmt.test.accept(self)
|
||||
self._write_line("body")
|
||||
with self._child_level():
|
||||
for i, body_stmt in enumerate(stmt.body):
|
||||
self._idx = i
|
||||
if i == len(stmt.body) - 1:
|
||||
self._mark_last()
|
||||
body_stmt.accept(self)
|
||||
self._write_line("orelse", last=True)
|
||||
with self._child_level():
|
||||
for i, else_stmt in enumerate(stmt.orelse):
|
||||
self._idx = i
|
||||
if i == len(stmt.orelse) - 1:
|
||||
self._mark_last()
|
||||
else_stmt.accept(self)
|
||||
|
||||
def visit_pass(self, stmt: p.Pass) -> None:
|
||||
self._write_line("Pass")
|
||||
|
||||
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
|
||||
self._write_line("ForStmt")
|
||||
with self._child_level():
|
||||
self._write_line("target")
|
||||
with self._child_level(single=True):
|
||||
stmt.target.accept(self)
|
||||
self._write_line("iterator")
|
||||
with self._child_level(single=True):
|
||||
stmt.iterator.accept(self)
|
||||
self._write_line("body", last=True)
|
||||
with self._child_level():
|
||||
for i, body_stmt in enumerate(stmt.body):
|
||||
self._idx = i
|
||||
if i == len(stmt.body) - 1:
|
||||
self._mark_last()
|
||||
body_stmt.accept(self)
|
||||
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
|
||||
self._write_line("RawStmt")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"stmt: {ast.unparse(stmt.stmt)}")
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||
self._write_line("CompareExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
|
||||
self._write_line("arguments")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(expr.arguments):
|
||||
self._idx = i
|
||||
if i == len(expr.arguments) - 1:
|
||||
self._mark_last()
|
||||
arg.accept(self)
|
||||
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line(f"name: {expr.name}", last=True)
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"value: {expr.value!r}")
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"name: {expr.name}")
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self._write_line("CastExpr")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
expr.type.accept(self)
|
||||
self._write_line("expr")
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
self._write_line(f"unsafe: {expr.unsafe}", last=True)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||
self._write_line("TernaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
expr.test.accept(self)
|
||||
|
||||
self._write_line("if_true")
|
||||
with self._child_level(single=True):
|
||||
expr.if_true.accept(self)
|
||||
|
||||
self._write_line("if_false", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.if_false.accept(self)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||
self._write_line("ListExpr")
|
||||
with self._child_level():
|
||||
self._write_line("items", last=True)
|
||||
with self._child_level():
|
||||
for i, item in enumerate(expr.items):
|
||||
self._idx = i
|
||||
if i == len(expr.items) - 1:
|
||||
self._mark_last()
|
||||
item.accept(self)
|
||||
|
||||
def visit_dict_expr(self, expr: p.DictExpr) -> None:
|
||||
self._write_line("DictExpr")
|
||||
with self._child_level():
|
||||
self._write_line("keys")
|
||||
with self._child_level():
|
||||
for i, key in enumerate(expr.keys):
|
||||
self._idx = i
|
||||
if i == len(expr.keys) - 1:
|
||||
self._mark_last()
|
||||
if key is None:
|
||||
self._write_line("None")
|
||||
else:
|
||||
key.accept(self)
|
||||
self._write_line("values", last=True)
|
||||
with self._child_level():
|
||||
for i, value in enumerate(expr.values):
|
||||
self._idx = i
|
||||
if i == len(expr.values) - 1:
|
||||
self._mark_last()
|
||||
value.accept(self)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||
self._write_line("SubscriptExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line("index", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.index.accept(self)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
|
||||
self._write_line("SliceExpr")
|
||||
with self._child_level():
|
||||
self._write_optional_child("lower", expr.lower)
|
||||
self._write_optional_child("upper", expr.upper)
|
||||
self._write_optional_child("step", expr.step, last=True)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||
self._write_line("TupleExpr")
|
||||
with self._child_level():
|
||||
self._write_line("items", last=True)
|
||||
with self._child_level():
|
||||
for i, item in enumerate(expr.items):
|
||||
self._idx = i
|
||||
if i == len(expr.items) - 1:
|
||||
self._mark_last()
|
||||
item.accept(self)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||
self._write_line("RawExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"expr: {ast.unparse(expr.expr)}")
|
||||
3
midas/ast/printer/__init__.py
Normal file
3
midas/ast/printer/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .midas import MidasPrinter as MidasPrinter
|
||||
from .midas_ast import MidasAstPrinter as MidasAstPrinter
|
||||
from .python_ast import PythonAstPrinter as PythonAstPrinter
|
||||
103
midas/ast/printer/base.py
Normal file
103
midas/ast/printer/base.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, Generator, Generic, Optional, Protocol, Sequence, TypeVar
|
||||
|
||||
|
||||
class _Level(Enum):
|
||||
EMPTY = auto()
|
||||
ACTIVE = auto()
|
||||
LAST = auto()
|
||||
|
||||
|
||||
class Expr(Protocol):
|
||||
def accept(self, printer: AstPrinter) -> None: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Expr)
|
||||
|
||||
|
||||
class AstPrinter(Generic[T]):
|
||||
LAST_CHILD = "└── "
|
||||
CHILD = "├── "
|
||||
VERTICAL = "│ "
|
||||
EMPTY = " "
|
||||
|
||||
def __init__(self):
|
||||
self._levels: list[_Level] = []
|
||||
self._idx: Optional[int] = None
|
||||
self._buf: io.StringIO = io.StringIO()
|
||||
|
||||
def print(self, expr: T):
|
||||
self._buf = io.StringIO()
|
||||
expr.accept(self)
|
||||
return self._buf.getvalue()
|
||||
|
||||
@contextmanager
|
||||
def _child_level(self, single: bool = False) -> Generator[None, None, None]:
|
||||
self._levels.append(_Level.LAST if single else _Level.ACTIVE)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._levels.pop()
|
||||
|
||||
def _mark_last(self):
|
||||
if self._levels:
|
||||
self._levels[-1] = _Level.LAST
|
||||
|
||||
def _write_line(self, text: str, *, last: bool = False):
|
||||
if last:
|
||||
self._mark_last()
|
||||
indent: str = self._build_indent()
|
||||
if self._idx is not None:
|
||||
text = f"[{self._idx}] {text}"
|
||||
self._idx = None
|
||||
self._buf.write(indent + text + "\n")
|
||||
|
||||
def _build_indent(self) -> str:
|
||||
parts: list[str] = []
|
||||
for level in self._levels[:-1]:
|
||||
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
|
||||
if self._levels:
|
||||
if self._levels[-1] == _Level.LAST:
|
||||
parts.append(self.LAST_CHILD)
|
||||
self._levels[-1] = _Level.EMPTY
|
||||
else:
|
||||
parts.append(self.CHILD)
|
||||
return "".join(parts)
|
||||
|
||||
def _write_optional_child(
|
||||
self, label: str, child: Optional[T], *, last: bool = False
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
if child is None:
|
||||
self._write_line(f"{label}: None")
|
||||
else:
|
||||
self._write_line(label)
|
||||
with self._child_level(single=True):
|
||||
child.accept(self)
|
||||
|
||||
def _write_sequence(
|
||||
self,
|
||||
label: str,
|
||||
list_: Sequence[T],
|
||||
*,
|
||||
last: bool = False,
|
||||
print_func: Optional[Callable[[T], None]] = None,
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
|
||||
self._write_line(label)
|
||||
with self._child_level():
|
||||
for i, item in enumerate(list_):
|
||||
self._idx = i
|
||||
if i == len(list_) - 1:
|
||||
self._mark_last()
|
||||
if print_func is not None:
|
||||
print_func(item)
|
||||
else:
|
||||
item.accept(self)
|
||||
183
midas/ast/printer/midas.py
Normal file
183
midas/ast/printer/midas.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import midas.ast.midas as m
|
||||
|
||||
|
||||
class MidasPrinter(
|
||||
m.Expr.Visitor[str],
|
||||
m.Stmt.Visitor[str],
|
||||
m.Type.Visitor[str],
|
||||
):
|
||||
def __init__(self, indent: int = 4):
|
||||
self.indent: int = indent
|
||||
self.level: int = 0
|
||||
|
||||
def indented(self, text: str) -> str:
|
||||
return " " * (self.level * self.indent) + text
|
||||
|
||||
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
|
||||
self.level = 0
|
||||
return expr.accept(self)
|
||||
|
||||
# Statements
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_alias_stmt(self, stmt: m.AliasStmt) -> str:
|
||||
return self.indented(f"alias {stmt.name.lexeme} = {stmt.type.accept(self)}")
|
||||
|
||||
def _print_type_param(self, param: m.TypeParam) -> str:
|
||||
res: str = param.name.lexeme
|
||||
if param.bound is not None:
|
||||
res += "<:" + param.bound.accept(self)
|
||||
return res
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||
keyword: str = {
|
||||
m.MemberKind.PROPERTY: "prop",
|
||||
m.MemberKind.METHOD: "def",
|
||||
}.get(stmt.kind, "")
|
||||
res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt):
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = self.indented(f"extend {stmt.name.lexeme}{template}")
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
for member in stmt.members:
|
||||
res += member.accept(self) + "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
|
||||
body: str = stmt.body.accept(self)
|
||||
return self.indented(f"predicate {name}{sig} = {body}")
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{operator}{right}"
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> str:
|
||||
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
|
||||
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
|
||||
]
|
||||
return f"{expr.callee.accept(self)}({', '.join(args)})"
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
name: str = expr.name.lexeme
|
||||
return f"{expr_}.{name}"
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
return expr.name.lexeme
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
return f"({expr_})"
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr):
|
||||
return str(expr.value)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
||||
return "_"
|
||||
|
||||
# Types
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> str:
|
||||
return type.name.lexeme
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
if len(type.args) != 0:
|
||||
args: list[str] = [param.accept(self) for param in type.args]
|
||||
res += f"[{', '.join(args)}]"
|
||||
return res
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
res += " where " + type.constraint.accept(self)
|
||||
return res
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> str:
|
||||
res: str = "{\n"
|
||||
self.level += 1
|
||||
for member in type.members:
|
||||
res += member.accept(self)
|
||||
res += "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> str:
|
||||
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||
spec: str = self._visit_param_spec(type.params)
|
||||
return f"fn {spec} -> {type.returns.accept(self)}"
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
|
||||
pos: list[str] = [self._print_param(param) for param in spec.pos]
|
||||
mixed: list[str] = [self._print_param(param) for param in spec.mixed]
|
||||
kw: list[str] = [self._print_param(param) for param in spec.kw]
|
||||
params: list[str] = pos
|
||||
|
||||
if len(pos) != 0:
|
||||
params.append("/")
|
||||
params += mixed
|
||||
if len(kw) != 0:
|
||||
params.append("*")
|
||||
params += kw
|
||||
return f"({', '.join(params)})"
|
||||
|
||||
def _print_param(self, param: m.FunctionType.Parameter) -> str:
|
||||
res: str = ""
|
||||
if param.name is not None:
|
||||
res += param.name.lexeme
|
||||
res += ": "
|
||||
res += param.type.accept(self)
|
||||
if not param.required:
|
||||
res += "?"
|
||||
return res
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> str:
|
||||
res: str = self.indented("Frame[")
|
||||
if len(type.columns) != 0:
|
||||
res += "\n"
|
||||
self.level += 1
|
||||
columns: list[str] = []
|
||||
for column in type.columns:
|
||||
columns.append(self.indented(self._print_frame_column(column)))
|
||||
res += ",\n".join(columns)
|
||||
self.level -= 1
|
||||
res += "\n"
|
||||
res += "]"
|
||||
return res
|
||||
|
||||
def _print_frame_column(self, column: m.FrameType.Column) -> str:
|
||||
return f"{column.name.lexeme}: {column.type.accept(self)}"
|
||||
253
midas/ast/printer/midas_ast.py
Normal file
253
midas/ast/printer/midas_ast.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import midas.ast.midas as m
|
||||
from midas.ast.printer.base import AstPrinter
|
||||
|
||||
|
||||
class MidasAstPrinter(
|
||||
AstPrinter,
|
||||
m.Expr.Visitor[None],
|
||||
m.Stmt.Visitor[None],
|
||||
m.Type.Visitor[None],
|
||||
):
|
||||
# Statements
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
self._write_line("TypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_sequence(
|
||||
"params",
|
||||
stmt.params,
|
||||
print_func=self._print_type_param,
|
||||
)
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
|
||||
self._write_line("AliasStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def _print_type_param(self, param: m.TypeParam) -> None:
|
||||
self._write_line("Param")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{param.name.lexeme}"')
|
||||
self._write_optional_child("bound", param.bound, last=True)
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||
self._write_line("MemberStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f"kind: {stmt.kind.name}")
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._write_line("ExtendStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_sequence(
|
||||
"params",
|
||||
stmt.params,
|
||||
print_func=self._print_type_param,
|
||||
)
|
||||
self._write_sequence("members", stmt.members, last=True)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
self._write_line("PredicateStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_sequence(
|
||||
"params",
|
||||
stmt.params,
|
||||
print_func=self._visit_param_spec,
|
||||
)
|
||||
self._write_line("body", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.body.accept(self)
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
self._write_sequence("arguments", expr.arguments)
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr")
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
self._write_line("GroupingExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"value: {expr.value}", last=True)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
# Types
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> None:
|
||||
self._write_line("NamedType")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{type.name.lexeme}"', last=True)
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self._write_line("GenericType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level():
|
||||
type.type.accept(self)
|
||||
self._write_sequence("args", type.args, last=True)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
type.type.accept(self)
|
||||
self._write_line("constraint", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.constraint.accept(self)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||
self._write_line("ComplexType")
|
||||
with self._child_level():
|
||||
self._write_sequence("members", type.members, last=True)
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> None:
|
||||
self._write_line("ExtensionType")
|
||||
with self._child_level():
|
||||
self._write_line("base")
|
||||
with self._child_level(single=True):
|
||||
type.base.accept(self)
|
||||
self._write_line("extension", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.extension.accept(self)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self._write_line("FunctionType")
|
||||
with self._child_level():
|
||||
self._write_line("params")
|
||||
with self._child_level(single=True):
|
||||
self._visit_param_spec(type.params)
|
||||
|
||||
self._write_line("returns", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.returns.accept(self)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
|
||||
self._write_line("ParamSpec")
|
||||
with self._child_level():
|
||||
self._write_sequence(
|
||||
"pos",
|
||||
spec.pos,
|
||||
print_func=self._print_param,
|
||||
)
|
||||
self._write_sequence(
|
||||
"mixed",
|
||||
spec.mixed,
|
||||
print_func=self._print_param,
|
||||
)
|
||||
self._write_sequence(
|
||||
"kw",
|
||||
spec.kw,
|
||||
print_func=self._print_param,
|
||||
last=True,
|
||||
)
|
||||
|
||||
def _print_param(self, param: m.FunctionType.Parameter) -> None:
|
||||
self._write_line("Parameter")
|
||||
with self._child_level():
|
||||
name: str = "None"
|
||||
if param.name is not None:
|
||||
name = f'"{param.name.lexeme}"'
|
||||
self._write_line(f"name: {name}")
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
param.type.accept(self)
|
||||
self._write_line(f"required: {param.required}", last=True)
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> None:
|
||||
self._write_line("FrameType")
|
||||
with self._child_level(single=True):
|
||||
self._write_sequence(
|
||||
"columns",
|
||||
type.columns,
|
||||
print_func=self._print_frame_column,
|
||||
)
|
||||
|
||||
def _print_frame_column(self, column: m.FrameType.Column) -> None:
|
||||
self._write_line("Column")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{column.name.lexeme}"')
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
column.type.accept(self)
|
||||
285
midas/ast/printer/python_ast.py
Normal file
285
midas/ast/printer/python_ast.py
Normal file
@@ -0,0 +1,285 @@
|
||||
import ast
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.printer.base import AstPrinter
|
||||
|
||||
|
||||
class PythonAstPrinter(
|
||||
AstPrinter,
|
||||
p.MidasType.Visitor[None],
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[None],
|
||||
):
|
||||
# Types
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> None:
|
||||
self._write_line("BaseType")
|
||||
with self._child_level():
|
||||
self._write_line(f"base: {node.base}")
|
||||
self._write_sequence("args", node.args, last=True)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
node.type.accept(self)
|
||||
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> None:
|
||||
self._write_line("FrameColumn")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {node.name}")
|
||||
self._write_optional_child("type", node.type, last=True)
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> None:
|
||||
self._write_line("FrameType")
|
||||
with self._child_level(single=True):
|
||||
self._write_sequence("columns", node.columns)
|
||||
|
||||
# Statements
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
self._write_line("Function")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
self._print_param_spec(stmt.params)
|
||||
|
||||
self._write_optional_child("returns", stmt.returns)
|
||||
self._write_sequence("body", stmt.body, last=True)
|
||||
|
||||
def _print_param_spec(self, spec: p.ParamSpec) -> None:
|
||||
self._write_line("ParamSpec")
|
||||
with self._child_level():
|
||||
self._write_sequence(
|
||||
"pos",
|
||||
spec.pos,
|
||||
print_func=self._print_param,
|
||||
)
|
||||
self._write_sequence(
|
||||
"mixed",
|
||||
spec.mixed,
|
||||
print_func=self._print_param,
|
||||
)
|
||||
self._write_sequence(
|
||||
"kw",
|
||||
spec.kw,
|
||||
print_func=self._print_param,
|
||||
last=True,
|
||||
)
|
||||
|
||||
def _print_param(self, param: p.Function.Parameter) -> None:
|
||||
self._write_line("Parameter")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {param.name}")
|
||||
self._write_optional_child("type", param.type, last=True)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
self._write_line("TypeAssign")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
self._write_line("AssignStmt")
|
||||
with self._child_level():
|
||||
self._write_sequence("targets", stmt.targets)
|
||||
self._write_line("value", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
self._write_line("ReturnStmt")
|
||||
with self._child_level():
|
||||
self._write_optional_child("value", stmt.value, last=True)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
self._write_line("IfStmt")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
stmt.test.accept(self)
|
||||
self._write_sequence("body", stmt.body)
|
||||
self._write_sequence("orelse", stmt.orelse, last=True)
|
||||
|
||||
def visit_pass(self, stmt: p.Pass) -> None:
|
||||
self._write_line("Pass")
|
||||
|
||||
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
|
||||
self._write_line("ForStmt")
|
||||
with self._child_level():
|
||||
self._write_line("target")
|
||||
with self._child_level(single=True):
|
||||
stmt.target.accept(self)
|
||||
self._write_line("iterator")
|
||||
with self._child_level(single=True):
|
||||
stmt.iterator.accept(self)
|
||||
self._write_sequence("body", stmt.body, last=True)
|
||||
|
||||
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
|
||||
self._write_line("RawStmt")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"stmt: {ast.unparse(stmt.stmt)}")
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||
self._write_line("CompareExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
|
||||
self._write_sequence("arguments", expr.arguments)
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line(f"name: {expr.name}", last=True)
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"value: {expr.value!r}")
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"name: {expr.name}")
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self._write_line("CastExpr")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
expr.type.accept(self)
|
||||
self._write_line("expr")
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
self._write_line(f"unsafe: {expr.unsafe}", last=True)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||
self._write_line("TernaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
expr.test.accept(self)
|
||||
|
||||
self._write_line("if_true")
|
||||
with self._child_level(single=True):
|
||||
expr.if_true.accept(self)
|
||||
|
||||
self._write_line("if_false", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.if_false.accept(self)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||
self._write_line("ListExpr")
|
||||
with self._child_level():
|
||||
self._write_sequence("items", expr.items, last=True)
|
||||
|
||||
def visit_dict_expr(self, expr: p.DictExpr) -> None:
|
||||
self._write_line("DictExpr")
|
||||
with self._child_level():
|
||||
self._write_sequence(
|
||||
"keys",
|
||||
expr.keys,
|
||||
print_func=lambda k: (
|
||||
self._write_line("None") if k is None else k.accept(self)
|
||||
),
|
||||
)
|
||||
self._write_sequence("values", expr.values, last=True)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||
self._write_line("SubscriptExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line("index", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.index.accept(self)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
|
||||
self._write_line("SliceExpr")
|
||||
with self._child_level():
|
||||
self._write_optional_child("lower", expr.lower)
|
||||
self._write_optional_child("upper", expr.upper)
|
||||
self._write_optional_child("step", expr.step, last=True)
|
||||
|
||||
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||
self._write_line("TupleExpr")
|
||||
with self._child_level():
|
||||
self._write_sequence("items", expr.items, last=True)
|
||||
|
||||
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||
self._write_line("RawExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"expr: {ast.unparse(expr.expr)}")
|
||||
@@ -14,6 +14,16 @@ from midas.ast.location import Location
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
pos: list[Function.Parameter]
|
||||
mixed: list[Function.Parameter]
|
||||
kw: list[Function.Parameter]
|
||||
|
||||
@property
|
||||
def all(self) -> list[Function.Parameter]:
|
||||
return self.pos + self.mixed + self.kw
|
||||
|
||||
|
||||
####################
|
||||
# Type annotations #
|
||||
@@ -128,25 +138,17 @@ class ExpressionStmt(Stmt):
|
||||
@dataclass(frozen=True)
|
||||
class Function(Stmt):
|
||||
name: str
|
||||
posonlyargs: list[Argument]
|
||||
args: list[Argument]
|
||||
sink: Optional[Argument]
|
||||
kwonlyargs: list[Argument]
|
||||
kw_sink: Optional[Argument]
|
||||
params: ParamSpec
|
||||
returns: Optional[MidasType]
|
||||
body: list[Stmt]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
class Parameter:
|
||||
location: Optional[Location] = None
|
||||
name: str
|
||||
type: Optional[MidasType]
|
||||
default: Optional[Expr]
|
||||
|
||||
@property
|
||||
def all_args(self) -> list[Argument]:
|
||||
return self.posonlyargs + self.args + self.kwonlyargs
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_function(self)
|
||||
|
||||
|
||||
@@ -17,8 +17,12 @@ if TYPE_CHECKING:
|
||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
|
||||
"float": {"int"},
|
||||
"int": {"bool"},
|
||||
}
|
||||
"""
|
||||
Hard-coded subtype relationships between builtin types
|
||||
|
||||
Circular dependencies and diamond inheritance MUST be avoided
|
||||
"""
|
||||
|
||||
|
||||
def define_builtins(reg: TypesRegistry):
|
||||
|
||||
@@ -10,6 +10,11 @@ from midas.utils import TypedAST
|
||||
|
||||
|
||||
class TypeChecker:
|
||||
"""Type checking dispatcher
|
||||
|
||||
Contains a typer for Midas and one for Python, as well as the types registry
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.types: TypesRegistry = TypesRegistry()
|
||||
self.reporter: Reporter = Reporter()
|
||||
|
||||
@@ -14,6 +14,15 @@ class DiagnosticType(StrEnum):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Diagnostic:
|
||||
"""Information about a diagnostic (warning, errors, etc.)
|
||||
|
||||
Holds a location, a diagnostic type and a message.
|
||||
Optionally bound to a file path
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
file_path: Optional[str]
|
||||
location: Location
|
||||
type: DiagnosticType
|
||||
@@ -21,6 +30,18 @@ class Diagnostic:
|
||||
|
||||
@property
|
||||
def location_str(self) -> str:
|
||||
"""The diagnostic type and location as a human readable string
|
||||
|
||||
The location is formatted as "<Type> in <file> from L<start_line>:<start_col> to <end_line>:<end_col>",
|
||||
for example: "Error in /home/user/Desktop/script.py from L12:5 to L12:8"
|
||||
|
||||
If the file is `None`, the "in ..." section is excluded from the result.<br>
|
||||
If the location's end is not specified, the formulation "at L<start_line>:<start_col>" is used.
|
||||
|
||||
Returns:
|
||||
str: _description_
|
||||
"""
|
||||
|
||||
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
|
||||
end_loc: Optional[str] = ""
|
||||
if (
|
||||
|
||||
@@ -30,9 +30,9 @@ TypedExpr = tuple[E, Type]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument(Generic[E]):
|
||||
expr: E
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
arg_expr: E
|
||||
arg_type: Type
|
||||
parameter: Function.Parameter
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@@ -219,11 +219,11 @@ class CallDispatcher(Generic[E]):
|
||||
"""
|
||||
valid: bool = True
|
||||
for arg in arguments:
|
||||
if not self.types.is_subtype(arg.type, arg.argument.type):
|
||||
if not self.types.is_subtype(arg.arg_type, arg.parameter.type):
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
arg.arg_expr.location,
|
||||
f"Wrong type for argument '{arg.parameter.name}', expected {arg.parameter.type}, got {arg.arg_type}",
|
||||
)
|
||||
valid = False
|
||||
return valid
|
||||
@@ -347,28 +347,30 @@ class CallDispatcher(Generic[E]):
|
||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||
the call is valid and the list of mapped arguments
|
||||
"""
|
||||
set_args: set[str] = set()
|
||||
set_params: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
param.name
|
||||
for param in function.params.pos + function.params.mixed
|
||||
if param.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
param.name for param in function.params.kw if param.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument[E]] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
pos_params: list[Function.Parameter] = list(function.params.pos)
|
||||
mixed_params: list[Function.Parameter] = list(function.params.mixed)
|
||||
kw_params: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in function.params.kw
|
||||
}
|
||||
|
||||
valid_call: bool = True
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
param: Function.Parameter
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
@@ -385,27 +387,27 @@ class CallDispatcher(Generic[E]):
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
set_params.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
arg_expr=arg[0],
|
||||
arg_type=arg[1],
|
||||
parameter=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
kw_params.update({param.name: param for param in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
param: Function.Parameter
|
||||
if name not in kw_params:
|
||||
if report_errors:
|
||||
if name in set_args:
|
||||
if name in set_params:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
arg[0].location, f"Multiple values for parameter '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||
arg[0].location, f"Unknown keyword parameter '{name}'"
|
||||
)
|
||||
valid_call = False
|
||||
continue
|
||||
@@ -414,40 +416,40 @@ class CallDispatcher(Generic[E]):
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
set_params.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
arg_expr=arg[0],
|
||||
arg_type=arg[1],
|
||||
parameter=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
def join_params(params: list[str]) -> str:
|
||||
params = list(map(lambda p: f"'{p}'", params))
|
||||
if len(params) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
if len(params) == 1:
|
||||
return params[0]
|
||||
return ", ".join(params[:-1]) + " and " + params[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
params: str = join_params(required_positional)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
f"Missing required positional argument{plural}: {params}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
params: str = join_params(required_keyword)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
f"Missing required keyword argument{plural}: {params}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
@@ -474,11 +476,11 @@ class CallDispatcher(Generic[E]):
|
||||
"""
|
||||
by_expr: dict[E, Type] = {}
|
||||
for arg in mapped1:
|
||||
by_expr[arg.expr] = arg.argument.type
|
||||
by_expr[arg.arg_expr] = arg.parameter.type
|
||||
|
||||
for arg in mapped2:
|
||||
type2: Type = arg.argument.type
|
||||
type1: Type = by_expr[arg.expr]
|
||||
type2: Type = arg.parameter.type
|
||||
type1: Type = by_expr[arg.arg_expr]
|
||||
if not self.types.is_subtype(type1, type2):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -158,15 +158,17 @@ class Evaluator(m.Expr.Visitor[Any]):
|
||||
return res
|
||||
|
||||
def _map_args(self, function: Function, args: list[Any], kwargs: dict[str, Any]):
|
||||
positional: list[Function.Argument] = function.pos_args + function.args
|
||||
keywords: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.args + function.kw_args
|
||||
positional: list[Function.Parameter] = (
|
||||
function.params.pos + function.params.mixed
|
||||
)
|
||||
keywords: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in function.params.mixed + function.params.kw
|
||||
}
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
param: Function.Argument = positional[i]
|
||||
param: Function.Parameter = positional[i]
|
||||
self.set_value(param.name, arg)
|
||||
|
||||
for name, arg in kwargs.items():
|
||||
param: Function.Argument = keywords[name]
|
||||
param: Function.Parameter = keywords[name]
|
||||
self.set_value(param.name, arg)
|
||||
|
||||
@@ -7,7 +7,14 @@ import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import ColumnGroupBy, Function, Type
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
Function,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
@@ -22,39 +29,52 @@ class Call:
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.groupby_expr, self.groupby)
|
||||
|
||||
|
||||
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
bool_ = self.types.get_type("bool")
|
||||
NAMED_ARGS: dict[str, str] = {
|
||||
"numeric_only": "bool",
|
||||
"skipna": "bool",
|
||||
"engine": "str",
|
||||
"engine_kwargs": "dict",
|
||||
}
|
||||
|
||||
def _aggregate(
|
||||
self,
|
||||
call: Call,
|
||||
params: list[str | tuple[str, str, bool]] = [],
|
||||
*,
|
||||
preserve_inner_type: bool = False,
|
||||
) -> Type:
|
||||
real_params: list[Function.Parameter] = []
|
||||
for i, param in enumerate(params):
|
||||
match param:
|
||||
case str() as name:
|
||||
param = Function.Parameter(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(self.NAMED_ARGS[name]),
|
||||
required=False,
|
||||
)
|
||||
case (name, type, required):
|
||||
param = Function.Parameter(
|
||||
pos=i,
|
||||
name=name,
|
||||
type=self.types.get_type(type),
|
||||
required=required,
|
||||
)
|
||||
real_params.append(param)
|
||||
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="numeric_only",
|
||||
type=bool_,
|
||||
required=False,
|
||||
params=ParamSpec(mixed=real_params),
|
||||
returns=(
|
||||
call.groupby.column
|
||||
if preserve_inner_type
|
||||
else ColumnType(type=TopType())
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="skipna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="engine",
|
||||
type=self.types.get_type("str"),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="engine_kwargs",
|
||||
type=self.types.get_type("dict"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.groupby.column,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
@@ -64,3 +84,127 @@ class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def kurt(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["skipna", "numeric_only"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna", "engine", "engine_kwargs"],
|
||||
)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
["numeric_only", "skipna"],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
preserve_inner_type=True,
|
||||
)
|
||||
|
||||
@method()
|
||||
def prod(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"ddof",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
"numeric_only",
|
||||
(
|
||||
"min_count",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"skipna",
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
(
|
||||
"var",
|
||||
"int",
|
||||
False,
|
||||
),
|
||||
"engine",
|
||||
"engine_kwargs",
|
||||
"numeric_only",
|
||||
"skipna",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.column_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry
|
||||
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import ColumnGroupBy, ColumnType, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -60,3 +61,18 @@ class ColumnManager:
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
|
||||
def get_attribute(self, column: ColumnType, name: str) -> Optional[Type]:
|
||||
types: TypesRegistry = self.typer.types
|
||||
match name:
|
||||
case "ndim" | "size":
|
||||
return types.get_type("int")
|
||||
|
||||
case "shape":
|
||||
return types.tuple_of("int")
|
||||
|
||||
case "T":
|
||||
return column
|
||||
|
||||
case _:
|
||||
return None
|
||||
|
||||
@@ -13,6 +13,7 @@ from midas.checker.types import (
|
||||
ColumnType,
|
||||
Function,
|
||||
GenericType,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -33,6 +34,10 @@ class Call:
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.column_expr, self.column)
|
||||
|
||||
|
||||
class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
def _element_binary_op(self, call: Call, method: str) -> ColumnType:
|
||||
@@ -69,8 +74,7 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
new_column = ColumnType(type=new_inner_type)
|
||||
return new_column
|
||||
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
def _element_wise(self, call: Call, method: str) -> Type:
|
||||
# TODO: support add with scalar
|
||||
|
||||
# Build signature with new column type and generic operand
|
||||
@@ -79,15 +83,17 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
name="add",
|
||||
params=[param_type],
|
||||
body=Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="other",
|
||||
type=ColumnType(type=param_type),
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
returns=self._element_binary_op(call, "__add__"),
|
||||
),
|
||||
returns=self._element_binary_op(call, method),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -105,18 +111,186 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__add__")
|
||||
|
||||
@method("sub", "__sub__")
|
||||
def sub(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__sub__")
|
||||
|
||||
@method("mul", "__mul__")
|
||||
def mul(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mul__")
|
||||
|
||||
@method("div", "truediv", "__truediv__")
|
||||
def truediv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__truediv__")
|
||||
|
||||
@method("floordiv", "__floordiv__")
|
||||
def floordiv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__floordiv__")
|
||||
|
||||
@method("mod", "__mod__")
|
||||
def mod(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mod__")
|
||||
|
||||
@method("pow", "__pow__")
|
||||
def pow(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__pow__")
|
||||
|
||||
@method("lt", "__lt__")
|
||||
def lt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__lt__")
|
||||
|
||||
@method("gt", "__gt__")
|
||||
def gt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__gt__")
|
||||
|
||||
@method("le", "__le__")
|
||||
def le(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__le__")
|
||||
|
||||
@method("ge", "__ge__")
|
||||
def ge(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ge__")
|
||||
|
||||
@method("ne", "__ne__")
|
||||
def ne(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ne__")
|
||||
|
||||
@method("eq", "__eq__")
|
||||
def eq(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__eq__")
|
||||
|
||||
def _aggregate(
|
||||
self,
|
||||
call: Call,
|
||||
kwargs: list[Function.Parameter] = [],
|
||||
*,
|
||||
preserve_inner_type: bool = False,
|
||||
) -> Type:
|
||||
signature = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
params=ParamSpec(
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
),
|
||||
returns=call.column if preserve_inner_type else ColumnType(type=TopType()),
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method("kurtosis", "kurt")
|
||||
def kurtosis(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method()
|
||||
def mode(self, call: Call) -> Type:
|
||||
return self._aggregate(call, preserve_inner_type=True)
|
||||
|
||||
@method("product", "prod")
|
||||
def product(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="ddof",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
returns=ColumnType(type=TopType()),
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="var",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def head(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=call.column,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def tail(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=call.column,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
@@ -131,52 +305,33 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
function: Function = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="by",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="level",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="as_index",
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=i + 2,
|
||||
name=name,
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="sort",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=4,
|
||||
name="group_keys",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=5,
|
||||
name="observed",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=6,
|
||||
name="dropna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
)
|
||||
for i, name in enumerate(
|
||||
["as_index", "sort", "group_keys", "observed", "dropna"]
|
||||
)
|
||||
],
|
||||
),
|
||||
returns=ColumnGroupBy(column=call.column),
|
||||
)
|
||||
|
||||
@@ -190,6 +345,21 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr):
|
||||
func_name: str = "__midas_column_same_length__"
|
||||
|
||||
# Efficiently compute length
|
||||
# https://stackoverflow.com/a/15943975/11109181
|
||||
def len_of_col(col: ast.expr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=ast.Name(id="len"),
|
||||
args=[
|
||||
ast.Attribute(
|
||||
value=col,
|
||||
attr="index",
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
@@ -207,16 +377,10 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=ast.Attribute(
|
||||
value=ast.Name(id="column1"),
|
||||
attr="size",
|
||||
),
|
||||
left=len_of_col(ast.Name(id="column1")),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="column2"),
|
||||
attr="size",
|
||||
)
|
||||
len_of_col(ast.Name(id="column2")),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,9 +5,15 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import FrameGroupBy, Function, Type
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
FrameGroupBy,
|
||||
Type,
|
||||
UnknownType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
@@ -22,45 +28,76 @@ class Call:
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.groupby_expr, self.groupby)
|
||||
|
||||
|
||||
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
bool_ = self.types.get_type("bool")
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="numeric_only",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="skipna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="engine",
|
||||
type=self.types.get_type("str"),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="engine_kwargs",
|
||||
type=self.types.get_type("dict"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.groupby.frame,
|
||||
)
|
||||
NAMED_ARGS: dict[str, str] = {
|
||||
"numeric_only": "bool",
|
||||
"skipna": "bool",
|
||||
"engine": "str",
|
||||
"engine_kwargs": "dict",
|
||||
}
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
def _aggregate(self, call: Call, method: str) -> Type:
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
|
||||
for column in call.groupby.frame.columns:
|
||||
column_groupby: ColumnGroupBy = ColumnGroupBy(column=column.type)
|
||||
result_type: Type = self.typer.call_method(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
call_expr=call.call_expr,
|
||||
obj=(call.groupby_expr, column_groupby),
|
||||
method_name=method,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
if not isinstance(result_type, ColumnType):
|
||||
result_type = ColumnType(type=UnknownType())
|
||||
new_columns.append(
|
||||
DataFrameType.Column(
|
||||
index=column.index,
|
||||
name=column.name,
|
||||
type=result_type,
|
||||
)
|
||||
)
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
@method()
|
||||
def kurt(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "kurt")
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "max")
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "mean")
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "median")
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "min")
|
||||
|
||||
@method()
|
||||
def prod(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "prod")
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "std")
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "sum")
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(call, "var")
|
||||
|
||||
@@ -7,6 +7,7 @@ from midas.ast.location import Location
|
||||
from midas.checker.frames.frame_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry
|
||||
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
@@ -240,3 +241,15 @@ class FrameManager:
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
|
||||
def get_attribute(self, frame: DataFrameType, name: str) -> Optional[Type]:
|
||||
types: TypesRegistry = self.typer.types
|
||||
match name:
|
||||
case "ndim" | "size":
|
||||
return types.get_type("int")
|
||||
|
||||
case "shape":
|
||||
return types.tuple_of("int", "int")
|
||||
|
||||
case _:
|
||||
return None
|
||||
|
||||
@@ -14,6 +14,7 @@ from midas.checker.types import (
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
OverloadedFunction,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
UnknownType,
|
||||
@@ -33,6 +34,10 @@ class Call:
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
@property
|
||||
def subject(self) -> TypedExpr:
|
||||
return (self.frame_expr, self.frame)
|
||||
|
||||
|
||||
class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
def _get_method_result(
|
||||
@@ -142,21 +147,21 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
|
||||
return DataFrameType(columns=new_columns)
|
||||
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
# TODO: support add with scalar, sequence, Series, dict
|
||||
|
||||
def _element_wise(self, call: Call, method: str) -> Type:
|
||||
# TODO: support scalar, sequence, Series, dict operand
|
||||
# Build signature with new schema and generic operand
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="other",
|
||||
type=DataFrameType(columns=[]),
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
returns=self._element_binary_op(call, "__add__"),
|
||||
),
|
||||
returns=self._element_binary_op(call, method),
|
||||
)
|
||||
|
||||
# Map arguments and compute result type
|
||||
@@ -173,28 +178,85 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__add__")
|
||||
|
||||
@method("sub", "__sub__")
|
||||
def sub(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__sub__")
|
||||
|
||||
@method("mul", "__mul__")
|
||||
def mul(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mul__")
|
||||
|
||||
@method("div", "truediv", "__truediv__")
|
||||
def truediv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__truediv__")
|
||||
|
||||
@method("floordiv", "__floordiv__")
|
||||
def floordiv(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__floordiv__")
|
||||
|
||||
@method("mod", "__mod__")
|
||||
def mod(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__mod__")
|
||||
|
||||
@method("pow", "__pow__")
|
||||
def pow(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__pow__")
|
||||
|
||||
@method("lt", "__lt__")
|
||||
def lt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__lt__")
|
||||
|
||||
@method("gt", "__gt__")
|
||||
def gt(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__gt__")
|
||||
|
||||
@method("le", "__le__")
|
||||
def le(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__le__")
|
||||
|
||||
@method("ge", "__ge__")
|
||||
def ge(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ge__")
|
||||
|
||||
@method("ne", "__ne__")
|
||||
def ne(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__ne__")
|
||||
|
||||
@method("eq", "__eq__")
|
||||
def eq(self, call: Call) -> Type:
|
||||
return self._element_wise(call, "__eq__")
|
||||
|
||||
def _aggregate(self, call: Call, kwargs: list[Function.Parameter] = []) -> Type:
|
||||
with_axis = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
params=ParamSpec(
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
),
|
||||
returns=ColumnType(type=TopType()),
|
||||
)
|
||||
without_axis = Function(
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
params=ParamSpec(
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="axis",
|
||||
type=self.types.get_type("None"),
|
||||
required=True,
|
||||
)
|
||||
),
|
||||
*kwargs,
|
||||
],
|
||||
),
|
||||
returns=TopType(),
|
||||
)
|
||||
overload = OverloadedFunction(
|
||||
@@ -212,56 +274,145 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method("kurtosis", "kurt")
|
||||
def kurtosis(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def max(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def median(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def min(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def mode(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method("product", "prod")
|
||||
def product(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def std(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="ddof",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def sum(self, call: Call) -> Type:
|
||||
return self._aggregate(call)
|
||||
|
||||
@method()
|
||||
def var(self, call: Call) -> Type:
|
||||
return self._aggregate(
|
||||
call,
|
||||
[
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="var",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@method()
|
||||
def head(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=call.frame,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def tail(self, call: Call) -> Type:
|
||||
signature = Function(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="n",
|
||||
type=self.types.get_type("int"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
returns=call.frame,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
|
||||
@method()
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
function: Function = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="by",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
Function.Parameter(
|
||||
pos=1,
|
||||
name="level",
|
||||
type=TopType(),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="as_index",
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=i + 2,
|
||||
name=name,
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="sort",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=4,
|
||||
name="group_keys",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=5,
|
||||
name="observed",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=6,
|
||||
name="dropna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
)
|
||||
for i, name in enumerate(
|
||||
["as_index", "sort", "group_keys", "observed", "dropna"]
|
||||
)
|
||||
],
|
||||
),
|
||||
returns=FrameGroupBy(frame=call.frame),
|
||||
)
|
||||
|
||||
@@ -275,6 +426,21 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
|
||||
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
|
||||
func_name: str = "__midas_frame_same_length__"
|
||||
|
||||
# Efficiently compute length
|
||||
# https://stackoverflow.com/a/15943975/11109181
|
||||
def len_of_df(df: ast.expr) -> ast.expr:
|
||||
return ast.Call(
|
||||
func=ast.Name(id="len"),
|
||||
args=[
|
||||
ast.Attribute(
|
||||
value=df,
|
||||
attr="index",
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
|
||||
self.assertions.define(
|
||||
func_name,
|
||||
ast.FunctionDef(
|
||||
@@ -292,17 +458,9 @@ class FrameMethodRegistry(MethodRegistry[Call]):
|
||||
body=[
|
||||
ast.Return(
|
||||
value=ast.Compare(
|
||||
left=ast.Attribute(
|
||||
value=ast.Name(id="frame1"),
|
||||
attr="size",
|
||||
),
|
||||
left=len_of_df(ast.Name(id="frame1")),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="frame2"),
|
||||
attr="size",
|
||||
)
|
||||
],
|
||||
comparators=[len_of_df(ast.Name(id="frame2"))],
|
||||
)
|
||||
)
|
||||
],
|
||||
|
||||
@@ -20,7 +20,7 @@ from midas.checker.types import Type, UnknownType
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
class _MethodRegistryMeta(type):
|
||||
@@ -41,12 +41,18 @@ class _MethodRegistryMeta(type):
|
||||
return new_class
|
||||
|
||||
|
||||
class HasLocation(Protocol):
|
||||
class MethodCall(Protocol):
|
||||
@property
|
||||
def location(self) -> Location: ...
|
||||
|
||||
@property
|
||||
def call_expr(self) -> p.Expr: ...
|
||||
|
||||
T = TypeVar("T", bound=HasLocation)
|
||||
@property
|
||||
def subject(self) -> TypedExpr: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=MethodCall)
|
||||
|
||||
|
||||
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
|
||||
@@ -72,7 +78,9 @@ class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
|
||||
def call(self, method: str, call: T) -> Type:
|
||||
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
|
||||
if func is None:
|
||||
self.reporter.warning(call.location, f"Unknown method {method}")
|
||||
self.reporter.warning(
|
||||
call.location, f"Unknown method {method} on {call.subject[1]}"
|
||||
)
|
||||
return UnknownType()
|
||||
return func(self, call)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from midas.checker.types import (
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
ParamSpec,
|
||||
Predicate,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -32,13 +33,6 @@ from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypedParamSpec:
|
||||
pos: list[Function.Argument]
|
||||
mixed: list[Function.Argument]
|
||||
kw: list[Function.Argument]
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
|
||||
@@ -47,7 +41,7 @@ class ReturnException(Exception):
|
||||
class MappedArgument:
|
||||
expr: m.Expr
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
argument: Function.Parameter
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@@ -196,9 +190,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
self._predicate_params[param.name.lexeme] = param.type.accept(self)
|
||||
|
||||
type: Type = self.type_of(stmt.body)
|
||||
params: list[TypedParamSpec] = [
|
||||
self._visit_param_spec(spec) for spec in stmt.params
|
||||
]
|
||||
params: list[ParamSpec] = [self._visit_param_spec(spec) for spec in stmt.params]
|
||||
|
||||
if not self._is_valid_predicate(type):
|
||||
self.reporter.error(
|
||||
@@ -209,9 +201,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
type = self._bool
|
||||
for spec in reversed(params):
|
||||
type = Function(
|
||||
pos_args=spec.pos,
|
||||
args=spec.mixed,
|
||||
kw_args=spec.kw,
|
||||
params=spec,
|
||||
returns=type,
|
||||
)
|
||||
self._predicate_params = {}
|
||||
@@ -386,30 +376,34 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
||||
)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||
params: TypedParamSpec = self._visit_param_spec(type.params)
|
||||
return Function(
|
||||
pos_args=params.pos,
|
||||
args=params.mixed,
|
||||
kw_args=params.kw,
|
||||
params=self._visit_param_spec(type.params),
|
||||
returns=type.returns.accept(self),
|
||||
)
|
||||
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
|
||||
def _visit_param_spec(self, spec: m.ParamSpec) -> ParamSpec:
|
||||
n_pos: int = len(spec.pos)
|
||||
n_mixed: int = len(spec.mixed)
|
||||
|
||||
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
||||
return Function.Argument(
|
||||
def process_param(
|
||||
param: m.FunctionType.Parameter, i: int
|
||||
) -> Function.Parameter:
|
||||
return Function.Parameter(
|
||||
pos=i,
|
||||
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||
type=arg.type.accept(self),
|
||||
required=arg.required,
|
||||
name=param.name.lexeme if param.name is not None else str(i),
|
||||
type=param.type.accept(self),
|
||||
required=param.required,
|
||||
)
|
||||
|
||||
return TypedParamSpec(
|
||||
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
|
||||
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
|
||||
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
|
||||
return ParamSpec(
|
||||
pos=[process_param(param, i) for i, param in enumerate(spec.pos)],
|
||||
mixed=[
|
||||
process_param(param, i + n_pos) for i, param in enumerate(spec.mixed)
|
||||
],
|
||||
kw=[
|
||||
process_param(param, i + n_pos + n_mixed)
|
||||
for i, param in enumerate(spec.kw)
|
||||
],
|
||||
)
|
||||
|
||||
def visit_frame_type(self, type: m.FrameType) -> Type:
|
||||
|
||||
@@ -7,6 +7,7 @@ from midas.checker.types import (
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -108,8 +109,8 @@ class Preamble(Environment):
|
||||
],
|
||||
)
|
||||
|
||||
def _list_of(self, item_type: Type) -> Type:
|
||||
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
||||
def _list_of(self, item_type: str | Type) -> Type:
|
||||
return self._types.list_of(item_type)
|
||||
|
||||
def _def_type_constructor(
|
||||
self, name: str, py_function: Optional[Callable[..., Any]] = None
|
||||
@@ -132,9 +133,9 @@ class Preamble(Environment):
|
||||
returns: Type = UnitType(),
|
||||
type_vars: list[TypeVar] = [],
|
||||
) -> Type:
|
||||
def map_args(params: list[Param], offset: int) -> list[Function.Argument]:
|
||||
def map_params(params: list[Param], offset: int) -> list[Function.Parameter]:
|
||||
return [
|
||||
Function.Argument(
|
||||
Function.Parameter(
|
||||
pos=i + offset,
|
||||
name=param.name,
|
||||
type=param.type,
|
||||
@@ -144,9 +145,11 @@ class Preamble(Environment):
|
||||
]
|
||||
|
||||
function = Function(
|
||||
pos_args=map_args(pos, 0),
|
||||
args=map_args(mixed, len(pos)),
|
||||
kw_args=map_args(kw, len(pos) + len(mixed)),
|
||||
params=ParamSpec(
|
||||
pos=map_params(pos, 0),
|
||||
mixed=map_params(mixed, len(pos)),
|
||||
kw=map_params(kw, len(pos) + len(mixed)),
|
||||
),
|
||||
returns=returns,
|
||||
)
|
||||
if len(type_vars) != 0:
|
||||
|
||||
@@ -31,6 +31,7 @@ from midas.checker.types import (
|
||||
FrameGroupBy,
|
||||
Function,
|
||||
GenericType,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
@@ -59,7 +60,7 @@ class UndefinedMethodException(Exception):
|
||||
class MappedArgument:
|
||||
expr: p.Expr
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
argument: Function.Parameter
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@@ -222,7 +223,7 @@ class PythonTyper(
|
||||
method_name: str,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Optional[Type]:
|
||||
) -> Type:
|
||||
unfolded: Type = unfold_type(obj[1])
|
||||
match unfolded:
|
||||
case DataFrameType():
|
||||
@@ -289,61 +290,64 @@ class PythonTyper(
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
env: Environment = Environment(self.env)
|
||||
pos_args: list[Function.Argument] = []
|
||||
args: list[Function.Argument] = []
|
||||
kw_args: list[Function.Argument] = []
|
||||
pos: list[Function.Parameter] = []
|
||||
mixed: list[Function.Parameter] = []
|
||||
kw: list[Function.Parameter] = []
|
||||
|
||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||
if arg.type is not None:
|
||||
return self.resolve_type_expr(arg.type)
|
||||
if arg.default is not None:
|
||||
return self.type_of(arg.default)
|
||||
def eval_param_type(param: p.Function.Parameter) -> Type:
|
||||
if param.type is not None:
|
||||
return self.resolve_type_expr(param.type)
|
||||
if param.default is not None:
|
||||
return self.type_of(param.default)
|
||||
return UnknownType()
|
||||
|
||||
pos: int = 0
|
||||
for arg in stmt.posonlyargs:
|
||||
pos_args.append(
|
||||
Function.Argument(
|
||||
pos=pos,
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
position: int = 0
|
||||
for param in stmt.params.pos:
|
||||
pos.append(
|
||||
Function.Parameter(
|
||||
pos=position,
|
||||
name=param.name,
|
||||
type=eval_param_type(param),
|
||||
required=param.default is None,
|
||||
)
|
||||
)
|
||||
pos += 1
|
||||
for arg in stmt.args:
|
||||
args.append(
|
||||
Function.Argument(
|
||||
pos=pos,
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
position += 1
|
||||
for param in stmt.params.mixed:
|
||||
mixed.append(
|
||||
Function.Parameter(
|
||||
pos=position,
|
||||
name=param.name,
|
||||
type=eval_param_type(param),
|
||||
required=param.default is None,
|
||||
)
|
||||
)
|
||||
pos += 1
|
||||
for arg in stmt.kwonlyargs:
|
||||
kw_args.append(
|
||||
Function.Argument(
|
||||
pos=pos, # not relevant
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
position += 1
|
||||
for param in stmt.params.kw:
|
||||
kw.append(
|
||||
Function.Parameter(
|
||||
pos=position, # not relevant
|
||||
name=param.name,
|
||||
type=eval_param_type(param),
|
||||
required=param.default is None,
|
||||
)
|
||||
)
|
||||
pos += 1
|
||||
position += 1
|
||||
|
||||
all_args: list[Function.Argument] = pos_args + args + kw_args
|
||||
for arg in all_args:
|
||||
env.define(arg.name, arg.type)
|
||||
param_spec: ParamSpec = ParamSpec(
|
||||
pos=pos,
|
||||
mixed=mixed,
|
||||
kw=kw,
|
||||
)
|
||||
all_params: list[Function.Parameter] = pos + mixed + kw
|
||||
for param in all_params:
|
||||
env.define(param.name, param.type)
|
||||
|
||||
returns_hint: Optional[Type] = None
|
||||
if stmt.returns is not None:
|
||||
returns_hint = self.resolve_type_expr(stmt.returns)
|
||||
# Early define to handle simple fully-typed recursion
|
||||
inside_function: Function = Function(
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
params=param_spec,
|
||||
returns=returns_hint,
|
||||
)
|
||||
self.env.define(stmt.name, inside_function)
|
||||
@@ -375,13 +379,11 @@ class PythonTyper(
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
function: Type = Function(
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
params=param_spec,
|
||||
returns=returns,
|
||||
)
|
||||
generic_params: list[TypeVar] = []
|
||||
all_types: list[Type] = [arg.type for arg in all_args] + [returns]
|
||||
all_types: list[Type] = [param.type for param in all_params] + [returns]
|
||||
for type in all_types:
|
||||
if isinstance(type, TypeVar):
|
||||
if type not in generic_params:
|
||||
@@ -580,9 +582,8 @@ class PythonTyper(
|
||||
right: TypedExpr,
|
||||
method: str,
|
||||
) -> Type:
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(
|
||||
return self.call_method(
|
||||
location=location,
|
||||
call_expr=expr,
|
||||
obj=left,
|
||||
@@ -597,8 +598,6 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return result or UnknownType()
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
||||
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
@@ -610,9 +609,8 @@ class PythonTyper(
|
||||
|
||||
operand: Type = self.type_of(expr.right)
|
||||
|
||||
result: Optional[Type]
|
||||
try:
|
||||
result = self.call_method(
|
||||
return self.call_method(
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
obj=(expr.right, operand),
|
||||
@@ -627,8 +625,6 @@ class PythonTyper(
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return result or UnknownType()
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||
match expr.callee:
|
||||
case p.VariableExpr(name="TypeVar"):
|
||||
@@ -644,8 +640,7 @@ class PythonTyper(
|
||||
match expr.callee:
|
||||
case p.GetExpr(object=obj, name=method):
|
||||
obj_type: Type = self.type_of(obj)
|
||||
return (
|
||||
self.call_method(
|
||||
return self.call_method(
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
obj=(obj, obj_type),
|
||||
@@ -653,8 +648,6 @@ class PythonTyper(
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
or UnknownType()
|
||||
)
|
||||
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
@@ -668,6 +661,14 @@ class PythonTyper(
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||
object: Type = self.type_of(expr.object)
|
||||
member: Optional[Type] = self.types.lookup_member(object, expr.name)
|
||||
|
||||
if member is None:
|
||||
match object:
|
||||
case DataFrameType():
|
||||
member = self.frame_mgr.get_attribute(object, expr.name)
|
||||
case ColumnType():
|
||||
member = self.column_mgr.get_attribute(object, expr.name)
|
||||
|
||||
if member is None:
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unknown member '{expr.name}' of {object}"
|
||||
|
||||
@@ -113,6 +113,15 @@ class TypesRegistry:
|
||||
raise ValueError(f"Predicate {name} already defined")
|
||||
self._predicates[name] = predicate
|
||||
|
||||
def is_builtin_subtype(self, name1: str, name2: str) -> bool:
|
||||
subtypes: set[str] = BUILTIN_SUBTYPES.get(name2, set())
|
||||
if name1 in subtypes:
|
||||
return True
|
||||
for subtype in subtypes:
|
||||
if self.is_builtin_subtype(name1, subtype):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
"""Check whether `type1` is a subtype of `type2`
|
||||
|
||||
@@ -150,7 +159,7 @@ class TypesRegistry:
|
||||
return self.is_subtype(base1, type2)
|
||||
|
||||
case (BaseType(name=name1), BaseType(name=name2)):
|
||||
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
||||
return self.is_builtin_subtype(name1, name2)
|
||||
|
||||
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||
for k, t in props2.items():
|
||||
@@ -225,92 +234,100 @@ class TypesRegistry:
|
||||
if not self.is_subtype(func1.returns, func2.returns):
|
||||
return False
|
||||
|
||||
pos1: list[Function.Argument] = func1.pos_args
|
||||
mixed1: list[Function.Argument] = func1.args
|
||||
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
|
||||
pos2: list[Function.Argument] = func2.pos_args
|
||||
mixed2: list[Function.Argument] = func2.args
|
||||
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
|
||||
pos1: list[Function.Parameter] = func1.params.pos
|
||||
mixed1: list[Function.Parameter] = func1.params.mixed
|
||||
kw1: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in func1.params.kw
|
||||
}
|
||||
pos2: list[Function.Parameter] = func2.params.pos
|
||||
mixed2: list[Function.Parameter] = func2.params.mixed
|
||||
kw2: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in func2.params.kw
|
||||
}
|
||||
|
||||
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
|
||||
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
|
||||
mixed_by_pos: dict[int, Function.Parameter] = {
|
||||
param.pos: param for param in mixed2
|
||||
}
|
||||
mixed_by_name: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in mixed2
|
||||
}
|
||||
|
||||
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
|
||||
def is_arg_subtype(sub: Function.Parameter, sup: Function.Parameter) -> bool:
|
||||
if not self.is_subtype(sub.type, sup.type):
|
||||
return False
|
||||
if not sup.required and sub.required:
|
||||
return False
|
||||
return True
|
||||
|
||||
for arg1 in pos1:
|
||||
arg2: Function.Argument
|
||||
if arg1.pos < len(pos2):
|
||||
arg2 = pos2[arg1.pos]
|
||||
elif arg1.pos in mixed_by_pos:
|
||||
arg2 = mixed_by_pos[arg1.pos]
|
||||
elif not arg1.required:
|
||||
for param1 in pos1:
|
||||
param2: Function.Parameter
|
||||
if param1.pos < len(pos2):
|
||||
param2 = pos2[param1.pos]
|
||||
elif param1.pos in mixed_by_pos:
|
||||
param2 = mixed_by_pos[param1.pos]
|
||||
elif not param1.required:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
if not is_arg_subtype(arg2, arg1):
|
||||
if not is_arg_subtype(param2, param1):
|
||||
return False
|
||||
|
||||
for name, arg1 in kw1.items():
|
||||
arg2: Function.Argument
|
||||
for name, param1 in kw1.items():
|
||||
param2: Function.Parameter
|
||||
if name in kw2:
|
||||
arg2 = kw2[name]
|
||||
param2 = kw2[name]
|
||||
elif name in mixed_by_name:
|
||||
arg2 = mixed_by_name[name]
|
||||
elif not arg1.required:
|
||||
param2 = mixed_by_name[name]
|
||||
elif not param1.required:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
if not is_arg_subtype(arg2, arg1):
|
||||
if not is_arg_subtype(param2, param1):
|
||||
return False
|
||||
|
||||
for arg1 in mixed1:
|
||||
pos_arg2: Optional[Function.Argument] = None
|
||||
kw_arg2: Optional[Function.Argument] = None
|
||||
if arg1.name in kw2:
|
||||
kw_arg2 = kw2[arg1.name]
|
||||
elif arg1.name in mixed_by_name:
|
||||
kw_arg2 = mixed_by_name[arg1.name]
|
||||
if arg1.pos < len(pos2):
|
||||
pos_arg2 = pos2[arg1.pos]
|
||||
elif arg1.pos in mixed_by_pos:
|
||||
pos_arg2 = mixed_by_pos[arg1.pos]
|
||||
for param1 in mixed1:
|
||||
pos_param2: Optional[Function.Parameter] = None
|
||||
kw_param2: Optional[Function.Parameter] = None
|
||||
if param1.name in kw2:
|
||||
kw_param2 = kw2[param1.name]
|
||||
elif param1.name in mixed_by_name:
|
||||
kw_param2 = mixed_by_name[param1.name]
|
||||
if param1.pos < len(pos2):
|
||||
pos_param2 = pos2[param1.pos]
|
||||
elif param1.pos in mixed_by_pos:
|
||||
pos_param2 = mixed_by_pos[param1.pos]
|
||||
|
||||
# No match in func2 and arg is required
|
||||
if pos_arg2 is None and kw_arg2 is None and arg1.required:
|
||||
if pos_param2 is None and kw_param2 is None and param1.required:
|
||||
return False
|
||||
|
||||
# Matching keyword argument
|
||||
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
|
||||
if kw_param2 is not None and not is_arg_subtype(kw_param2, param1):
|
||||
return False
|
||||
|
||||
# Matching positional argument
|
||||
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
|
||||
if pos_param2 is not None and not is_arg_subtype(pos_param2, param1):
|
||||
return False
|
||||
|
||||
mixed_positions: set[int] = {a.pos for a in mixed1}
|
||||
mixed_names: set[str] = {a.name for a in mixed1}
|
||||
for arg2 in pos2:
|
||||
if not arg2.required:
|
||||
mixed_positions: set[int] = {param.pos for param in mixed1}
|
||||
mixed_names: set[str] = {param.name for param in mixed1}
|
||||
for param2 in pos2:
|
||||
if not param2.required:
|
||||
continue
|
||||
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
|
||||
if param2.pos >= len(pos1) and param2.pos not in mixed_positions:
|
||||
return False
|
||||
|
||||
for name, arg2 in kw2.items():
|
||||
if not arg2.required:
|
||||
for name, param2 in kw2.items():
|
||||
if not param2.required:
|
||||
continue
|
||||
if name not in kw1 and name not in mixed_names:
|
||||
return False
|
||||
|
||||
for arg2 in mixed2:
|
||||
if arg2.required:
|
||||
for param2 in mixed2:
|
||||
if param2.required:
|
||||
continue
|
||||
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
|
||||
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
|
||||
pos_match: bool = param2.pos < len(pos1) or param2.pos in mixed_positions
|
||||
kw_match: bool = param2.name in kw1 or param2.name in mixed_names
|
||||
if not pos_match or not kw_match:
|
||||
return False
|
||||
|
||||
@@ -443,3 +460,29 @@ class TypesRegistry:
|
||||
|
||||
def lookup_predicate(self, name: str) -> Optional[Predicate]:
|
||||
return self._predicates.get(name)
|
||||
|
||||
def _by_name_or_type(self, name_or_type: str | Type) -> Type:
|
||||
if isinstance(name_or_type, str):
|
||||
return self.get_type(name_or_type)
|
||||
return name_or_type
|
||||
|
||||
def list_of(self, item_type: str | Type) -> Type:
|
||||
list_ = self.get_type("list")
|
||||
return self.apply_generic(list_, [self._by_name_or_type(item_type)])
|
||||
|
||||
def tuple_of(self, *item_types: str | Type) -> Type:
|
||||
tuple_ = self.get_type("tuple")
|
||||
return self.apply_generic(
|
||||
tuple_,
|
||||
[self._by_name_or_type(item_type) for item_type in item_types],
|
||||
)
|
||||
|
||||
def dict_of(self, key_type: str | Type, value_type: str | Type) -> Type:
|
||||
dict_ = self.get_type("dict")
|
||||
return self.apply_generic(
|
||||
dict_,
|
||||
[
|
||||
self._by_name_or_type(key_type),
|
||||
self._by_name_or_type(value_type),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -93,7 +93,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
function (p.Function): the function to resolve
|
||||
"""
|
||||
self.begin_scope()
|
||||
for param in function.all_args:
|
||||
for param in function.params.all:
|
||||
self.declare(param.name)
|
||||
self.define(param.name)
|
||||
self.resolve(*function.body)
|
||||
|
||||
@@ -45,28 +45,14 @@ class UnitType:
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Function:
|
||||
pos_args: list[Argument] = field(default_factory=list)
|
||||
args: list[Argument] = field(default_factory=list)
|
||||
kw_args: list[Argument] = field(default_factory=list)
|
||||
params: ParamSpec
|
||||
returns: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
args: list[str] = []
|
||||
if len(self.pos_args) != 0:
|
||||
args += list(map(str, self.pos_args))
|
||||
args.append("/")
|
||||
|
||||
if len(self.args) != 0:
|
||||
args += list(map(str, self.args))
|
||||
|
||||
if len(self.kw_args) != 0:
|
||||
args.append("*")
|
||||
args += list(map(str, self.kw_args))
|
||||
|
||||
return f"({', '.join(args)}) -> {self.returns}"
|
||||
return f"{self.params} -> {self.returns}"
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
class Parameter:
|
||||
pos: int
|
||||
name: str
|
||||
type: Type
|
||||
@@ -77,6 +63,28 @@ class Function:
|
||||
return f"{self.name}: {self.type}{opt}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ParamSpec:
|
||||
pos: list[Function.Parameter] = field(default_factory=list)
|
||||
mixed: list[Function.Parameter] = field(default_factory=list)
|
||||
kw: list[Function.Parameter] = field(default_factory=list)
|
||||
|
||||
def __str__(self) -> str:
|
||||
params: list[str] = []
|
||||
if len(self.pos) != 0:
|
||||
params += list(map(str, self.pos))
|
||||
params.append("/")
|
||||
|
||||
if len(self.mixed) != 0:
|
||||
params += list(map(str, self.mixed))
|
||||
|
||||
if len(self.kw) != 0:
|
||||
params.append("*")
|
||||
params += list(map(str, self.kw))
|
||||
|
||||
return f"({', '.join(params)})"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OverloadedFunction:
|
||||
overloads: list[Type]
|
||||
@@ -204,12 +212,19 @@ class ColumnGroupBy:
|
||||
|
||||
|
||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
def sub_argument(arg: Function.Argument):
|
||||
return Function.Argument(
|
||||
pos=arg.pos,
|
||||
name=arg.name,
|
||||
type=substitute_typevars(arg.type, substitutions),
|
||||
required=arg.required,
|
||||
def sub_parameter(param: Function.Parameter):
|
||||
return Function.Parameter(
|
||||
pos=param.pos,
|
||||
name=param.name,
|
||||
type=substitute_typevars(param.type, substitutions),
|
||||
required=param.required,
|
||||
)
|
||||
|
||||
def sub_param_spec(spec: ParamSpec):
|
||||
return ParamSpec(
|
||||
pos=list(map(sub_parameter, spec.pos)),
|
||||
mixed=list(map(sub_parameter, spec.mixed)),
|
||||
kw=list(map(sub_parameter, spec.kw)),
|
||||
)
|
||||
|
||||
def sub_column(col: DataFrameType.Column):
|
||||
@@ -235,15 +250,11 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
)
|
||||
|
||||
case Function(
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
params=params,
|
||||
returns=returns,
|
||||
):
|
||||
return Function(
|
||||
pos_args=list(map(sub_argument, pos_args)),
|
||||
args=list(map(sub_argument, args)),
|
||||
kw_args=list(map(sub_argument, kw_args)),
|
||||
params=sub_param_spec(params),
|
||||
returns=substitute_typevars(returns, substitutions),
|
||||
)
|
||||
|
||||
@@ -351,14 +362,14 @@ def unfold_type(type: Type) -> Type:
|
||||
|
||||
|
||||
def to_annotation(type: Type) -> str:
|
||||
def _args_annotation(func: Function) -> str:
|
||||
if len(func.kw_args) != 0:
|
||||
def _params_annotation(spec: ParamSpec) -> str:
|
||||
if len(spec.kw) != 0:
|
||||
return "..."
|
||||
|
||||
args: str = ", ".join(
|
||||
to_annotation(arg.type) for arg in func.pos_args + func.args
|
||||
params: str = ", ".join(
|
||||
to_annotation(param.type) for param in spec.pos + spec.mixed
|
||||
)
|
||||
return f"[{args}]"
|
||||
return f"[{params}]"
|
||||
|
||||
match type:
|
||||
case TopType():
|
||||
@@ -376,8 +387,8 @@ def to_annotation(type: Type) -> str:
|
||||
case UnitType():
|
||||
return "None"
|
||||
|
||||
case Function(returns=returns):
|
||||
params_annot: str = _args_annotation(type)
|
||||
case Function(params=params, returns=returns):
|
||||
params_annot: str = _params_annotation(params)
|
||||
return f"Callable[{params_annot}, {to_annotation(returns)}]"
|
||||
|
||||
case OverloadedFunction():
|
||||
|
||||
@@ -8,6 +8,7 @@ from midas.checker.types import (
|
||||
DataFrameType,
|
||||
Function,
|
||||
GenericType,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -29,8 +30,9 @@ class Unifier:
|
||||
keywords: dict[str, Type],
|
||||
) -> Optional[Type]:
|
||||
concrete_func: Function = Function(
|
||||
pos_args=[
|
||||
Function.Argument(
|
||||
params=ParamSpec(
|
||||
pos=[
|
||||
Function.Parameter(
|
||||
pos=i,
|
||||
name=str(i),
|
||||
type=arg,
|
||||
@@ -38,9 +40,8 @@ class Unifier:
|
||||
)
|
||||
for i, arg in enumerate(positional)
|
||||
],
|
||||
args=[],
|
||||
kw_args=[
|
||||
Function.Argument(
|
||||
kw=[
|
||||
Function.Parameter(
|
||||
pos=len(positional) + i,
|
||||
name=name,
|
||||
type=arg,
|
||||
@@ -48,6 +49,7 @@ class Unifier:
|
||||
)
|
||||
for i, (name, arg) in enumerate(keywords.items())
|
||||
],
|
||||
),
|
||||
returns=TopType(), # TODO: use expected type
|
||||
)
|
||||
return self.unify_generic(type, concrete_func, match_return=False)
|
||||
@@ -125,7 +127,7 @@ class Unifier:
|
||||
return self.match(template_column, concrete_column)
|
||||
|
||||
case (Function(), Function()):
|
||||
mapped: list[tuple[Function.Argument, Function.Argument]] = (
|
||||
mapped: list[tuple[Function.Parameter, Function.Parameter]] = (
|
||||
self.map_params(template, concrete)
|
||||
)
|
||||
substitutions: dict[str, Type] = {}
|
||||
@@ -161,19 +163,23 @@ class Unifier:
|
||||
|
||||
def map_params(
|
||||
self, func1: Function, func2: Function
|
||||
) -> list[tuple[Function.Argument, Function.Argument]]:
|
||||
pos1: list[Function.Argument] = func1.pos_args
|
||||
mixed1: list[Function.Argument] = func1.args
|
||||
kw1: list[Function.Argument] = func1.kw_args
|
||||
) -> list[tuple[Function.Parameter, Function.Parameter]]:
|
||||
pos1: list[Function.Parameter] = func1.params.pos
|
||||
mixed1: list[Function.Parameter] = func1.params.mixed
|
||||
kw1: list[Function.Parameter] = func1.params.kw
|
||||
|
||||
pos2: list[Function.Argument] = func2.pos_args
|
||||
mixed2: list[Function.Argument] = func2.args
|
||||
kw2: list[Function.Argument] = func2.kw_args
|
||||
pos2: list[Function.Parameter] = func2.params.pos
|
||||
mixed2: list[Function.Parameter] = func2.params.mixed
|
||||
kw2: list[Function.Parameter] = func2.params.kw
|
||||
|
||||
mapped: list[tuple[Function.Argument, Function.Argument]] = []
|
||||
mapped: list[tuple[Function.Parameter, Function.Parameter]] = []
|
||||
|
||||
by_pos2: dict[int, Function.Argument] = {arg.pos: arg for arg in pos2 + mixed2}
|
||||
by_name2: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2 + kw2}
|
||||
by_pos2: dict[int, Function.Parameter] = {
|
||||
param.pos: param for param in pos2 + mixed2
|
||||
}
|
||||
by_name2: dict[str, Function.Parameter] = {
|
||||
param.name: param for param in mixed2 + kw2
|
||||
}
|
||||
|
||||
for arg1 in pos1:
|
||||
if (arg2 := by_pos2.get(arg1.pos)) is not None:
|
||||
|
||||
@@ -77,14 +77,14 @@ class VarianceInferrer:
|
||||
match type:
|
||||
# Arguments are negative positions -> flip polarity
|
||||
# Return is positive position -> keep polarity
|
||||
case Function(pos_args=pos_args, args=mixed_args, kw_args=kw_args):
|
||||
all_args: list[Function.Argument] = pos_args + mixed_args + kw_args
|
||||
for arg in all_args:
|
||||
case Function(params=spec):
|
||||
all_params: list[Function.Parameter] = spec.pos + spec.mixed + spec.kw
|
||||
for param in all_params:
|
||||
self.walk(
|
||||
arg.type,
|
||||
param.type,
|
||||
-polarity,
|
||||
base_name,
|
||||
path + [f"arg:'{arg.name}'"],
|
||||
path + [f"param:'{param.name}'"],
|
||||
)
|
||||
|
||||
self.walk(type.returns, polarity, base_name, path + ["return"])
|
||||
@@ -109,10 +109,10 @@ class VarianceInferrer:
|
||||
Variance.COVARIANT: 1,
|
||||
Variance.CONTRAVARIANT: -1,
|
||||
}
|
||||
for arg, param in zip(args, params):
|
||||
for param, param in zip(args, params):
|
||||
param_polarity: Polarity = polarities[param.variance]
|
||||
self.walk(
|
||||
arg,
|
||||
param,
|
||||
cast(Polarity, polarity * param_polarity),
|
||||
base_name,
|
||||
path + [f"applied:'{name}'"],
|
||||
|
||||
@@ -157,15 +157,18 @@ class PythonHighlighter(
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
self.wrap(stmt, "function")
|
||||
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs:
|
||||
self._highlight_function_argument(arg)
|
||||
self._highlight_param_spec(stmt.params)
|
||||
for body_stmt in stmt.body:
|
||||
body_stmt.accept(self)
|
||||
|
||||
def _highlight_function_argument(self, arg: p.Function.Argument) -> None:
|
||||
self.wrap(arg, "argument")
|
||||
if arg.type is not None:
|
||||
arg.type.accept(self)
|
||||
def _highlight_param_spec(self, spec: p.ParamSpec) -> None:
|
||||
for param in spec.all:
|
||||
self._highlight_function_param(param)
|
||||
|
||||
def _highlight_function_param(self, param: p.Function.Parameter) -> None:
|
||||
self.wrap(param, "parameter")
|
||||
if param.type is not None:
|
||||
param.type.accept(self)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
stmt.type.accept(self)
|
||||
|
||||
@@ -23,7 +23,7 @@ span {
|
||||
--col: 215, 103, 224;
|
||||
}
|
||||
|
||||
&.argument {
|
||||
&.parameter {
|
||||
--col: 103, 192, 224;
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import midas.ast.midas as m
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.types import (
|
||||
Function,
|
||||
ParamSpec,
|
||||
Predicate,
|
||||
Type,
|
||||
to_annotation,
|
||||
@@ -54,16 +55,16 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
return expr.accept(self)
|
||||
case _:
|
||||
func = Function(
|
||||
pos_args=[],
|
||||
args=[
|
||||
Function.Argument(
|
||||
params=ParamSpec(
|
||||
mixed=[
|
||||
Function.Parameter(
|
||||
pos=0,
|
||||
name="_",
|
||||
type=self.types.get_type("Any"),
|
||||
required=True,
|
||||
)
|
||||
],
|
||||
kw_args=[],
|
||||
),
|
||||
returns=self.types.get_type("bool"),
|
||||
)
|
||||
alias: str = self.make_alias(None)
|
||||
@@ -94,28 +95,28 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
)
|
||||
return self.make_func(name, [ast.Return(value=body)], predicate.type)
|
||||
|
||||
def make_args(self, func: Function) -> ast.arguments:
|
||||
def make_args(self, params: ParamSpec) -> ast.arguments:
|
||||
return ast.arguments(
|
||||
posonlyargs=[
|
||||
ast.arg(
|
||||
arg=arg.name,
|
||||
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||
arg=param.name,
|
||||
annotation=ast.Constant(value=to_annotation(param.type)),
|
||||
)
|
||||
for arg in func.pos_args
|
||||
for param in params.pos
|
||||
],
|
||||
args=[
|
||||
ast.arg(
|
||||
arg=arg.name,
|
||||
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||
arg=param.name,
|
||||
annotation=ast.Constant(value=to_annotation(param.type)),
|
||||
)
|
||||
for arg in func.args
|
||||
for param in params.mixed
|
||||
],
|
||||
kwonlyargs=[
|
||||
ast.arg(
|
||||
arg=arg.name,
|
||||
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||
arg=param.name,
|
||||
annotation=ast.Constant(value=to_annotation(param.type)),
|
||||
)
|
||||
for arg in func.kw_args
|
||||
for param in params.kw
|
||||
],
|
||||
defaults=[],
|
||||
kw_defaults=[],
|
||||
@@ -125,11 +126,11 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
|
||||
) -> ast.stmt:
|
||||
match type:
|
||||
case Function(returns=Function()):
|
||||
case Function(params=params, returns=Function()):
|
||||
inner_name: str = f"inner{level}"
|
||||
return ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.make_args(type),
|
||||
args=self.make_args(params),
|
||||
body=[
|
||||
self.make_func(inner_name, inner_body, type.returns, level + 1),
|
||||
ast.Return(value=ast.Name(id=inner_name)),
|
||||
@@ -138,10 +139,10 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
case Function():
|
||||
case Function(params=params):
|
||||
return ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.make_args(type),
|
||||
args=self.make_args(params),
|
||||
body=inner_body,
|
||||
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||
decorator_list=[],
|
||||
|
||||
@@ -250,25 +250,26 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
value=self.convert(stmt.expr),
|
||||
)
|
||||
|
||||
def make_args(self, params: p.ParamSpec) -> ast.arguments:
|
||||
return ast.arguments(
|
||||
posonlyargs=[ast.arg(arg=param.name) for param in params.pos],
|
||||
args=[ast.arg(arg=param.name) for param in params.mixed],
|
||||
kwonlyargs=[ast.arg(arg=param.name) for param in params.kw],
|
||||
defaults=[
|
||||
self.convert(param.default)
|
||||
for param in params.pos + params.mixed
|
||||
if param.default is not None
|
||||
],
|
||||
kw_defaults=[
|
||||
self.convert(param.default) if param.default is not None else None
|
||||
for param in params.kw
|
||||
],
|
||||
)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> ast.stmt:
|
||||
return ast.FunctionDef(
|
||||
name=stmt.name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[ast.arg(arg=arg.name) for arg in stmt.posonlyargs],
|
||||
vararg=None,
|
||||
args=[ast.arg(arg=arg.name) for arg in stmt.args],
|
||||
kwonlyargs=[ast.arg(arg=arg.name) for arg in stmt.kwonlyargs],
|
||||
kwarg=None,
|
||||
defaults=[
|
||||
self.convert(arg.default)
|
||||
for arg in stmt.posonlyargs + stmt.args
|
||||
if arg.default is not None
|
||||
],
|
||||
kw_defaults=[
|
||||
self.convert(arg.default) if arg.default is not None else None
|
||||
for arg in stmt.kwonlyargs
|
||||
],
|
||||
),
|
||||
args=self.make_args(stmt.params),
|
||||
body=self._visit_body(stmt.body),
|
||||
decorator_list=[],
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ from midas.checker.types import (
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
ParamSpec,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
@@ -328,7 +329,7 @@ class StubsGenerator:
|
||||
return [
|
||||
ast.FunctionDef(
|
||||
name=name,
|
||||
args=self.dump_args(method, with_self=True),
|
||||
args=self.dump_params(method.params, with_self=True),
|
||||
returns=self.dump_type(method.returns),
|
||||
body=[ast.Expr(value=Empty)],
|
||||
decorator_list=[ast.Name(id="overload")] if overloaded else [],
|
||||
@@ -348,24 +349,33 @@ class StubsGenerator:
|
||||
)
|
||||
]
|
||||
|
||||
def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments:
|
||||
def dump_params(self, params: ParamSpec, with_self: bool = False) -> ast.arguments:
|
||||
pos: list[ast.arg] = [
|
||||
ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type))
|
||||
for arg in func.pos_args
|
||||
ast.arg(
|
||||
arg=f"_{param.pos}",
|
||||
annotation=self.dump_type(param.type),
|
||||
)
|
||||
for param in params.pos
|
||||
]
|
||||
mixed: list[ast.arg] = [
|
||||
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
|
||||
for arg in func.args
|
||||
ast.arg(
|
||||
arg=param.name,
|
||||
annotation=self.dump_type(param.type),
|
||||
)
|
||||
for param in params.mixed
|
||||
]
|
||||
kw: list[ast.arg] = [
|
||||
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
|
||||
for arg in func.kw_args
|
||||
ast.arg(
|
||||
arg=param.name,
|
||||
annotation=self.dump_type(param.type),
|
||||
)
|
||||
for param in params.kw
|
||||
]
|
||||
defaults: list[ast.expr] = [
|
||||
Empty for arg in func.pos_args + func.args if not arg.required
|
||||
Empty for param in params.pos + params.mixed if not param.required
|
||||
]
|
||||
kw_defaults: list[Optional[ast.expr]] = [
|
||||
None if arg.required else Empty for arg in func.kw_args
|
||||
None if param.required else Empty for param in params.kw
|
||||
]
|
||||
if with_self:
|
||||
arg = ast.arg(arg="self", annotation=None)
|
||||
@@ -391,7 +401,7 @@ class StubsGenerator:
|
||||
body=[
|
||||
ast.FunctionDef(
|
||||
name="__call__",
|
||||
args=self.dump_args(func, with_self=True),
|
||||
args=self.dump_params(func.params, with_self=True),
|
||||
returns=self.dump_type(func.returns),
|
||||
body=[ast.Expr(value=Empty)],
|
||||
decorator_list=[],
|
||||
|
||||
@@ -16,9 +16,10 @@ class Lexer(ABC):
|
||||
"""An abstract lexer which provides methods to easily extend it into a concrete one
|
||||
|
||||
This implementation is based on the [_Crafting Interpreters_][1] book by Robert Nystrom,
|
||||
more specifically on my [previous Python implementation](https://git.kb28.ch/HEL/pebble)
|
||||
more specifically on my [previous Python implementation][2]
|
||||
|
||||
[1]: https://craftinginterpreters.com/
|
||||
[2]: https://git.kb28.ch/HEL/pebble
|
||||
"""
|
||||
|
||||
def __init__(self, source: str, file: Optional[str] = None) -> None:
|
||||
@@ -168,6 +169,6 @@ class Lexer(ABC):
|
||||
def scan_token(self) -> None:
|
||||
"""Scan a token
|
||||
|
||||
This function should (at least) consume the current character and produce the appropriate token(s), using `add_token`
|
||||
This function should (at least) consume the current character and produce the appropriate token(s), using :func:`add_token`
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -81,6 +81,12 @@ class MidasLexer(Lexer):
|
||||
return None
|
||||
|
||||
def scan_string(self, opening: str):
|
||||
"""Scan the rest of a string and add it as a token
|
||||
|
||||
Args:
|
||||
opening (str): the opening quote or double quote, to be matched
|
||||
at the end of the string
|
||||
"""
|
||||
while self.peek() != opening and not self.is_at_end():
|
||||
self.advance()
|
||||
|
||||
@@ -147,6 +153,18 @@ class MidasLexer(Lexer):
|
||||
self.add_token(TokenType.COMMENT)
|
||||
|
||||
def is_identifier_char(self, char: str, *, start: bool) -> bool:
|
||||
"""Check whether a character is a valid as part of an identifier
|
||||
|
||||
Identifiers can contain any alphanumerical character or underscore.
|
||||
They cannot start with a digit.
|
||||
|
||||
Args:
|
||||
char (str): the character to check
|
||||
start (bool): whether this is the first character of the identifier
|
||||
|
||||
Returns:
|
||||
bool: `True` if the character is valid, `False` otherwise
|
||||
"""
|
||||
if char == "_":
|
||||
return True
|
||||
if char.isalpha():
|
||||
|
||||
@@ -104,6 +104,15 @@ class Token:
|
||||
)
|
||||
|
||||
def location_to(self, to: Token) -> Location:
|
||||
"""Create a new :class:`Location` spanning from this token to another
|
||||
|
||||
Args:
|
||||
to (Token): the end token
|
||||
|
||||
Returns:
|
||||
Location: a new :class:`Location` starting at this token and ending
|
||||
at `to`, both included
|
||||
"""
|
||||
return Location.span(self.get_location(), to.get_location())
|
||||
|
||||
@property
|
||||
|
||||
@@ -16,6 +16,9 @@ class TokenError:
|
||||
def get_report(self) -> str:
|
||||
"""Get a detailed error message
|
||||
|
||||
The error message is formatted as "(<position>) Error at <token>: <message>".
|
||||
For example: "(L2:5) Error at '3': Expected ')' after arguments."
|
||||
|
||||
Returns:
|
||||
str: the complete error message
|
||||
"""
|
||||
@@ -32,9 +35,10 @@ class Parser(ABC, Generic[T]):
|
||||
"""An abstract parser which provides methods to easily extend it into a concrete one
|
||||
|
||||
This implementation is based on the [_Crafting Interpreters_][1] book by Robert Nystrom,
|
||||
more specifically on my [previous Python implementation](https://git.kb28.ch/HEL/pebble)
|
||||
more specifically on my [previous Python implementation][2]
|
||||
|
||||
[1]: https://craftinginterpreters.com/
|
||||
[2]: https://git.kb28.ch/HEL/pebble
|
||||
"""
|
||||
|
||||
IGNORE: set[TokenType] = {
|
||||
@@ -173,7 +177,7 @@ class Parser(ABC, Generic[T]):
|
||||
error_msg (str): the error message if the token doesn't match
|
||||
|
||||
Raises:
|
||||
SyntaxError: if the current token doesn't match the given type
|
||||
ParsingError: if the current token doesn't match the given type
|
||||
|
||||
Returns:
|
||||
Token: the current token which matched the given type
|
||||
|
||||
@@ -35,10 +35,11 @@ from midas.parser.base import Parser
|
||||
from midas.parser.errors import ParsingError
|
||||
|
||||
|
||||
class MidasParser(Parser):
|
||||
class MidasParser(Parser[list[Stmt]]):
|
||||
"""A simple parser for midas type definitions"""
|
||||
|
||||
SYNC_BOUNDARY: set[TokenType] = {
|
||||
TokenType.ALIAS,
|
||||
TokenType.TYPE,
|
||||
TokenType.EXTEND,
|
||||
TokenType.PREDICATE,
|
||||
@@ -73,10 +74,10 @@ class MidasParser(Parser):
|
||||
def declaration(self) -> Optional[Stmt]:
|
||||
"""Try and parse a declaration
|
||||
|
||||
Any parsing error is caught and None is returned
|
||||
Any parsing error is caught and `None` is returned
|
||||
|
||||
Returns:
|
||||
Optional[Stmt]: the parsed Midas statement, or None if a ParsingError was raised
|
||||
Optional[Stmt]: the parsed Midas statement, or `None` if a ParsingError was raised
|
||||
"""
|
||||
try:
|
||||
if self.match(TokenType.TYPE):
|
||||
@@ -95,23 +96,14 @@ class MidasParser(Parser):
|
||||
def type_declaration(self) -> TypeStmt:
|
||||
"""Parse a type declaration
|
||||
|
||||
A type declaration can either be a simple type alias or a new complex type.
|
||||
In either case, it can have an optional template expression after its name, wrapped in brackets.
|
||||
A simple type alias is derived from a base type expression, and can have a optional constraint expression preceded by the `where` keyword.
|
||||
A full simple type alias is thus written:
|
||||
```
|
||||
type Name[Template](TypeExpr) where Condition
|
||||
```
|
||||
A type declaration creates a named subtype of a type expression.
|
||||
It can have an optional template expression after its name, wrapped in brackets, to handle type parameters.
|
||||
|
||||
A new complex type has a set of properties which are named, have a type and an optional constraint expression (also preceded by the `where` keyword).
|
||||
A full complex type definition is thus written:
|
||||
```
|
||||
type Name[Template] {
|
||||
prop1: TypeExpr1 where Condition1
|
||||
prop2: TypeExpr2 where Condition2
|
||||
...
|
||||
}
|
||||
```
|
||||
A type statement consists of:
|
||||
- the `type` keyword
|
||||
- a name (identifier)
|
||||
- (optional) type parameters
|
||||
- a body, a type expression (see :func:`type_expr`)
|
||||
|
||||
Returns:
|
||||
TypeStmt: the parsed type declaration statement
|
||||
@@ -165,11 +157,16 @@ class MidasParser(Parser):
|
||||
def alias_declaration(self) -> AliasStmt:
|
||||
"""Parse an alias declaration
|
||||
|
||||
An alias statement consists of:
|
||||
- the `alias` keyword
|
||||
- a name (identifier)
|
||||
- a body, a type expression (see :func:`type_expr`)
|
||||
|
||||
Returns:
|
||||
AliasStmt: the parsed alias declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
name: Token = self.consume_identifier("Expected alias name")
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' before alias definition")
|
||||
|
||||
@@ -184,8 +181,8 @@ class MidasParser(Parser):
|
||||
def type_expr(self) -> Type:
|
||||
"""Parse a type expression
|
||||
|
||||
A type is an identifier, optionally followed by a template expression.
|
||||
It can also optionally be followed by a '?' to indicate a nullable type
|
||||
A type expression can either be a function type (see :func:`function`)
|
||||
or a constraint type (see :func:`constraint_type`)
|
||||
|
||||
Returns:
|
||||
TypeExpr: the parsed type expression
|
||||
@@ -205,6 +202,15 @@ class MidasParser(Parser):
|
||||
return base
|
||||
|
||||
def constraint_type(self) -> Type:
|
||||
"""Parse a constraint type expression
|
||||
|
||||
A constraint type consists of a base type (see :func:`base_type`),
|
||||
optionally followed by the `where` keyword and a constraint
|
||||
expression (see :func:`constraint`)
|
||||
|
||||
Returns:
|
||||
Type: the parsed constraint type expression
|
||||
"""
|
||||
type: Type = self.base_type()
|
||||
if self.match(TokenType.WHERE):
|
||||
constraint: Expr = self.constraint()
|
||||
@@ -216,6 +222,14 @@ class MidasParser(Parser):
|
||||
return type
|
||||
|
||||
def base_type(self) -> Type:
|
||||
"""Parse a base type expression
|
||||
|
||||
A base type is either a parenthesized type expression (see :func:`type_expr`)
|
||||
or a generic type (see :func:`generic_type`)
|
||||
|
||||
Returns:
|
||||
Type: the parsed base type expression
|
||||
"""
|
||||
if self.match(TokenType.LEFT_PAREN):
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
||||
@@ -227,6 +241,17 @@ class MidasParser(Parser):
|
||||
return self.generic_type()
|
||||
|
||||
def generic_type(self) -> Type:
|
||||
"""Parse a generic type expression
|
||||
|
||||
A generic type consists of a named type (see :func:`named_type`),
|
||||
optionally followed by type arguments in brackets.
|
||||
|
||||
The special `Frame` type accepts a frame schema instead of type
|
||||
arguments (see :func:`frame_type`).
|
||||
|
||||
Returns:
|
||||
Type: the parsed generic type
|
||||
"""
|
||||
type: NamedType = self.named_type()
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
if type.name.lexeme == "Frame":
|
||||
@@ -240,6 +265,13 @@ class MidasParser(Parser):
|
||||
return type
|
||||
|
||||
def type_args(self) -> list[Type]:
|
||||
"""Parse a list of type arguments
|
||||
|
||||
Type arguments are a comma-separated list of type expression wrapped in brackets.
|
||||
|
||||
Returns:
|
||||
list[Type]: the list of type arguments, if any, or an empty list
|
||||
"""
|
||||
args: list[Type] = []
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
@@ -250,6 +282,13 @@ class MidasParser(Parser):
|
||||
return args
|
||||
|
||||
def named_type(self) -> NamedType:
|
||||
"""Parse a named type expression
|
||||
|
||||
A named type is an identifier token
|
||||
|
||||
Returns:
|
||||
NamedType: the parsed named type expression
|
||||
"""
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
return NamedType(
|
||||
location=name.get_location(),
|
||||
@@ -257,13 +296,13 @@ class MidasParser(Parser):
|
||||
)
|
||||
|
||||
def complex_type(self) -> ComplexType:
|
||||
"""Parse a type definition body
|
||||
"""Parse a complex type expression
|
||||
|
||||
A type definition body is a set of whitespace-separated
|
||||
property statements enclosed in curly braces
|
||||
A complex type consists of zero or more member statements enclosed in
|
||||
curly braces
|
||||
|
||||
Returns:
|
||||
ComplexType: the parsed complex type
|
||||
ComplexType: the parsed complex type expression
|
||||
"""
|
||||
left: Token = self.consume(
|
||||
TokenType.LEFT_BRACE, "Expected '{' to start type body"
|
||||
@@ -285,6 +324,20 @@ class MidasParser(Parser):
|
||||
)
|
||||
|
||||
def frame_type(self) -> FrameType:
|
||||
"""Parse a frame type expression
|
||||
|
||||
A frame type consists of:
|
||||
- the `Frame` identifier
|
||||
- an opening bracket `[`
|
||||
- a list of comma-separated column expression consisting of:
|
||||
- a name (token)
|
||||
- a colon `:`
|
||||
- a type expression (see :func:`type_expr`)
|
||||
- a closing bracket `]`
|
||||
|
||||
Returns:
|
||||
FrameType: the parsed frame type
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
self.consume(TokenType.LEFT_BRACKET, "Expected '[' to start frame schema")
|
||||
|
||||
@@ -311,9 +364,9 @@ class MidasParser(Parser):
|
||||
)
|
||||
|
||||
def constraint(self) -> Expr:
|
||||
"""Parse a constraint
|
||||
"""Parse a constraint expression
|
||||
|
||||
A constraint is basically a logical predicate
|
||||
A constraint is an expression (see :func:`expression`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed constraint expression
|
||||
@@ -321,10 +374,20 @@ class MidasParser(Parser):
|
||||
return self.expression()
|
||||
|
||||
def expression(self) -> Expr:
|
||||
"""Parse an expression
|
||||
|
||||
An expression consists of a logical AND expression (see :func:`and_`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
return self.and_()
|
||||
|
||||
def and_(self) -> Expr:
|
||||
"""Parse a logical AND expression or a simpler expression
|
||||
"""Parse a logical AND expression
|
||||
|
||||
An AND consists of one or more equality expressions (see :func:`equality`)
|
||||
separated by logical AND operators (`&`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
@@ -340,7 +403,10 @@ class MidasParser(Parser):
|
||||
return expr
|
||||
|
||||
def equality(self) -> Expr:
|
||||
"""Parse a logical equality expression or a simpler expression
|
||||
"""Parse an equality expression
|
||||
|
||||
An equality consists of one or more comparison expressions (see :func:`comparison`)
|
||||
separated by equality operators (`==`, `!=`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
@@ -356,7 +422,10 @@ class MidasParser(Parser):
|
||||
return expr
|
||||
|
||||
def comparison(self) -> Expr:
|
||||
"""Parse a logical comparison expression or a simpler expression
|
||||
"""Parse a comparison expression
|
||||
|
||||
A comparison consists of one or more term expressions (see :func:`term`)
|
||||
separated by comparison operators (`<`, `<=`, `>`, `>=`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
@@ -377,6 +446,14 @@ class MidasParser(Parser):
|
||||
return expr
|
||||
|
||||
def term(self) -> Expr:
|
||||
"""Parse a term expression
|
||||
|
||||
A term consists of one or more factor expressions (see :func:`factor`)
|
||||
separated by weak arithmetic operators (`+`, `-`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.factor()
|
||||
while self.match(TokenType.PLUS, TokenType.MINUS):
|
||||
operator: Token = self.previous()
|
||||
@@ -388,6 +465,14 @@ class MidasParser(Parser):
|
||||
return expr
|
||||
|
||||
def factor(self) -> Expr:
|
||||
"""Parse a factor expression
|
||||
|
||||
A factor consists of one or more unary expressions (see :func:`unary`)
|
||||
separated by strong arithmetic operators (`*`, `/`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.unary()
|
||||
while self.match(TokenType.STAR, TokenType.SLASH):
|
||||
operator: Token = self.previous()
|
||||
@@ -399,12 +484,15 @@ class MidasParser(Parser):
|
||||
return expr
|
||||
|
||||
def unary(self) -> Expr:
|
||||
"""Parse a unary expression or a simpler expression
|
||||
"""Parse a unary expression
|
||||
|
||||
A unary consists of a call expression (see :func:`call`) optionally
|
||||
preceded by zero or more unary operators (`+`, `-`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
if self.match(TokenType.MINUS):
|
||||
if self.match(TokenType.PLUS, TokenType.MINUS):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.unary()
|
||||
location: Location = Location.span(operator.get_location(), right.location)
|
||||
@@ -412,12 +500,44 @@ class MidasParser(Parser):
|
||||
return self.call()
|
||||
|
||||
def call(self) -> Expr:
|
||||
"""Parse a call expression
|
||||
|
||||
A call consists of a reference expression (see :func:`reference`)
|
||||
optionally followed by zero or more argument groups.
|
||||
|
||||
Argument groups are parenthesize, comma-separated list of arguments (see :func:`finish_call`)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.reference()
|
||||
while self.match(TokenType.LEFT_PAREN):
|
||||
expr = self.finish_call(expr)
|
||||
return expr
|
||||
|
||||
def finish_call(self, callee: Expr) -> Expr:
|
||||
"""Parse an argument group, i.e. the arguments of a call
|
||||
|
||||
Arguments are either passed positionally or by name (keyword argument).
|
||||
All positional arguments must come before any keyword argument and
|
||||
vice-versa. Arguments are separated by commas.
|
||||
|
||||
A positional argument simply consists of an expression (see :func:`expression`)
|
||||
|
||||
A keyword argument consists of and identifier, followed by the equal `=`
|
||||
token and an expression (see :func:`expression`).
|
||||
|
||||
Args:
|
||||
callee (Expr): the callee expression
|
||||
|
||||
Raises:
|
||||
ParsingError: if a positional argument is passed after a keyword
|
||||
argument or if a keyword argument's name is invalid (i.e. not
|
||||
an identifier)
|
||||
|
||||
Returns:
|
||||
Expr: the parsed call expression
|
||||
"""
|
||||
pos_args: list[Expr] = []
|
||||
kw_args: dict[str, Expr] = {}
|
||||
keywords: bool = False
|
||||
@@ -437,13 +557,14 @@ class MidasParser(Parser):
|
||||
else:
|
||||
value = self.expression()
|
||||
if self.check(TokenType.EQUAL):
|
||||
error_msg: str
|
||||
if keywords:
|
||||
raise self.error(self.peek(), "Invalid keyword argument name")
|
||||
error_msg = "Invalid keyword argument name"
|
||||
else:
|
||||
raise self.error(
|
||||
self.peek(),
|
||||
"Cannot pass positional arguments after a keyword argument",
|
||||
error_msg = (
|
||||
"Cannot pass positional arguments after a keyword argument"
|
||||
)
|
||||
raise self.error(self.peek(), error_msg)
|
||||
pos_args.append(value)
|
||||
|
||||
if not self.match(TokenType.COMMA):
|
||||
@@ -460,7 +581,12 @@ class MidasParser(Parser):
|
||||
)
|
||||
|
||||
def reference(self) -> Expr:
|
||||
"""Parse an attribute access expression or a simpler expression
|
||||
"""Parse a reference expression
|
||||
|
||||
A reference consists of a primary expression (see :func:`primary`)
|
||||
optionally followed by zero or more attribute accesses.
|
||||
|
||||
An attribute access consists of a dot `.` token followed by an identifier
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
@@ -475,7 +601,12 @@ class MidasParser(Parser):
|
||||
def primary(self) -> Expr:
|
||||
"""Parse a primary expression
|
||||
|
||||
This includes literals (booleans, numbers, etc.), wildcards, identifiers and grouped expressions
|
||||
This includes literals (booleans, numbers, etc.), wildcards, identifiers
|
||||
and grouped expressions
|
||||
|
||||
Raises:
|
||||
ParsingError: if a primary expressions cannot be parsed from the
|
||||
following tokens
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
@@ -508,14 +639,41 @@ class MidasParser(Parser):
|
||||
raise self.error(self.peek(), "Expected expression")
|
||||
|
||||
def consume_identifier(self, message: str = "Expected identifier") -> Token:
|
||||
"""Consume the current token if it is a valid identifier or raise an error (see :func:`check_identifier`)
|
||||
|
||||
If the current token is not a valid identifier, an error is raised
|
||||
with the provided message
|
||||
|
||||
Args:
|
||||
message (str, optional): the error message. Defaults to "Expected identifier".
|
||||
|
||||
Raises:
|
||||
ParsingError: if the current token is not a valid identifier
|
||||
|
||||
Returns:
|
||||
Token: the current token which is a valid identifier
|
||||
"""
|
||||
if not self.match_identifier():
|
||||
raise self.error(self.peek(), message)
|
||||
return self.previous()
|
||||
|
||||
def match_identifier(self) -> bool:
|
||||
"""Consume the next token if it is a valid identifier (see :func:`check_identifier`)
|
||||
|
||||
Returns:
|
||||
bool: whether a token was matched and consumed
|
||||
"""
|
||||
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
|
||||
|
||||
def check_identifier(self) -> bool:
|
||||
"""Check whether the current token is a valid identifier
|
||||
|
||||
A valid identifier is either an identifier token or a keyword token.
|
||||
This function always returns False if the parser is at the EOF token
|
||||
|
||||
Returns:
|
||||
bool: True if the current token is a valid identifier and not EOF
|
||||
"""
|
||||
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
|
||||
if self.check(tt):
|
||||
return True
|
||||
@@ -524,7 +682,14 @@ class MidasParser(Parser):
|
||||
def member_stmt(self) -> MemberStmt:
|
||||
"""Parse a member statement
|
||||
|
||||
A type member statement is written `prop name: Type` or `def name: Type`
|
||||
A member statement is written consists of:
|
||||
- the `prop` (for a property) or `def` (for a method) keyword
|
||||
- an name (identifier)
|
||||
- a colon `:`
|
||||
- a type expression (see :func:`type_expr`)
|
||||
|
||||
Raises:
|
||||
ParsingError: if the first token is neither `prop` nor `def`
|
||||
|
||||
Returns:
|
||||
MemberStmt: the parsed member statement
|
||||
@@ -551,7 +716,13 @@ class MidasParser(Parser):
|
||||
def extend_declaration(self) -> ExtendStmt:
|
||||
"""Parse an extension definition
|
||||
|
||||
An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }`
|
||||
An extension statement consists of:
|
||||
- the `extend` keyword
|
||||
- a type name (identifier)
|
||||
- (optional) type parameters (see :func:`type_params`)
|
||||
- an opening brace `{`
|
||||
- zero or more member statements (see :func:`member_stmt`)
|
||||
- a closing brace `}`
|
||||
|
||||
Returns:
|
||||
ExtendStmt: the parsed extension statement
|
||||
@@ -576,7 +747,12 @@ class MidasParser(Parser):
|
||||
def predicate_declaration(self) -> PredicateStmt:
|
||||
"""Parse a predicate declaration
|
||||
|
||||
A predicate is written `predicate Name(subject: Type) = constraint_expression`
|
||||
A predicate statement consists of:
|
||||
- the `predicate` keyword
|
||||
- a name (identifier)
|
||||
- (optional) zero or more parameter specs (see :func:`function_params`)
|
||||
- an equal sign `=`
|
||||
- a body, a constraint expression (see :func:`constraint`)
|
||||
|
||||
Returns:
|
||||
PredicateStmt: the parsed predicate declaration statement
|
||||
@@ -587,7 +763,7 @@ class MidasParser(Parser):
|
||||
|
||||
params: list[ParamSpec] = []
|
||||
while self.check(TokenType.LEFT_PAREN):
|
||||
params.append(self.function_args())
|
||||
params.append(self.function_params())
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
||||
body: Expr = self.constraint()
|
||||
@@ -599,7 +775,18 @@ class MidasParser(Parser):
|
||||
)
|
||||
|
||||
def function(self) -> FunctionType:
|
||||
params: ParamSpec = self.function_args()
|
||||
"""Parse a function type expression
|
||||
|
||||
A function consists of:
|
||||
- the `fn` keyword
|
||||
- a parameter spec (see :func:`function_params`)
|
||||
- the arrow keyword `->`
|
||||
- a result type expression (see :func:`type_expr`)
|
||||
|
||||
Returns:
|
||||
FunctionType: the parsed function type expression
|
||||
"""
|
||||
params: ParamSpec = self.function_params()
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
@@ -610,36 +797,53 @@ class MidasParser(Parser):
|
||||
returns=result,
|
||||
)
|
||||
|
||||
def function_args(self) -> ParamSpec:
|
||||
def function_params(self) -> ParamSpec:
|
||||
"""Parse a parameter spec
|
||||
|
||||
A parameter spec consists of zero or more comma-separated parameters,
|
||||
wrapped in parentheses.
|
||||
|
||||
Like in Python, it can contain positional-only, mixed and keyword-only
|
||||
parameters (separated by `/` and `*`).
|
||||
|
||||
Each parameter has a type (see :func:`type_expr`),
|
||||
preceded by a name (identifier) and a colon `:` (not required for
|
||||
positional-only parameters).
|
||||
|
||||
Returns:
|
||||
ParamSpec: the parsed parameter spec
|
||||
"""
|
||||
l_paren: Token = self.consume(
|
||||
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||
)
|
||||
pos_args: list[FunctionType.Argument] = []
|
||||
args: list[FunctionType.Argument] = []
|
||||
kw_args: list[FunctionType.Argument] = []
|
||||
pos: list[FunctionType.Parameter] = []
|
||||
mixed: list[FunctionType.Parameter] = []
|
||||
kw: list[FunctionType.Parameter] = []
|
||||
|
||||
args_first_tokens: list[Token] = []
|
||||
mixed_first_tokens: list[Token] = []
|
||||
|
||||
section: int = 0
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
|
||||
match section:
|
||||
case 0 if self.match(TokenType.SLASH):
|
||||
pos_args = args
|
||||
args = []
|
||||
args_first_tokens = []
|
||||
pos = mixed
|
||||
mixed = []
|
||||
mixed_first_tokens = []
|
||||
section = 1
|
||||
case 0 | 1 if self.match(TokenType.STAR):
|
||||
section = 2
|
||||
case _:
|
||||
# Record first token of mixed argument for errors if unnamed
|
||||
# Record first token of mixed parameters for errors if unnamed
|
||||
if section != 2:
|
||||
args_first_tokens.append(self.peek())
|
||||
mixed_first_tokens.append(self.peek())
|
||||
|
||||
name: Optional[Token] = None
|
||||
if section == 2:
|
||||
name = self.consume_identifier("Expected keyword argument name")
|
||||
name = self.consume_identifier(
|
||||
"Expected keyword parameter name"
|
||||
)
|
||||
self.consume(
|
||||
TokenType.COLON, "Expected ':' after argument name"
|
||||
TokenType.COLON, "Expected ':' after parameter name"
|
||||
)
|
||||
elif self.check_identifier() and self.check_next(TokenType.COLON):
|
||||
name = self.advance()
|
||||
@@ -647,24 +851,24 @@ class MidasParser(Parser):
|
||||
|
||||
type: Type = self.type_expr()
|
||||
optional: bool = self.match(TokenType.QMARK)
|
||||
arg = FunctionType.Argument(
|
||||
param = FunctionType.Parameter(
|
||||
location=None,
|
||||
name=name,
|
||||
type=type,
|
||||
required=not optional,
|
||||
)
|
||||
if section == 2:
|
||||
kw_args.append(arg)
|
||||
kw.append(param)
|
||||
else:
|
||||
args.append(arg)
|
||||
mixed.append(param)
|
||||
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
|
||||
for arg, token in zip(args, args_first_tokens):
|
||||
if arg.name is None:
|
||||
for param, token in zip(mixed, mixed_first_tokens):
|
||||
if param.name is None:
|
||||
# Not raised because we can keep parsing
|
||||
self.error(token, "Unnamed mixed argument")
|
||||
self.error(token, "Unnamed mixed parameter")
|
||||
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)
|
||||
return ParamSpec(l_paren=l_paren, pos=pos, mixed=mixed, kw=kw)
|
||||
|
||||
@@ -23,6 +23,7 @@ from midas.ast.python import (
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ParamSpec,
|
||||
RawExpr,
|
||||
RawStmt,
|
||||
ReturnStmt,
|
||||
@@ -49,6 +50,8 @@ class UnsupportedSyntaxError(Exception):
|
||||
|
||||
|
||||
class PythonParser:
|
||||
"""A parser to convert raw Python `ast` nodes in custom IR nodes"""
|
||||
|
||||
CAST_FUNCTION = "cast"
|
||||
UNSAFE_CAST_FUNCTION = "unsafe_cast"
|
||||
|
||||
@@ -212,27 +215,10 @@ class PythonParser:
|
||||
match node:
|
||||
case ast.FunctionDef(
|
||||
name=name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=posonlyargs,
|
||||
args=args,
|
||||
vararg=sink,
|
||||
kwonlyargs=kwonlyargs,
|
||||
kwarg=kw_sink,
|
||||
defaults=defaults,
|
||||
kw_defaults=kw_defaults,
|
||||
),
|
||||
returns=returns,
|
||||
body=raw_body,
|
||||
):
|
||||
|
||||
def parse_args(
|
||||
args_list: list[ast.arg], defaults: list[Optional[Expr]]
|
||||
) -> list[Function.Argument]:
|
||||
return [
|
||||
self._parse_function_argument(arg, default)
|
||||
for arg, default in zip(args_list, defaults)
|
||||
]
|
||||
|
||||
body: list[Stmt] = []
|
||||
for stmt in raw_body:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
@@ -241,54 +227,58 @@ class PythonParser:
|
||||
elif stmts is not None:
|
||||
body.extend(stmts)
|
||||
|
||||
parsed_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) for default in defaults
|
||||
]
|
||||
n_posargs: int = len(posonlyargs)
|
||||
n_args: int = len(args)
|
||||
n_all_posargs = n_posargs + n_args
|
||||
parsed_defaults = [
|
||||
None,
|
||||
] * (n_all_posargs - len(defaults)) + parsed_defaults
|
||||
|
||||
posargs_defaults: list[Optional[Expr]] = parsed_defaults[:n_posargs]
|
||||
args_defaults: list[Optional[Expr]] = parsed_defaults[n_posargs:]
|
||||
kwargs_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) if default is not None else None
|
||||
for default in kw_defaults
|
||||
]
|
||||
|
||||
return Function(
|
||||
location=loc,
|
||||
name=name,
|
||||
posonlyargs=parse_args(posonlyargs, posargs_defaults),
|
||||
args=parse_args(args, args_defaults),
|
||||
sink=(
|
||||
self._parse_function_argument(sink, None)
|
||||
if sink is not None
|
||||
else None
|
||||
),
|
||||
kwonlyargs=parse_args(kwonlyargs, kwargs_defaults),
|
||||
kw_sink=(
|
||||
self._parse_function_argument(kw_sink, None)
|
||||
if kw_sink is not None
|
||||
else None
|
||||
),
|
||||
params=self._parse_param_spec(args),
|
||||
returns=self._parse_type(returns) if returns is not None else None,
|
||||
body=body,
|
||||
)
|
||||
case _:
|
||||
print(f"Unsupported function definition: {ast.unparse(node)}")
|
||||
|
||||
def _parse_function_argument(
|
||||
def _parse_param_spec(self, args: ast.arguments) -> ParamSpec:
|
||||
def parse_params(
|
||||
args_list: list[ast.arg], defaults: list[Optional[Expr]]
|
||||
) -> list[Function.Parameter]:
|
||||
return [
|
||||
self._parse_function_parameter(arg, default)
|
||||
for arg, default in zip(args_list, defaults)
|
||||
]
|
||||
|
||||
defaults: list[ast.expr] = args.defaults
|
||||
parsed_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) for default in defaults
|
||||
]
|
||||
n_pos: int = len(args.posonlyargs)
|
||||
n_mixed: int = len(args.args)
|
||||
n_all_pos = n_pos + n_mixed
|
||||
parsed_defaults = [
|
||||
None,
|
||||
] * (n_all_pos - len(defaults)) + parsed_defaults
|
||||
|
||||
pos_defaults: list[Optional[Expr]] = parsed_defaults[:n_pos]
|
||||
mixed_defaults: list[Optional[Expr]] = parsed_defaults[n_pos:]
|
||||
kw_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) if default is not None else None
|
||||
for default in args.kw_defaults
|
||||
]
|
||||
|
||||
return ParamSpec(
|
||||
pos=parse_params(args.posonlyargs, pos_defaults),
|
||||
mixed=parse_params(args.args, mixed_defaults),
|
||||
kw=parse_params(args.kwonlyargs, kw_defaults),
|
||||
)
|
||||
|
||||
def _parse_function_parameter(
|
||||
self, arg: ast.arg, default: Optional[Expr]
|
||||
) -> Function.Argument:
|
||||
) -> Function.Parameter:
|
||||
loc: Location = Location.from_ast(arg)
|
||||
name: str = arg.arg
|
||||
type: Optional[MidasType] = None
|
||||
if arg.annotation is not None:
|
||||
type = self._parse_type(arg.annotation)
|
||||
return Function.Argument(
|
||||
return Function.Parameter(
|
||||
location=loc,
|
||||
name=name,
|
||||
type=type,
|
||||
|
||||
43
tests/__main__.py
Normal file
43
tests/__main__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Type
|
||||
|
||||
from midas.cli.ansi import Ansi
|
||||
from tests.base import Tester
|
||||
from tests.checker import CheckerTester
|
||||
from tests.generator import GeneratorTester
|
||||
from tests.midas import MidasTester
|
||||
from tests.python import PythonTester
|
||||
|
||||
|
||||
def print_banner(name: str):
|
||||
horizontal: str = "+" + "-" * (len(name) + 2) + "+"
|
||||
print(horizontal)
|
||||
print(f"| {name} |")
|
||||
print(horizontal)
|
||||
|
||||
|
||||
def run_tests(tester_cls: Type[Tester]) -> bool:
|
||||
print_banner(tester_cls.__name__)
|
||||
tester: Tester = tester_cls()
|
||||
success: bool = tester.run_all_tests()
|
||||
print()
|
||||
return success
|
||||
|
||||
|
||||
def main():
|
||||
testers: list[Type[Tester]] = [
|
||||
PythonTester,
|
||||
MidasTester,
|
||||
CheckerTester,
|
||||
GeneratorTester,
|
||||
]
|
||||
|
||||
success: bool = all(map(run_tests, testers))
|
||||
|
||||
if success:
|
||||
print(Ansi.FG(Ansi.BRIGHT_GREEN) + "All tests passed!" + Ansi.RESET)
|
||||
else:
|
||||
print(Ansi.FG(Ansi.BRIGHT_RED) + "Some tests failed!" + Ansi.RESET)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -7,6 +7,8 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Protocol
|
||||
|
||||
from midas.cli.ansi import Ansi
|
||||
|
||||
|
||||
class CaseResult(Protocol):
|
||||
def dumps(self) -> str: ...
|
||||
@@ -44,8 +46,11 @@ class Tester(ABC):
|
||||
|
||||
print(rule)
|
||||
for i, test in enumerate(tests):
|
||||
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
|
||||
path: Path = test.resolve().relative_to(self.CASES_DIR)
|
||||
print(f"{Ansi.FG(Ansi.BRIGHT_CYAN)}Case {i+1}/{n}: {path}{Ansi.RESET}")
|
||||
print(Ansi.DIM, end="")
|
||||
success: bool = self._run_test(test)
|
||||
print(Ansi.RESET, end="")
|
||||
if success:
|
||||
successes += 1
|
||||
else:
|
||||
@@ -146,7 +151,8 @@ class Tester(ABC):
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
case None:
|
||||
print("No subcommand provided. Available subcommands: run, update")
|
||||
success: bool = tester.run_all_tests()
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
case _:
|
||||
print(f"Unknown subcommand '{args.subcommand}'")
|
||||
|
||||
@@ -124,7 +124,7 @@
|
||||
22
|
||||
]
|
||||
},
|
||||
"message": "Multiple values for argument 'b'"
|
||||
"message": "Multiple values for parameter 'b'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
@@ -152,7 +152,7 @@
|
||||
12
|
||||
]
|
||||
},
|
||||
"message": "Unknown keyword argument 'a'"
|
||||
"message": "Unknown keyword parameter 'a'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
@@ -194,7 +194,7 @@
|
||||
17
|
||||
]
|
||||
},
|
||||
"message": "Unknown keyword argument 'g'"
|
||||
"message": "Unknown keyword parameter 'g'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
@@ -277,7 +277,8 @@
|
||||
"name": "foo"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
@@ -287,7 +288,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
@@ -297,7 +298,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [
|
||||
"kw": [
|
||||
{
|
||||
"pos": 2,
|
||||
"name": "c",
|
||||
@@ -306,7 +307,8 @@
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -351,7 +353,8 @@
|
||||
"name": "foo"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
@@ -361,7 +364,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
@@ -371,7 +374,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [
|
||||
"kw": [
|
||||
{
|
||||
"pos": 2,
|
||||
"name": "c",
|
||||
@@ -380,7 +383,8 @@
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -443,7 +447,8 @@
|
||||
"name": "foo"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
@@ -453,7 +458,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
@@ -463,7 +468,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [
|
||||
"kw": [
|
||||
{
|
||||
"pos": 2,
|
||||
"name": "c",
|
||||
@@ -472,7 +477,8 @@
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -539,7 +545,8 @@
|
||||
"name": "foo"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
@@ -549,7 +556,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
@@ -559,7 +566,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [
|
||||
"kw": [
|
||||
{
|
||||
"pos": 2,
|
||||
"name": "c",
|
||||
@@ -568,7 +575,8 @@
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -649,7 +657,8 @@
|
||||
"name": "foo"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
@@ -659,7 +668,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
@@ -669,7 +678,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [
|
||||
"kw": [
|
||||
{
|
||||
"pos": 2,
|
||||
"name": "c",
|
||||
@@ -678,7 +687,8 @@
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -762,7 +772,8 @@
|
||||
"name": "foo"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
@@ -772,7 +783,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
@@ -782,7 +793,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [
|
||||
"kw": [
|
||||
{
|
||||
"pos": 2,
|
||||
"name": "c",
|
||||
@@ -791,7 +802,8 @@
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -850,7 +862,8 @@
|
||||
"name": "foo"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
@@ -860,7 +873,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
@@ -870,7 +883,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [
|
||||
"kw": [
|
||||
{
|
||||
"pos": 2,
|
||||
"name": "c",
|
||||
@@ -879,7 +892,8 @@
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -929,7 +943,8 @@
|
||||
"name": "foo"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
@@ -939,7 +954,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
@@ -949,7 +964,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [
|
||||
"kw": [
|
||||
{
|
||||
"pos": 2,
|
||||
"name": "c",
|
||||
@@ -958,7 +973,8 @@
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -1034,7 +1050,8 @@
|
||||
"name": "foo"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
@@ -1044,7 +1061,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
@@ -1054,7 +1071,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [
|
||||
"kw": [
|
||||
{
|
||||
"pos": 2,
|
||||
"name": "c",
|
||||
@@ -1063,7 +1080,8 @@
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -1150,7 +1168,8 @@
|
||||
"name": "foo"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
@@ -1160,7 +1179,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
@@ -1170,7 +1189,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [
|
||||
"kw": [
|
||||
{
|
||||
"pos": 2,
|
||||
"name": "c",
|
||||
@@ -1179,7 +1198,8 @@
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -1266,7 +1286,8 @@
|
||||
"name": "foo"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
@@ -1276,7 +1297,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
@@ -1286,7 +1307,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [
|
||||
"kw": [
|
||||
{
|
||||
"pos": 2,
|
||||
"name": "c",
|
||||
@@ -1295,7 +1316,8 @@
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -1382,7 +1404,8 @@
|
||||
"name": "foo"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
@@ -1392,7 +1415,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
@@ -1402,7 +1425,7 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [
|
||||
"kw": [
|
||||
{
|
||||
"pos": 2,
|
||||
"name": "c",
|
||||
@@ -1411,7 +1434,8 @@
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
|
||||
@@ -136,8 +136,9 @@
|
||||
"name": "maximum"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
"params": {
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
@@ -155,7 +156,8 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
|
||||
@@ -312,7 +312,8 @@
|
||||
"name": "print"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "object",
|
||||
@@ -320,8 +321,9 @@
|
||||
"required": false
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -120,7 +120,8 @@
|
||||
"name": "bool"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "object",
|
||||
@@ -128,8 +129,9 @@
|
||||
"required": false
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -377,8 +379,9 @@
|
||||
"name": "double"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
"params": {
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
@@ -388,7 +391,8 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
@@ -439,12 +443,14 @@
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "transform",
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "v",
|
||||
@@ -456,8 +462,9 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
@@ -485,8 +492,9 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
@@ -548,8 +556,9 @@
|
||||
"name": "double"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
"params": {
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
@@ -559,7 +568,8 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
@@ -610,12 +620,14 @@
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "transform",
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "v",
|
||||
@@ -627,8 +639,9 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
@@ -656,8 +669,9 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
@@ -709,8 +723,9 @@
|
||||
"name": "is_odd"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
"params": {
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "value",
|
||||
@@ -720,7 +735,8 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "bool"
|
||||
}
|
||||
@@ -771,12 +787,14 @@
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "transform",
|
||||
"type": {
|
||||
"pos_args": [
|
||||
"params": {
|
||||
"pos": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "v",
|
||||
@@ -788,8 +806,9 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "U",
|
||||
"bound": null,
|
||||
@@ -817,8 +836,9 @@
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"args": [],
|
||||
"kw_args": [],
|
||||
"mixed": [],
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"name": "list",
|
||||
"args": [
|
||||
|
||||
117
tests/cases/checker/09_frame_ops.py
Normal file
117
tests/cases/checker/09_frame_ops.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
df1: Frame[a:int, b:float]
|
||||
df2: Frame[a:int, b:float]
|
||||
|
||||
_: Any
|
||||
|
||||
# Arithmetic
|
||||
_ = df1 + df2
|
||||
_ = df1 - df2
|
||||
_ = df1 * df2
|
||||
_ = df1 / df2
|
||||
_ = df1 // df2
|
||||
_ = df1 % df2
|
||||
_ = df1**df2
|
||||
|
||||
# Comparisons
|
||||
_ = df1 < df2
|
||||
_ = df1 > df2
|
||||
_ = df1 <= df2
|
||||
_ = df1 >= df2
|
||||
_ = df1 != df2
|
||||
_ = df1 == df2
|
||||
|
||||
# Aggregate
|
||||
_ = df1.kurt()
|
||||
_ = df1.kurtosis()
|
||||
_ = df1.max()
|
||||
_ = df1.mean()
|
||||
_ = df1.median()
|
||||
_ = df1.min()
|
||||
_ = df1.mode()
|
||||
_ = df1.prod()
|
||||
_ = df1.product()
|
||||
_ = df1.std()
|
||||
_ = df1.sum()
|
||||
_ = df1.var()
|
||||
|
||||
# Groupby
|
||||
df_gb = df1.groupby(by="a")
|
||||
|
||||
_ = df_gb.kurt()
|
||||
_ = df_gb.max()
|
||||
_ = df_gb.mean()
|
||||
_ = df_gb.median()
|
||||
_ = df_gb.min()
|
||||
_ = df_gb.prod()
|
||||
_ = df_gb.std()
|
||||
_ = df_gb.sum()
|
||||
_ = df_gb.var()
|
||||
|
||||
|
||||
# Columns
|
||||
|
||||
col1 = df1["a"]
|
||||
col2 = df1["a"]
|
||||
|
||||
# Arithmetic
|
||||
_ = col1 + col2
|
||||
_ = col1 - col2
|
||||
_ = col1 * col2
|
||||
_ = col1 / col2
|
||||
_ = col1 // col2
|
||||
_ = col1 % col2
|
||||
_ = col1**col2
|
||||
|
||||
# Comparisons
|
||||
_ = col1 < col2
|
||||
_ = col1 > col2
|
||||
_ = col1 <= col2
|
||||
_ = col1 >= col2
|
||||
_ = col1 != col2
|
||||
_ = col1 == col2
|
||||
|
||||
# Aggregate
|
||||
_ = col1.kurt()
|
||||
_ = col1.kurtosis()
|
||||
_ = col1.max()
|
||||
_ = col1.mean()
|
||||
_ = col1.median()
|
||||
_ = col1.min()
|
||||
_ = col1.mode()
|
||||
_ = col1.prod()
|
||||
_ = col1.product()
|
||||
_ = col1.std()
|
||||
_ = col1.sum()
|
||||
_ = col1.var()
|
||||
|
||||
# Groupby
|
||||
col_gb = col1.groupby(level=0)
|
||||
|
||||
_ = col_gb.kurt()
|
||||
_ = col_gb.max()
|
||||
_ = col_gb.mean()
|
||||
_ = col_gb.median()
|
||||
_ = col_gb.min()
|
||||
_ = col_gb.prod()
|
||||
_ = col_gb.std()
|
||||
_ = col_gb.sum()
|
||||
_ = col_gb.var()
|
||||
|
||||
# Attributes
|
||||
_ = df1.ndim # int
|
||||
_ = df1.size # int
|
||||
_ = df1.shape # (int, int)
|
||||
_ = col1.ndim # int
|
||||
_ = col1.size # int
|
||||
_ = col1.shape # (int)
|
||||
_ = col1.T # Column[int]
|
||||
|
||||
|
||||
# Misc
|
||||
_ = df1.head()
|
||||
_ = df1.tail()
|
||||
_ = col1.head()
|
||||
_ = col1.tail()
|
||||
4924
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
4924
tests/cases/checker/09_frame_ops.py.ref.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,8 +7,10 @@
|
||||
{
|
||||
"_type": "Function",
|
||||
"name": "func",
|
||||
"posonlyargs": [],
|
||||
"args": [
|
||||
"params": {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [],
|
||||
"mixed": [
|
||||
{
|
||||
"name": "col1",
|
||||
"type": {
|
||||
@@ -48,9 +50,8 @@
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"sink": null,
|
||||
"kwonlyargs": [],
|
||||
"kw_sink": null,
|
||||
"kw": []
|
||||
},
|
||||
"returns": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
@@ -119,7 +120,9 @@
|
||||
{
|
||||
"_type": "Function",
|
||||
"name": "func2",
|
||||
"posonlyargs": [
|
||||
"params": {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [
|
||||
{
|
||||
"name": "a",
|
||||
"type": {
|
||||
@@ -130,7 +133,7 @@
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
"mixed": [
|
||||
{
|
||||
"name": "b",
|
||||
"type": {
|
||||
@@ -141,8 +144,7 @@
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"sink": null,
|
||||
"kwonlyargs": [
|
||||
"kw": [
|
||||
{
|
||||
"name": "c",
|
||||
"type": {
|
||||
@@ -152,8 +154,8 @@
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"kw_sink": null,
|
||||
]
|
||||
},
|
||||
"returns": null,
|
||||
"body": []
|
||||
}
|
||||
|
||||
@@ -188,16 +188,16 @@ class MidasAstJsonSerializer(
|
||||
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
|
||||
return {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
|
||||
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
|
||||
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
|
||||
"pos": [self._serialize_func_param(arg) for arg in spec.pos],
|
||||
"mixed": [self._serialize_func_param(arg) for arg in spec.mixed],
|
||||
"kw": [self._serialize_func_param(arg) for arg in spec.kw],
|
||||
}
|
||||
|
||||
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||
def _serialize_func_param(self, param: FunctionType.Parameter) -> dict:
|
||||
return {
|
||||
"name": arg.name.lexeme if arg.name is not None else None,
|
||||
"type": arg.type.accept(self),
|
||||
"required": arg.required,
|
||||
"name": param.name.lexeme if param.name is not None else None,
|
||||
"type": param.type.accept(self),
|
||||
"required": param.required,
|
||||
}
|
||||
|
||||
def visit_extension_type(self, type: ExtensionType) -> dict:
|
||||
|
||||
@@ -22,6 +22,7 @@ from midas.ast.python import (
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ParamSpec,
|
||||
Pass,
|
||||
RawExpr,
|
||||
RawStmt,
|
||||
@@ -128,32 +129,30 @@ class PythonAstJsonSerializer(
|
||||
"expr": stmt.expr.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_argument(self, arg: Function.Argument) -> dict:
|
||||
return {
|
||||
"name": arg.name,
|
||||
"type": self._serialize_optional(arg.type),
|
||||
"default": self._serialize_optional(arg.default),
|
||||
}
|
||||
|
||||
def visit_function(self, stmt: Function) -> dict:
|
||||
return {
|
||||
"_type": "Function",
|
||||
"name": stmt.name,
|
||||
"posonlyargs": [self._serialize_argument(arg) for arg in stmt.posonlyargs],
|
||||
"args": [self._serialize_argument(arg) for arg in stmt.args],
|
||||
"sink": (
|
||||
self._serialize_argument(stmt.sink) if stmt.sink is not None else None
|
||||
),
|
||||
"kwonlyargs": [self._serialize_argument(arg) for arg in stmt.kwonlyargs],
|
||||
"kw_sink": (
|
||||
self._serialize_argument(stmt.kw_sink)
|
||||
if stmt.kw_sink is not None
|
||||
else None
|
||||
),
|
||||
"params": self._serialize_param_spec(stmt.params),
|
||||
"returns": self._serialize_optional(stmt.returns),
|
||||
"body": self._serialize_list(stmt.body),
|
||||
}
|
||||
|
||||
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
|
||||
return {
|
||||
"_type": "ParamSpec",
|
||||
"pos": [self._serialize_func_param(arg) for arg in spec.pos],
|
||||
"mixed": [self._serialize_func_param(arg) for arg in spec.mixed],
|
||||
"kw": [self._serialize_func_param(arg) for arg in spec.kw],
|
||||
}
|
||||
|
||||
def _serialize_func_param(self, param: Function.Parameter) -> dict:
|
||||
return {
|
||||
"name": param.name,
|
||||
"type": self._serialize_optional(param.type),
|
||||
"default": self._serialize_optional(param.default),
|
||||
}
|
||||
|
||||
def visit_type_assign(self, stmt: TypeAssign) -> dict:
|
||||
return {
|
||||
"_type": "TypeAssign",
|
||||
|
||||
Reference in New Issue
Block a user