Compare commits

...

28 Commits

Author SHA1 Message Date
9764484fd9 docs: add docstrings to midas parser 2026-07-04 01:30:14 +02:00
5b9e322c91 docs: add some docstrings in lexer classes 2026-07-03 22:41:21 +02:00
c18d9c18de tests: update with new parameter spec 2026-07-03 19:31:17 +02:00
9229f00375 refactor: rebrand function parameters and unify spec
rename function arguments to parameters where it was wrong, and add ParamSpec for Python AST, like for Midas
2026-07-03 19:24:30 +02:00
6b7a682dc5 docs: add some docstrings 2026-07-03 17:36:45 +02:00
35b97fd17b refactor(ast): restructure printers 2026-07-03 17:26:28 +02:00
03bc32400b Merge pull request 'Frame / columns in manual' (#28) from feat/frame-columns-in-manual into main
Reviewed-on: #28
2026-07-03 14:38:44 +00:00
4a93ee45d9 docs: add section about Frame type annotations 2026-07-03 16:32:35 +02:00
8197131d8d docs: add Column and Frame to manual 2026-07-03 13:31:56 +02:00
cf91187b7a fix(checker): remove bool as subtype of int 2026-07-03 12:56:47 +02:00
1b2bdf0b79 docs: add alias statements to manual 2026-07-03 12:56:20 +02:00
c6cc38bfeb Merge pull request 'Frame / column operations' (#27) from feat/simple-frame-ops into main
Reviewed-on: #27
2026-07-03 10:29:32 +00:00
4d3e3f44a1 fix(checker): correctly check length of frame/column 2026-07-03 12:28:39 +02:00
ec80b1e92e feat(checker): add head/tail methods 2026-07-03 12:13:30 +02:00
4ea15519f3 feate(checker): add some frame/column attributes 2026-07-03 12:07:36 +02:00
7a6e01cff8 fix(checker): delegate frame aggregate methods to columns 2026-07-03 11:42:35 +02:00
733c8736b8 feat(checker): add aggregation ops on column groupby 2026-07-03 11:25:06 +02:00
20173a0b07 feat(tests): add colors and run all tests in base module 2026-07-03 10:58:28 +02:00
a143972ef1 feat(checker): add aggregation ops on frame groupby 2026-07-03 02:20:51 +02:00
0c70048b62 feat(checker): add statistical ops on columns 2026-07-03 01:34:58 +02:00
1c0c917873 feat(checker): add statistical ops on frames 2026-07-03 01:27:16 +02:00
1f6189daa4 feat(checker): add comparison binary ops on columns 2026-07-03 01:05:24 +02:00
66b585c3d6 fix(checker): recursively check builtin subtypes 2026-07-03 01:04:45 +02:00
819ab3c2bf tests: add dataframe operations test 2026-07-03 00:58:29 +02:00
d8c0b17512 feat(checker): add comparison binary ops on frames 2026-07-03 00:57:27 +02:00
6e06f9078e fix(checker): improve unknown method message 2026-07-03 00:57:10 +02:00
ece2e3a6a3 feat(checker): add arithmetic binary ops on columns 2026-07-03 00:42:00 +02:00
74c07c9afb feat(checker): add arithmetic binary ops on frames 2026-07-03 00:38:56 +02:00
56 changed files with 8367 additions and 2281 deletions

View File

@@ -198,10 +198,26 @@ python3 build/midas/script.py
In this chapter, you will find a complete reference for the Midas definition language. In this chapter, you will find a complete reference for the Midas definition language.
A `*.midas` file contains a number of statements, which can be: A `*.midas` file contains a number of statements, which can be:
- *`alias`* statements (see @alias-stmt): to define a new type alias
- *`type`* statements (see @type-stmt): to define a new type - *`type`* statements (see @type-stmt): to define a new type
- *`extend`* statements (see @extend-stmt): to define member of a type - *`extend`* statements (see @extend-stmt): to define member of a type
- *`predicate`* statements (see @predicate-stmt): to define named predicates that can be used in constraint types - *`predicate`* statements (see @predicate-stmt): to define named predicates that can be used in constraint types
== Alias Statement <alias-stmt>
An *`alias`* statement lets you define a new type alias. It requires a unique name and base type.
While a `type` statement (see @type-stmt) allows generic definitions, aliases are purely a for givin an alternative name to a type.
#figure(
```midas
alias MyType = float
```,
caption: [Simple `alias` statement declaring a new type "`MyType`" equivalent to `float`],
) <midas-simple-alias>
This statement defines a new type called `MyType` which is equivalent to `float`. `MyType` and `float` can be used interchangeably.
== Type Statement <type-stmt> == Type Statement <type-stmt>
A *`type`* statement lets you define a new type. It requires a unique name and base type. A *`type`* statement lets you define a new type. It requires a unique name and base type.
@@ -212,7 +228,7 @@ The simplest form of a *`type`* statement is:
type MyType = float type MyType = float
```, ```,
caption: [Simple `type` statement declaring a new type "`MyType`" as a subtype of `float`], caption: [Simple `type` statement declaring a new type "`MyType`" as a subtype of `float`],
) <midas-simple-alias> ) <midas-simple-type>
This statement defines a new type called `MyType` which is a subtype of `float`. `MyType` is a `float` but a `float` is not necessarily `MyType`. This statement defines a new type called `MyType` which is a subtype of `float`. `MyType` is a `float` but a `float` is not necessarily `MyType`.
@@ -291,8 +307,7 @@ To better refine a generic type, you can also bound type parameters using the fo
caption: [Generic container type definition with a bound], caption: [Generic container type definition with a bound],
) )
This can be read as "`Container` is a generic type which takes one type parameter `T` that must be a subtype of `float`". This can be read as "`Container` is a generic type which takes one type parameter `T` that must be a subtype of `float`".\
You can use a generic type, i.e. instantiate it, by using a similar syntax with concrete type as arguments: You can use a generic type, i.e. instantiate it, by using a similar syntax with concrete type as arguments:
#figure( #figure(
@@ -318,6 +333,46 @@ The _body_ of a generic type, i.e. the right-hand side of the definition, can co
caption: [Type parameters in a generic type's body], caption: [Type parameters in a generic type's body],
) )
=== `Column` / `Frame` types
To provide useful type-checking for data engineers, Midas offers two special types: `Column` and `Frame`.
Their goal is to help type check Pandas' `Series` and `DataFrame` respectively.
==== `Column`
The `Column` type is a generic type used to represent a `pandas.Series` object.
You can use it like any other generic type and it will provide type checking for some common methods and attributes offered by Pandas.
#figure(
```midas
type Temperature = float
alias Temperatures = Column[Temperature]
```,
caption: [Simple column type definition],
)
==== `Frame` <frame-type>
The `Frame` type is a super-powered generic type used to represent a `pandas.DataFrame` object.
In place of type arguments, `Frame` accepts a schema, i.e. a series of column definitions.
@simple-frame show how you can define a simple frame type with 3 columns:
- `name`: a column of `Name` values
- `age`: a column of `int` values
- `height`: a column of `float where _ >= 0` values
Notice that you don't need to specify `Column` types.
#figure(
```midas
type Name = str where len(_) != 0
alias Data = Frame[
name: Name,
age: int,
height: float where _ >= 0
]
```,
) <simple-frame>
#pagebreak() #pagebreak()
== Extend Statement <extend-stmt> == Extend Statement <extend-stmt>
@@ -503,6 +558,7 @@ A simple annotation declaration, without assigning a value, is enough to declare
) )
Because unpacking is not supported, assigning to multiple values is also not handled by the type checker. Because unpacking is not supported, assigning to multiple values is also not handled by the type checker.
For more information about type annotations, see @type-annotations
== Arithmetic == Arithmetic
@@ -578,7 +634,7 @@ Conditional statements are checked relatively strictly by Midas. The test expres
Simple forms of `for` loops can be used, that is using a single variable and iterating over an object implementing the `__getitem__` method. Like above in @if-else, leaking variables from inside the loop is ignored. Simple forms of `for` loops can be used, that is using a single variable and iterating over an object implementing the `__getitem__` method. Like above in @if-else, leaking variables from inside the loop is ignored.
The `for`-`else` statements are not supported. `while` loops are also not not supported. `for`-`else` statements are not supported. `while` loops are also not supported.
== Functions == Functions
@@ -686,6 +742,35 @@ There may be some cases where the cost of checking a value at runtime is simply
If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated. If the value passed to `cast` or `unsafe_cast` is a literal (e.g. an integer, a string, a list of literals, etc.), the assertion is evaluated _at compile-time_ and no runtime assertion is generated.
== Annotations / Type Hints <type-annotations>
Vanilla Python already lets you use type hints to specify the type of variables and function parameters.
Midas use them to type check your code. Additionally, it allows you to use a special syntax to define a `Frame` types directly in these annotations.
Because these annotations are not interpretable by Python, your integrated type checker might complain loudly about them being invalid.
A workaround is to silence it by adding a type comment at the end of the line, as shown in @silence-errors.
#figure(
```python
var: Frame[name: str, age: float] # type: ignore # noqa: F821
```,
caption: [MyPy's and Pylance's complaints about custom type annotation can be silenced with type comments],
) <silence-errors>
=== Frame type annotation
The syntax is similar to how you can define frame types in the Midas language (see @frame-type). The only difference is that types can only be name references; you cannot inline constraint types.
The example of @python-frame-type shows how you can annotate a dataframe with some columns directly in Python.
#figure(
```python
df: Frame[name: Name, age: float, height: Length[Meter]] = ...
```,
caption: [Frame type annotation in Python],
) <python-frame-type>
= Commands <commands> = Commands <commands>
#TODO #TODO

View File

@@ -37,6 +37,9 @@ contexts:
pop: true pop: true
keywords: keywords:
- match: \balias\b
scope: keyword.declaration.midas
push: alias-stmt
- match: \btype\b - match: \btype\b
scope: keyword.declaration.midas scope: keyword.declaration.midas
push: type-stmt push: type-stmt
@@ -47,6 +50,15 @@ contexts:
scope: keyword.declaration.midas scope: keyword.declaration.midas
push: predicate-stmt push: predicate-stmt
alias-stmt:
- match: "{{identifier}}"
scope: entity.name.type
- match: "="
scope: keyword.operator.equal.midas
push: type-expr
- match: $
pop: true
type-stmt: type-stmt:
- match: "{{identifier}}" - match: "{{identifier}}"
scope: entity.name.type scope: entity.name.type
@@ -67,6 +79,13 @@ contexts:
- match: \b(where)\b - match: \b(where)\b
scope: keyword.other.midas scope: keyword.other.midas
set: constraint set: constraint
- match: "Frame"
scope: entity.name.type
push:
- match: \[
push: frame-schema
- match: $
pop: true
- match: "{{identifier}}" - match: "{{identifier}}"
scope: entity.name.type scope: entity.name.type
- match: $ - match: $
@@ -178,3 +197,15 @@ contexts:
- match: '\)' - match: '\)'
pop: true pop: true
frame-schema:
- include: frame-column
- match: \]
# scope: punctuation.section.block.end
pop: true
frame-column:
- match: "{{identifier}}"
scope: variable.other.member
- match: ":"
push: type-expr

View File

@@ -1,3 +1,9 @@
"""
Helper script to generate AST nodes for Midas and Python.
Takes in simple templates and generates full dataclasses and a visitor interface
"""
import re import re
from pathlib import Path from pathlib import Path

View File

@@ -29,9 +29,9 @@ class MemberKind(Enum):
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class ParamSpec: class ParamSpec:
l_paren: Token l_paren: Token
pos: list[FunctionType.Argument] pos: list[FunctionType.Parameter]
mixed: list[FunctionType.Argument] mixed: list[FunctionType.Parameter]
kw: list[FunctionType.Argument] kw: list[FunctionType.Parameter]
###< ###<
@@ -150,7 +150,7 @@ class FunctionType:
returns: Type returns: Type
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Argument: class Parameter:
location: Optional[Location] = None location: Optional[Location] = None
name: Optional[Token] name: Optional[Token]
type: Type type: Type

View File

@@ -12,6 +12,21 @@ from midas.ast.location import Location
###< ###<
###> Preamble
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
pos: list[Function.Parameter]
mixed: list[Function.Parameter]
kw: list[Function.Parameter]
@property
def all(self) -> list[Function.Parameter]:
return self.pos + self.mixed + self.kw
###<
###> MidasType | Type annotations | node ###> MidasType | Type annotations | node
class BaseType: class BaseType:
base: str base: str
@@ -42,25 +57,17 @@ class ExpressionStmt:
class Function: class Function:
name: str name: str
posonlyargs: list[Argument] params: ParamSpec
args: list[Argument]
sink: Optional[Argument]
kwonlyargs: list[Argument]
kw_sink: Optional[Argument]
returns: Optional[MidasType] returns: Optional[MidasType]
body: list[Stmt] body: list[Stmt]
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Argument: class Parameter:
location: Optional[Location] = None location: Optional[Location] = None
name: str name: str
type: Optional[MidasType] type: Optional[MidasType]
default: Optional[Expr] default: Optional[Expr]
@property
def all_args(self) -> list[Argument]:
return self.posonlyargs + self.args + self.kwonlyargs
class TypeAssign: class TypeAssign:
name: str name: str

View File

@@ -13,6 +13,8 @@ class HasLocation(Protocol):
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Location: class Location:
"""Information about the location of an AST node"""
lineno: int lineno: int
col_offset: int col_offset: int
end_lineno: Optional[int] end_lineno: Optional[int]
@@ -29,6 +31,16 @@ class Location:
@staticmethod @staticmethod
def span(start: Location, end: Location) -> Location: def span(start: Location, end: Location) -> Location:
"""Create a new location spanning from one location to another
Args:
start (Location): the starting location
end (Location): the end location
Returns:
Location: a new location spanning from the start of `start`
to the end of `end`
"""
return Location( return Location(
lineno=start.lineno, lineno=start.lineno,
col_offset=start.col_offset, col_offset=start.col_offset,

View File

@@ -30,9 +30,9 @@ class MemberKind(Enum):
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class ParamSpec: class ParamSpec:
l_paren: Token l_paren: Token
pos: list[FunctionType.Argument] pos: list[FunctionType.Parameter]
mixed: list[FunctionType.Argument] mixed: list[FunctionType.Parameter]
kw: list[FunctionType.Argument] kw: list[FunctionType.Parameter]
############## ##############
@@ -318,7 +318,7 @@ class FunctionType(Type):
returns: Type returns: Type
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Argument: class Parameter:
location: Optional[Location] = None location: Optional[Location] = None
name: Optional[Token] name: Optional[Token]
type: Type type: Type

View File

@@ -1,896 +0,0 @@
from __future__ import annotations
import ast
import io
from contextlib import contextmanager
from enum import Enum, auto
from typing import Generator, Generic, Optional, Protocol, TypeVar
import midas.ast.midas as m
import midas.ast.python as p
class _Level(Enum):
EMPTY = auto()
ACTIVE = auto()
LAST = auto()
class Expr(Protocol):
def accept(self, printer: AstPrinter) -> None: ...
T = TypeVar("T", bound=Expr)
class AstPrinter(Generic[T]):
LAST_CHILD = "└── "
CHILD = "├── "
VERTICAL = ""
EMPTY = " "
def __init__(self):
self._levels: list[_Level] = []
self._idx: Optional[int] = None
self._buf: io.StringIO = io.StringIO()
def print(self, expr: T):
self._buf = io.StringIO()
expr.accept(self)
return self._buf.getvalue()
@contextmanager
def _child_level(self, single: bool = False) -> Generator[None, None, None]:
self._levels.append(_Level.LAST if single else _Level.ACTIVE)
try:
yield
finally:
self._levels.pop()
def _mark_last(self):
if self._levels:
self._levels[-1] = _Level.LAST
def _write_line(self, text: str, *, last: bool = False):
if last:
self._mark_last()
indent: str = self._build_indent()
if self._idx is not None:
text = f"[{self._idx}] {text}"
self._idx = None
self._buf.write(indent + text + "\n")
def _build_indent(self) -> str:
parts: list[str] = []
for level in self._levels[:-1]:
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
if self._levels:
if self._levels[-1] == _Level.LAST:
parts.append(self.LAST_CHILD)
self._levels[-1] = _Level.EMPTY
else:
parts.append(self.CHILD)
return "".join(parts)
def _write_optional_child(
self, label: str, child: Optional[T], *, last: bool = False
):
if last:
self._mark_last()
if child is None:
self._write_line(f"{label}: None")
else:
self._write_line(label)
with self._child_level(single=True):
child.accept(self)
class MidasAstPrinter(
AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None], m.Type.Visitor[None]
):
# Statements
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self._write_line("TypeStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("params")
with self._child_level():
for i, param in enumerate(stmt.params):
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
self._print_type_param(param)
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
self._write_line("AliasStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def _print_type_param(self, param: m.TypeParam) -> None:
self._write_line("Param")
with self._child_level():
self._write_line(f'name: "{param.name.lexeme}"')
self._write_optional_child("bound", param.bound, last=True)
def visit_member_stmt(self, stmt: m.MemberStmt):
self._write_line("MemberStmt")
with self._child_level():
self._write_line(f"kind: {stmt.kind.name}")
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._write_line("ExtendStmt")
with self._child_level():
self._write_line("params")
with self._child_level():
for i, param in enumerate(stmt.params):
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
self._print_type_param(param)
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("params")
with self._child_level():
for i, param in enumerate(stmt.params):
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
self._print_type_param(param)
self._write_line("members", last=True)
with self._child_level():
for i, member in enumerate(stmt.members):
self._idx = i
if i == len(stmt.members) - 1:
self._mark_last()
member.accept(self)
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
self._write_line("PredicateStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("params")
with self._child_level():
for i, spec in enumerate(stmt.params):
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
self._visit_param_spec(spec)
self._write_line("body", last=True)
with self._child_level(single=True):
stmt.body.accept(self)
# Expressions
def visit_logical_expr(self, expr: m.LogicalExpr):
self._write_line("LogicalExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.lexeme}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_binary_expr(self, expr: m.BinaryExpr):
self._write_line("BinaryExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.lexeme}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_unary_expr(self, expr: m.UnaryExpr):
self._write_line("UnaryExpr")
with self._child_level():
self._write_line(f"operator: {expr.operator.lexeme}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self._write_line("CallExpr")
with self._child_level():
self._write_line("callee")
with self._child_level(single=True):
expr.callee.accept(self)
self._write_line("arguments")
with self._child_level():
for i, arg in enumerate(expr.arguments):
self._idx = i
if i == len(expr.arguments) - 1:
self._mark_last()
arg.accept(self)
self._write_line("keywords", last=True)
with self._child_level():
for i, (name, arg) in enumerate(expr.keywords.items()):
self._idx = i
if i == len(expr.keywords) - 1:
self._mark_last()
self._write_line(name)
with self._child_level(single=True):
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr):
self._write_line("GetExpr")
with self._child_level():
self._write_line("expr")
with self._child_level(single=True):
expr.expr.accept(self)
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
def visit_variable_expr(self, expr: m.VariableExpr):
self._write_line("VariableExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
def visit_grouping_expr(self, expr: m.GroupingExpr):
self._write_line("GroupingExpr")
with self._child_level():
self._write_line("expr", last=True)
with self._child_level(single=True):
expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
self._write_line("LiteralExpr")
with self._child_level():
self._write_line(f"value: {expr.value}", last=True)
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self._write_line("WildcardExpr")
def visit_named_type(self, type: m.NamedType) -> None:
self._write_line("NamedType")
with self._child_level():
self._write_line(f'name: "{type.name.lexeme}"', last=True)
def visit_generic_type(self, type: m.GenericType) -> None:
self._write_line("GenericType")
with self._child_level():
self._write_line("type")
with self._child_level():
type.type.accept(self)
self._write_line("args", last=True)
with self._child_level():
for i, param in enumerate(type.args):
self._idx = i
if i == len(type.args) - 1:
self._mark_last()
param.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self._write_line("ConstraintType")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
type.type.accept(self)
self._write_line("constraint", last=True)
with self._child_level(single=True):
type.constraint.accept(self)
def visit_complex_type(self, type: m.ComplexType) -> None:
self._write_line("ComplexType")
with self._child_level():
self._write_line("members", last=True)
with self._child_level():
for i, member in enumerate(type.members):
self._idx = i
if i == len(type.members) - 1:
self._mark_last()
member.accept(self)
def visit_extension_type(self, type: m.ExtensionType) -> None:
self._write_line("ExtensionType")
with self._child_level():
self._write_line("base")
with self._child_level(single=True):
type.base.accept(self)
self._write_line("extension", last=True)
with self._child_level(single=True):
type.extension.accept(self)
def visit_function_type(self, type: m.FunctionType) -> None:
self._write_line("FunctionType")
with self._child_level():
self._write_line("params")
with self._child_level(single=True):
self._visit_param_spec(type.params)
self._write_line("returns", last=True)
with self._child_level(single=True):
type.returns.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
self._write_line("ParamSpec")
with self._child_level():
self._write_line("pos")
with self._child_level():
for i, arg in enumerate(spec.pos):
self._idx = i
if i == len(spec.pos) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("mixed")
with self._child_level():
for i, arg in enumerate(spec.mixed):
self._idx = i
if i == len(spec.mixed) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw", last=True)
with self._child_level():
for i, arg in enumerate(spec.kw):
self._idx = i
if i == len(spec.kw) - 1:
self._mark_last()
self._print_function_arg(arg)
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
self._write_line("Argument")
with self._child_level():
name: str = "None"
if arg.name is not None:
name = f'"{arg.name.lexeme}"'
self._write_line(f"name: {name}")
self._write_line("type")
with self._child_level(single=True):
arg.type.accept(self)
self._write_line(f"required: {arg.required}", last=True)
def visit_frame_type(self, type: m.FrameType) -> None:
self._write_line("FrameType")
with self._child_level(single=True):
self._write_line("columns")
with self._child_level():
for i, column in enumerate(type.columns):
self._idx = i
if i == len(type.columns) - 1:
self._mark_last()
self._print_frame_column(column)
def _print_frame_column(self, column: m.FrameType.Column) -> None:
self._write_line("Column")
with self._child_level():
self._write_line(f'name: "{column.name.lexeme}"')
self._write_line("type")
with self._child_level(single=True):
column.type.accept(self)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
def __init__(self, indent: int = 4):
self.indent: int = indent
self.level: int = 0
def indented(self, text: str) -> str:
return " " * (self.level * self.indent) + text
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
self.level = 0
return expr.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
template: str = ""
if len(stmt.params) != 0:
params: list[str] = [self._print_type_param(param) for param in stmt.params]
template = f"[{', '.join(params)}]"
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
return self.indented(res)
def visit_alias_stmt(self, stmt: m.AliasStmt) -> str:
return self.indented(f"alias {stmt.name.lexeme} = {stmt.type.accept(self)}")
def _print_type_param(self, param: m.TypeParam) -> str:
res: str = param.name.lexeme
if param.bound is not None:
res += "<:" + param.bound.accept(self)
return res
def visit_member_stmt(self, stmt: m.MemberStmt):
keyword: str = {
m.MemberKind.PROPERTY: "prop",
m.MemberKind.METHOD: "def",
}.get(stmt.kind, "")
res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}"
return self.indented(res)
def visit_extend_stmt(self, stmt: m.ExtendStmt):
template: str = ""
if len(stmt.params) != 0:
params: list[str] = [self._print_type_param(param) for param in stmt.params]
template = f"[{', '.join(params)}]"
res: str = self.indented(f"extend {stmt.name.lexeme}{template}")
res += " {\n"
self.level += 1
for member in stmt.members:
res += member.accept(self) + "\n"
self.level -= 1
res += self.indented("}")
return res
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
body: str = stmt.body.accept(self)
return self.indented(f"predicate {name}{sig} = {body}")
def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self)
operator: str = expr.operator.lexeme
right: str = expr.right.accept(self)
return f"{left} {operator} {right}"
def visit_binary_expr(self, expr: m.BinaryExpr):
left: str = expr.left.accept(self)
operator: str = expr.operator.lexeme
right: str = expr.right.accept(self)
return f"{left} {operator} {right}"
def visit_unary_expr(self, expr: m.UnaryExpr):
operator: str = expr.operator.lexeme
right: str = expr.right.accept(self)
return f"{operator}{right}"
def visit_call_expr(self, expr: m.CallExpr) -> str:
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
]
return f"{expr.callee.accept(self)}({', '.join(args)})"
def visit_get_expr(self, expr: m.GetExpr):
expr_: str = expr.expr.accept(self)
name: str = expr.name.lexeme
return f"{expr_}.{name}"
def visit_variable_expr(self, expr: m.VariableExpr):
return expr.name.lexeme
def visit_grouping_expr(self, expr: m.GroupingExpr):
expr_: str = expr.expr.accept(self)
return f"({expr_})"
def visit_literal_expr(self, expr: m.LiteralExpr):
return str(expr.value)
def visit_wildcard_expr(self, expr: m.WildcardExpr):
return "_"
def visit_named_type(self, type: m.NamedType) -> str:
return type.name.lexeme
def visit_generic_type(self, type: m.GenericType) -> str:
res: str = type.type.accept(self)
if len(type.args) != 0:
args: list[str] = [param.accept(self) for param in type.args]
res += f"[{', '.join(args)}]"
return res
def visit_constraint_type(self, type: m.ConstraintType) -> str:
res: str = type.type.accept(self)
res += " where " + type.constraint.accept(self)
return res
def visit_complex_type(self, type: m.ComplexType) -> str:
res: str = "{\n"
self.level += 1
for member in type.members:
res += member.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
return res
def visit_extension_type(self, type: m.ExtensionType) -> str:
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
def visit_function_type(self, type: m.FunctionType) -> str:
spec: str = self._visit_param_spec(type.params)
return f"fn {spec} -> {type.returns.accept(self)}"
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
args: list[str] = pos_args
if len(pos_args) != 0:
args.append("/")
args += mixed_args
if len(kw_args) != 0:
args.append("*")
args += kw_args
return f"({', '.join(args)})"
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
res: str = ""
if arg.name is not None:
res += arg.name.lexeme
res += ": "
res += arg.type.accept(self)
if not arg.required:
res += "?"
return res
def visit_frame_type(self, type: m.FrameType) -> str:
res: str = self.indented("Frame[")
if len(type.columns) != 0:
res += "\n"
self.level += 1
columns: list[str] = []
for column in type.columns:
columns.append(self.indented(self._print_frame_column(column)))
res += ",\n".join(columns)
self.level -= 1
res += "\n"
res += "]"
return res
def _print_frame_column(self, column: m.FrameType.Column) -> str:
return f"{column.name.lexeme}: {column.type.accept(self)}"
class PythonAstPrinter(
AstPrinter,
p.MidasType.Visitor[None],
p.Stmt.Visitor[None],
p.Expr.Visitor[None],
):
def visit_base_type(self, node: p.BaseType) -> None:
self._write_line("BaseType")
with self._child_level():
self._write_line(f"base: {node.base}")
self._write_line("args:", last=True)
with self._child_level():
for i, arg in enumerate(node.args):
self._idx = i
if i == len(node.args) - 1:
self._mark_last()
arg.accept(self)
def visit_constraint_type(self, node: p.ConstraintType) -> None:
self._write_line("ConstraintType")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
node.type.accept(self)
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
def visit_frame_column(self, node: p.FrameColumn) -> None:
self._write_line("FrameColumn")
with self._child_level():
self._write_line(f"name: {node.name}")
self._write_optional_child("type", node.type, last=True)
def visit_frame_type(self, node: p.FrameType) -> None:
self._write_line("FrameType")
with self._child_level():
self._write_line("columns", last=True)
with self._child_level():
for i, col in enumerate(node.columns):
self._idx = i
if i == len(node.columns) - 1:
self._mark_last()
col.accept(self)
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
stmt.expr.accept(self)
def visit_function(self, stmt: p.Function) -> None:
self._write_line("Function")
with self._child_level():
self._write_line(f"name: {stmt.name}")
self._write_line("posonlyargs")
with self._child_level():
for i, arg in enumerate(stmt.posonlyargs):
self._idx = i
if i == len(stmt.posonlyargs) - 1:
self._mark_last()
self._print_argument(arg)
self._write_line("args")
with self._child_level():
for i, arg in enumerate(stmt.args):
self._idx = i
if i == len(stmt.args) - 1:
self._mark_last()
self._print_argument(arg)
self._write_line("kwonlyargs")
with self._child_level():
for i, arg in enumerate(stmt.kwonlyargs):
self._idx = i
if i == len(stmt.kwonlyargs) - 1:
self._mark_last()
self._print_argument(arg)
self._write_optional_child("returns", stmt.returns)
self._write_line("body", last=True)
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
def _print_argument(self, arg: p.Function.Argument) -> None:
self._write_line("FunctionArgument")
with self._child_level():
self._write_line(f"name: {arg.name}")
self._write_optional_child("type", arg.type, last=True)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
self._write_line("TypeAssign")
with self._child_level():
self._write_line(f"name: {stmt.name}")
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
self._write_line("AssignStmt")
with self._child_level():
self._write_line("targets")
with self._child_level():
for i, target in enumerate(stmt.targets):
self._idx = i
if i == len(stmt.targets) - 1:
self._mark_last()
target.accept(self)
self._write_line("value", last=True)
with self._child_level(single=True):
stmt.value.accept(self)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
self._write_line("ReturnStmt")
with self._child_level():
self._write_optional_child("value", stmt.value, last=True)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
self._write_line("IfStmt")
with self._child_level():
self._write_line("test")
with self._child_level(single=True):
stmt.test.accept(self)
self._write_line("body")
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
self._write_line("orelse", last=True)
with self._child_level():
for i, else_stmt in enumerate(stmt.orelse):
self._idx = i
if i == len(stmt.orelse) - 1:
self._mark_last()
else_stmt.accept(self)
def visit_pass(self, stmt: p.Pass) -> None:
self._write_line("Pass")
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self._write_line("ForStmt")
with self._child_level():
self._write_line("target")
with self._child_level(single=True):
stmt.target.accept(self)
self._write_line("iterator")
with self._child_level(single=True):
stmt.iterator.accept(self)
self._write_line("body", last=True)
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
self._write_line("RawStmt")
with self._child_level(single=True):
self._write_line(f"stmt: {ast.unparse(stmt.stmt)}")
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
self._write_line("CompareExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
self._write_line("UnaryExpr")
with self._child_level():
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_call_expr(self, expr: p.CallExpr) -> None:
self._write_line("CallExpr")
with self._child_level():
self._write_line("callee")
with self._child_level(single=True):
expr.callee.accept(self)
self._write_line("arguments")
with self._child_level():
for i, arg in enumerate(expr.arguments):
self._idx = i
if i == len(expr.arguments) - 1:
self._mark_last()
arg.accept(self)
self._write_line("keywords", last=True)
with self._child_level():
for i, (name, arg) in enumerate(expr.keywords.items()):
self._idx = i
if i == len(expr.keywords) - 1:
self._mark_last()
self._write_line(name)
with self._child_level(single=True):
arg.accept(self)
def visit_get_expr(self, expr: p.GetExpr) -> None:
self._write_line("GetExpr")
with self._child_level():
self._write_line("object")
with self._child_level(single=True):
expr.object.accept(self)
self._write_line(f"name: {expr.name}", last=True)
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
self._write_line("LiteralExpr")
with self._child_level(single=True):
self._write_line(f"value: {expr.value!r}")
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
self._write_line("VariableExpr")
with self._child_level(single=True):
self._write_line(f"name: {expr.name}")
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
self._write_line("LogicalExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_cast_expr(self, expr: p.CastExpr) -> None:
self._write_line("CastExpr")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
expr.type.accept(self)
self._write_line("expr")
with self._child_level(single=True):
expr.expr.accept(self)
self._write_line(f"unsafe: {expr.unsafe}", last=True)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self._write_line("TernaryExpr")
with self._child_level():
self._write_line("test")
with self._child_level(single=True):
expr.test.accept(self)
self._write_line("if_true")
with self._child_level(single=True):
expr.if_true.accept(self)
self._write_line("if_false", last=True)
with self._child_level(single=True):
expr.if_false.accept(self)
def visit_list_expr(self, expr: p.ListExpr) -> None:
self._write_line("ListExpr")
with self._child_level():
self._write_line("items", last=True)
with self._child_level():
for i, item in enumerate(expr.items):
self._idx = i
if i == len(expr.items) - 1:
self._mark_last()
item.accept(self)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
self._write_line("DictExpr")
with self._child_level():
self._write_line("keys")
with self._child_level():
for i, key in enumerate(expr.keys):
self._idx = i
if i == len(expr.keys) - 1:
self._mark_last()
if key is None:
self._write_line("None")
else:
key.accept(self)
self._write_line("values", last=True)
with self._child_level():
for i, value in enumerate(expr.values):
self._idx = i
if i == len(expr.values) - 1:
self._mark_last()
value.accept(self)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
self._write_line("SubscriptExpr")
with self._child_level():
self._write_line("object")
with self._child_level(single=True):
expr.object.accept(self)
self._write_line("index", last=True)
with self._child_level(single=True):
expr.index.accept(self)
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
self._write_line("SliceExpr")
with self._child_level():
self._write_optional_child("lower", expr.lower)
self._write_optional_child("upper", expr.upper)
self._write_optional_child("step", expr.step, last=True)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
self._write_line("TupleExpr")
with self._child_level():
self._write_line("items", last=True)
with self._child_level():
for i, item in enumerate(expr.items):
self._idx = i
if i == len(expr.items) - 1:
self._mark_last()
item.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None:
self._write_line("RawExpr")
with self._child_level(single=True):
self._write_line(f"expr: {ast.unparse(expr.expr)}")

View File

@@ -0,0 +1,3 @@
from .midas import MidasPrinter as MidasPrinter
from .midas_ast import MidasAstPrinter as MidasAstPrinter
from .python_ast import PythonAstPrinter as PythonAstPrinter

103
midas/ast/printer/base.py Normal file
View File

@@ -0,0 +1,103 @@
from __future__ import annotations
import io
from contextlib import contextmanager
from enum import Enum, auto
from typing import Callable, Generator, Generic, Optional, Protocol, Sequence, TypeVar
class _Level(Enum):
EMPTY = auto()
ACTIVE = auto()
LAST = auto()
class Expr(Protocol):
def accept(self, printer: AstPrinter) -> None: ...
T = TypeVar("T", bound=Expr)
class AstPrinter(Generic[T]):
LAST_CHILD = "└── "
CHILD = "├── "
VERTICAL = ""
EMPTY = " "
def __init__(self):
self._levels: list[_Level] = []
self._idx: Optional[int] = None
self._buf: io.StringIO = io.StringIO()
def print(self, expr: T):
self._buf = io.StringIO()
expr.accept(self)
return self._buf.getvalue()
@contextmanager
def _child_level(self, single: bool = False) -> Generator[None, None, None]:
self._levels.append(_Level.LAST if single else _Level.ACTIVE)
try:
yield
finally:
self._levels.pop()
def _mark_last(self):
if self._levels:
self._levels[-1] = _Level.LAST
def _write_line(self, text: str, *, last: bool = False):
if last:
self._mark_last()
indent: str = self._build_indent()
if self._idx is not None:
text = f"[{self._idx}] {text}"
self._idx = None
self._buf.write(indent + text + "\n")
def _build_indent(self) -> str:
parts: list[str] = []
for level in self._levels[:-1]:
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
if self._levels:
if self._levels[-1] == _Level.LAST:
parts.append(self.LAST_CHILD)
self._levels[-1] = _Level.EMPTY
else:
parts.append(self.CHILD)
return "".join(parts)
def _write_optional_child(
self, label: str, child: Optional[T], *, last: bool = False
):
if last:
self._mark_last()
if child is None:
self._write_line(f"{label}: None")
else:
self._write_line(label)
with self._child_level(single=True):
child.accept(self)
def _write_sequence(
self,
label: str,
list_: Sequence[T],
*,
last: bool = False,
print_func: Optional[Callable[[T], None]] = None,
):
if last:
self._mark_last()
self._write_line(label)
with self._child_level():
for i, item in enumerate(list_):
self._idx = i
if i == len(list_) - 1:
self._mark_last()
if print_func is not None:
print_func(item)
else:
item.accept(self)

183
midas/ast/printer/midas.py Normal file
View File

@@ -0,0 +1,183 @@
import midas.ast.midas as m
class MidasPrinter(
m.Expr.Visitor[str],
m.Stmt.Visitor[str],
m.Type.Visitor[str],
):
def __init__(self, indent: int = 4):
self.indent: int = indent
self.level: int = 0
def indented(self, text: str) -> str:
return " " * (self.level * self.indent) + text
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
self.level = 0
return expr.accept(self)
# Statements
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
template: str = ""
if len(stmt.params) != 0:
params: list[str] = [self._print_type_param(param) for param in stmt.params]
template = f"[{', '.join(params)}]"
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
return self.indented(res)
def visit_alias_stmt(self, stmt: m.AliasStmt) -> str:
return self.indented(f"alias {stmt.name.lexeme} = {stmt.type.accept(self)}")
def _print_type_param(self, param: m.TypeParam) -> str:
res: str = param.name.lexeme
if param.bound is not None:
res += "<:" + param.bound.accept(self)
return res
def visit_member_stmt(self, stmt: m.MemberStmt):
keyword: str = {
m.MemberKind.PROPERTY: "prop",
m.MemberKind.METHOD: "def",
}.get(stmt.kind, "")
res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}"
return self.indented(res)
def visit_extend_stmt(self, stmt: m.ExtendStmt):
template: str = ""
if len(stmt.params) != 0:
params: list[str] = [self._print_type_param(param) for param in stmt.params]
template = f"[{', '.join(params)}]"
res: str = self.indented(f"extend {stmt.name.lexeme}{template}")
res += " {\n"
self.level += 1
for member in stmt.members:
res += member.accept(self) + "\n"
self.level -= 1
res += self.indented("}")
return res
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
body: str = stmt.body.accept(self)
return self.indented(f"predicate {name}{sig} = {body}")
# Expressions
def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self)
operator: str = expr.operator.lexeme
right: str = expr.right.accept(self)
return f"{left} {operator} {right}"
def visit_binary_expr(self, expr: m.BinaryExpr):
left: str = expr.left.accept(self)
operator: str = expr.operator.lexeme
right: str = expr.right.accept(self)
return f"{left} {operator} {right}"
def visit_unary_expr(self, expr: m.UnaryExpr):
operator: str = expr.operator.lexeme
right: str = expr.right.accept(self)
return f"{operator}{right}"
def visit_call_expr(self, expr: m.CallExpr) -> str:
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
]
return f"{expr.callee.accept(self)}({', '.join(args)})"
def visit_get_expr(self, expr: m.GetExpr):
expr_: str = expr.expr.accept(self)
name: str = expr.name.lexeme
return f"{expr_}.{name}"
def visit_variable_expr(self, expr: m.VariableExpr):
return expr.name.lexeme
def visit_grouping_expr(self, expr: m.GroupingExpr):
expr_: str = expr.expr.accept(self)
return f"({expr_})"
def visit_literal_expr(self, expr: m.LiteralExpr):
return str(expr.value)
def visit_wildcard_expr(self, expr: m.WildcardExpr):
return "_"
# Types
def visit_named_type(self, type: m.NamedType) -> str:
return type.name.lexeme
def visit_generic_type(self, type: m.GenericType) -> str:
res: str = type.type.accept(self)
if len(type.args) != 0:
args: list[str] = [param.accept(self) for param in type.args]
res += f"[{', '.join(args)}]"
return res
def visit_constraint_type(self, type: m.ConstraintType) -> str:
res: str = type.type.accept(self)
res += " where " + type.constraint.accept(self)
return res
def visit_complex_type(self, type: m.ComplexType) -> str:
res: str = "{\n"
self.level += 1
for member in type.members:
res += member.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
return res
def visit_extension_type(self, type: m.ExtensionType) -> str:
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
def visit_function_type(self, type: m.FunctionType) -> str:
spec: str = self._visit_param_spec(type.params)
return f"fn {spec} -> {type.returns.accept(self)}"
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
pos: list[str] = [self._print_param(param) for param in spec.pos]
mixed: list[str] = [self._print_param(param) for param in spec.mixed]
kw: list[str] = [self._print_param(param) for param in spec.kw]
params: list[str] = pos
if len(pos) != 0:
params.append("/")
params += mixed
if len(kw) != 0:
params.append("*")
params += kw
return f"({', '.join(params)})"
def _print_param(self, param: m.FunctionType.Parameter) -> str:
res: str = ""
if param.name is not None:
res += param.name.lexeme
res += ": "
res += param.type.accept(self)
if not param.required:
res += "?"
return res
def visit_frame_type(self, type: m.FrameType) -> str:
res: str = self.indented("Frame[")
if len(type.columns) != 0:
res += "\n"
self.level += 1
columns: list[str] = []
for column in type.columns:
columns.append(self.indented(self._print_frame_column(column)))
res += ",\n".join(columns)
self.level -= 1
res += "\n"
res += "]"
return res
def _print_frame_column(self, column: m.FrameType.Column) -> str:
return f"{column.name.lexeme}: {column.type.accept(self)}"

View File

@@ -0,0 +1,253 @@
import midas.ast.midas as m
from midas.ast.printer.base import AstPrinter
class MidasAstPrinter(
AstPrinter,
m.Expr.Visitor[None],
m.Stmt.Visitor[None],
m.Type.Visitor[None],
):
# Statements
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self._write_line("TypeStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_sequence(
"params",
stmt.params,
print_func=self._print_type_param,
)
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
self._write_line("AliasStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def _print_type_param(self, param: m.TypeParam) -> None:
self._write_line("Param")
with self._child_level():
self._write_line(f'name: "{param.name.lexeme}"')
self._write_optional_child("bound", param.bound, last=True)
def visit_member_stmt(self, stmt: m.MemberStmt):
self._write_line("MemberStmt")
with self._child_level():
self._write_line(f"kind: {stmt.kind.name}")
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._write_line("ExtendStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_sequence(
"params",
stmt.params,
print_func=self._print_type_param,
)
self._write_sequence("members", stmt.members, last=True)
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
self._write_line("PredicateStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_sequence(
"params",
stmt.params,
print_func=self._visit_param_spec,
)
self._write_line("body", last=True)
with self._child_level(single=True):
stmt.body.accept(self)
# Expressions
def visit_logical_expr(self, expr: m.LogicalExpr):
self._write_line("LogicalExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.lexeme}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_binary_expr(self, expr: m.BinaryExpr):
self._write_line("BinaryExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.lexeme}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_unary_expr(self, expr: m.UnaryExpr):
self._write_line("UnaryExpr")
with self._child_level():
self._write_line(f"operator: {expr.operator.lexeme}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self._write_line("CallExpr")
with self._child_level():
self._write_line("callee")
with self._child_level(single=True):
expr.callee.accept(self)
self._write_sequence("arguments", expr.arguments)
self._write_line("keywords", last=True)
with self._child_level():
for i, (name, arg) in enumerate(expr.keywords.items()):
self._idx = i
if i == len(expr.keywords) - 1:
self._mark_last()
self._write_line(name)
with self._child_level(single=True):
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr):
self._write_line("GetExpr")
with self._child_level():
self._write_line("expr")
with self._child_level(single=True):
expr.expr.accept(self)
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
def visit_variable_expr(self, expr: m.VariableExpr):
self._write_line("VariableExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
def visit_grouping_expr(self, expr: m.GroupingExpr):
self._write_line("GroupingExpr")
with self._child_level():
self._write_line("expr", last=True)
with self._child_level(single=True):
expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
self._write_line("LiteralExpr")
with self._child_level():
self._write_line(f"value: {expr.value}", last=True)
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self._write_line("WildcardExpr")
# Types
def visit_named_type(self, type: m.NamedType) -> None:
self._write_line("NamedType")
with self._child_level():
self._write_line(f'name: "{type.name.lexeme}"', last=True)
def visit_generic_type(self, type: m.GenericType) -> None:
self._write_line("GenericType")
with self._child_level():
self._write_line("type")
with self._child_level():
type.type.accept(self)
self._write_sequence("args", type.args, last=True)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self._write_line("ConstraintType")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
type.type.accept(self)
self._write_line("constraint", last=True)
with self._child_level(single=True):
type.constraint.accept(self)
def visit_complex_type(self, type: m.ComplexType) -> None:
self._write_line("ComplexType")
with self._child_level():
self._write_sequence("members", type.members, last=True)
def visit_extension_type(self, type: m.ExtensionType) -> None:
self._write_line("ExtensionType")
with self._child_level():
self._write_line("base")
with self._child_level(single=True):
type.base.accept(self)
self._write_line("extension", last=True)
with self._child_level(single=True):
type.extension.accept(self)
def visit_function_type(self, type: m.FunctionType) -> None:
self._write_line("FunctionType")
with self._child_level():
self._write_line("params")
with self._child_level(single=True):
self._visit_param_spec(type.params)
self._write_line("returns", last=True)
with self._child_level(single=True):
type.returns.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
self._write_line("ParamSpec")
with self._child_level():
self._write_sequence(
"pos",
spec.pos,
print_func=self._print_param,
)
self._write_sequence(
"mixed",
spec.mixed,
print_func=self._print_param,
)
self._write_sequence(
"kw",
spec.kw,
print_func=self._print_param,
last=True,
)
def _print_param(self, param: m.FunctionType.Parameter) -> None:
self._write_line("Parameter")
with self._child_level():
name: str = "None"
if param.name is not None:
name = f'"{param.name.lexeme}"'
self._write_line(f"name: {name}")
self._write_line("type")
with self._child_level(single=True):
param.type.accept(self)
self._write_line(f"required: {param.required}", last=True)
def visit_frame_type(self, type: m.FrameType) -> None:
self._write_line("FrameType")
with self._child_level(single=True):
self._write_sequence(
"columns",
type.columns,
print_func=self._print_frame_column,
)
def _print_frame_column(self, column: m.FrameType.Column) -> None:
self._write_line("Column")
with self._child_level():
self._write_line(f'name: "{column.name.lexeme}"')
self._write_line("type")
with self._child_level(single=True):
column.type.accept(self)

View File

@@ -0,0 +1,285 @@
import ast
import midas.ast.python as p
from midas.ast.printer.base import AstPrinter
class PythonAstPrinter(
AstPrinter,
p.MidasType.Visitor[None],
p.Stmt.Visitor[None],
p.Expr.Visitor[None],
):
# Types
def visit_base_type(self, node: p.BaseType) -> None:
self._write_line("BaseType")
with self._child_level():
self._write_line(f"base: {node.base}")
self._write_sequence("args", node.args, last=True)
def visit_constraint_type(self, node: p.ConstraintType) -> None:
self._write_line("ConstraintType")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
node.type.accept(self)
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
def visit_frame_column(self, node: p.FrameColumn) -> None:
self._write_line("FrameColumn")
with self._child_level():
self._write_line(f"name: {node.name}")
self._write_optional_child("type", node.type, last=True)
def visit_frame_type(self, node: p.FrameType) -> None:
self._write_line("FrameType")
with self._child_level(single=True):
self._write_sequence("columns", node.columns)
# Statements
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
stmt.expr.accept(self)
def visit_function(self, stmt: p.Function) -> None:
self._write_line("Function")
with self._child_level():
self._write_line(f"name: {stmt.name}")
self._write_line("params")
with self._child_level():
self._print_param_spec(stmt.params)
self._write_optional_child("returns", stmt.returns)
self._write_sequence("body", stmt.body, last=True)
def _print_param_spec(self, spec: p.ParamSpec) -> None:
self._write_line("ParamSpec")
with self._child_level():
self._write_sequence(
"pos",
spec.pos,
print_func=self._print_param,
)
self._write_sequence(
"mixed",
spec.mixed,
print_func=self._print_param,
)
self._write_sequence(
"kw",
spec.kw,
print_func=self._print_param,
last=True,
)
def _print_param(self, param: p.Function.Parameter) -> None:
self._write_line("Parameter")
with self._child_level():
self._write_line(f"name: {param.name}")
self._write_optional_child("type", param.type, last=True)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
self._write_line("TypeAssign")
with self._child_level():
self._write_line(f"name: {stmt.name}")
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
self._write_line("AssignStmt")
with self._child_level():
self._write_sequence("targets", stmt.targets)
self._write_line("value", last=True)
with self._child_level(single=True):
stmt.value.accept(self)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
self._write_line("ReturnStmt")
with self._child_level():
self._write_optional_child("value", stmt.value, last=True)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
self._write_line("IfStmt")
with self._child_level():
self._write_line("test")
with self._child_level(single=True):
stmt.test.accept(self)
self._write_sequence("body", stmt.body)
self._write_sequence("orelse", stmt.orelse, last=True)
def visit_pass(self, stmt: p.Pass) -> None:
self._write_line("Pass")
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
self._write_line("ForStmt")
with self._child_level():
self._write_line("target")
with self._child_level(single=True):
stmt.target.accept(self)
self._write_line("iterator")
with self._child_level(single=True):
stmt.iterator.accept(self)
self._write_sequence("body", stmt.body, last=True)
def visit_raw_stmt(self, stmt: p.RawStmt) -> None:
self._write_line("RawStmt")
with self._child_level(single=True):
self._write_line(f"stmt: {ast.unparse(stmt.stmt)}")
# Expressions
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
self._write_line("CompareExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
self._write_line("UnaryExpr")
with self._child_level():
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_call_expr(self, expr: p.CallExpr) -> None:
self._write_line("CallExpr")
with self._child_level():
self._write_line("callee")
with self._child_level(single=True):
expr.callee.accept(self)
self._write_sequence("arguments", expr.arguments)
self._write_line("keywords", last=True)
with self._child_level():
for i, (name, arg) in enumerate(expr.keywords.items()):
self._idx = i
if i == len(expr.keywords) - 1:
self._mark_last()
self._write_line(name)
with self._child_level(single=True):
arg.accept(self)
def visit_get_expr(self, expr: p.GetExpr) -> None:
self._write_line("GetExpr")
with self._child_level():
self._write_line("object")
with self._child_level(single=True):
expr.object.accept(self)
self._write_line(f"name: {expr.name}", last=True)
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
self._write_line("LiteralExpr")
with self._child_level(single=True):
self._write_line(f"value: {expr.value!r}")
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
self._write_line("VariableExpr")
with self._child_level(single=True):
self._write_line(f"name: {expr.name}")
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
self._write_line("LogicalExpr")
with self._child_level():
self._write_line("left")
with self._child_level(single=True):
expr.left.accept(self)
self._write_line(f"operator: {expr.operator.__class__.__name__}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_cast_expr(self, expr: p.CastExpr) -> None:
self._write_line("CastExpr")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
expr.type.accept(self)
self._write_line("expr")
with self._child_level(single=True):
expr.expr.accept(self)
self._write_line(f"unsafe: {expr.unsafe}", last=True)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self._write_line("TernaryExpr")
with self._child_level():
self._write_line("test")
with self._child_level(single=True):
expr.test.accept(self)
self._write_line("if_true")
with self._child_level(single=True):
expr.if_true.accept(self)
self._write_line("if_false", last=True)
with self._child_level(single=True):
expr.if_false.accept(self)
def visit_list_expr(self, expr: p.ListExpr) -> None:
self._write_line("ListExpr")
with self._child_level():
self._write_sequence("items", expr.items, last=True)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
self._write_line("DictExpr")
with self._child_level():
self._write_sequence(
"keys",
expr.keys,
print_func=lambda k: (
self._write_line("None") if k is None else k.accept(self)
),
)
self._write_sequence("values", expr.values, last=True)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
self._write_line("SubscriptExpr")
with self._child_level():
self._write_line("object")
with self._child_level(single=True):
expr.object.accept(self)
self._write_line("index", last=True)
with self._child_level(single=True):
expr.index.accept(self)
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
self._write_line("SliceExpr")
with self._child_level():
self._write_optional_child("lower", expr.lower)
self._write_optional_child("upper", expr.upper)
self._write_optional_child("step", expr.step, last=True)
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
self._write_line("TupleExpr")
with self._child_level():
self._write_sequence("items", expr.items, last=True)
def visit_raw_expr(self, expr: p.RawExpr) -> None:
self._write_line("RawExpr")
with self._child_level(single=True):
self._write_line(f"expr: {ast.unparse(expr.expr)}")

View File

@@ -14,6 +14,16 @@ from midas.ast.location import Location
T = TypeVar("T") T = TypeVar("T")
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
pos: list[Function.Parameter]
mixed: list[Function.Parameter]
kw: list[Function.Parameter]
@property
def all(self) -> list[Function.Parameter]:
return self.pos + self.mixed + self.kw
#################### ####################
# Type annotations # # Type annotations #
@@ -128,25 +138,17 @@ class ExpressionStmt(Stmt):
@dataclass(frozen=True) @dataclass(frozen=True)
class Function(Stmt): class Function(Stmt):
name: str name: str
posonlyargs: list[Argument] params: ParamSpec
args: list[Argument]
sink: Optional[Argument]
kwonlyargs: list[Argument]
kw_sink: Optional[Argument]
returns: Optional[MidasType] returns: Optional[MidasType]
body: list[Stmt] body: list[Stmt]
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Argument: class Parameter:
location: Optional[Location] = None location: Optional[Location] = None
name: str name: str
type: Optional[MidasType] type: Optional[MidasType]
default: Optional[Expr] default: Optional[Expr]
@property
def all_args(self) -> list[Argument]:
return self.posonlyargs + self.args + self.kwonlyargs
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_function(self) return visitor.visit_function(self)

View File

@@ -17,8 +17,12 @@ if TYPE_CHECKING:
BUILTIN_SUBTYPES: dict[str, set[str]] = { BUILTIN_SUBTYPES: dict[str, set[str]] = {
"object": {"float", "list", "dict", "str", "bytes", "tuple"}, "object": {"float", "list", "dict", "str", "bytes", "tuple"},
"float": {"int"}, "float": {"int"},
"int": {"bool"},
} }
"""
Hard-coded subtype relationships between builtin types
Circular dependencies and diamond inheritance MUST be avoided
"""
def define_builtins(reg: TypesRegistry): def define_builtins(reg: TypesRegistry):

View File

@@ -10,6 +10,11 @@ from midas.utils import TypedAST
class TypeChecker: class TypeChecker:
"""Type checking dispatcher
Contains a typer for Midas and one for Python, as well as the types registry
"""
def __init__(self): def __init__(self):
self.types: TypesRegistry = TypesRegistry() self.types: TypesRegistry = TypesRegistry()
self.reporter: Reporter = Reporter() self.reporter: Reporter = Reporter()

View File

@@ -14,6 +14,15 @@ class DiagnosticType(StrEnum):
@dataclass(frozen=True) @dataclass(frozen=True)
class Diagnostic: class Diagnostic:
"""Information about a diagnostic (warning, errors, etc.)
Holds a location, a diagnostic type and a message.
Optionally bound to a file path
Returns:
_type_: _description_
"""
file_path: Optional[str] file_path: Optional[str]
location: Location location: Location
type: DiagnosticType type: DiagnosticType
@@ -21,6 +30,18 @@ class Diagnostic:
@property @property
def location_str(self) -> str: def location_str(self) -> str:
"""The diagnostic type and location as a human readable string
The location is formatted as "<Type> in <file> from L<start_line>:<start_col> to <end_line>:<end_col>",
for example: "Error in /home/user/Desktop/script.py from L12:5 to L12:8"
If the file is `None`, the "in ..." section is excluded from the result.<br>
If the location's end is not specified, the formulation "at L<start_line>:<start_col>" is used.
Returns:
str: _description_
"""
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}" start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
end_loc: Optional[str] = "" end_loc: Optional[str] = ""
if ( if (

View File

@@ -30,9 +30,9 @@ TypedExpr = tuple[E, Type]
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class MappedArgument(Generic[E]): class MappedArgument(Generic[E]):
expr: E arg_expr: E
type: Type arg_type: Type
argument: Function.Argument parameter: Function.Parameter
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
@@ -219,11 +219,11 @@ class CallDispatcher(Generic[E]):
""" """
valid: bool = True valid: bool = True
for arg in arguments: for arg in arguments:
if not self.types.is_subtype(arg.type, arg.argument.type): if not self.types.is_subtype(arg.arg_type, arg.parameter.type):
if report_errors: if report_errors:
self.reporter.error( self.reporter.error(
arg.expr.location, arg.arg_expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", f"Wrong type for argument '{arg.parameter.name}', expected {arg.parameter.type}, got {arg.arg_type}",
) )
valid = False valid = False
return valid return valid
@@ -347,28 +347,30 @@ class CallDispatcher(Generic[E]):
tuple[bool, list[MappedArgument]]: a boolean reporting whether tuple[bool, list[MappedArgument]]: a boolean reporting whether
the call is valid and the list of mapped arguments the call is valid and the list of mapped arguments
""" """
set_args: set[str] = set() set_params: set[str] = set()
required_positional: list[str] = [ required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required param.name
for param in function.params.pos + function.params.mixed
if param.required
] ]
required_keyword: list[str] = [ required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required param.name for param in function.params.kw if param.required
] ]
mapped: list[MappedArgument[E]] = [] mapped: list[MappedArgument[E]] = []
pos_params: list[Function.Argument] = list(function.pos_args) pos_params: list[Function.Parameter] = list(function.params.pos)
mixed_params: list[Function.Argument] = list(function.args) mixed_params: list[Function.Parameter] = list(function.params.mixed)
kw_params: dict[str, Function.Argument] = { kw_params: dict[str, Function.Parameter] = {
arg.name: arg for arg in function.kw_args param.name: param for param in function.params.kw
} }
valid_call: bool = True valid_call: bool = True
# TODO: handle *args and **kwargs sinks # TODO: handle *args and **kwargs sinks
for arg in positional: for arg in positional:
param: Function.Argument param: Function.Parameter
if len(pos_params) != 0: if len(pos_params) != 0:
param = pos_params.pop(0) param = pos_params.pop(0)
elif len(mixed_params) != 0: elif len(mixed_params) != 0:
@@ -385,27 +387,27 @@ class CallDispatcher(Generic[E]):
required_positional.remove(name) required_positional.remove(name)
if name in required_keyword: if name in required_keyword:
required_keyword.remove(name) required_keyword.remove(name)
set_args.add(name) set_params.add(name)
mapped.append( mapped.append(
MappedArgument( MappedArgument(
expr=arg[0], arg_expr=arg[0],
type=arg[1], arg_type=arg[1],
argument=param, parameter=param,
) )
) )
kw_params.update({arg.name: arg for arg in mixed_params}) kw_params.update({param.name: param for param in mixed_params})
for name, arg in keywords.items(): for name, arg in keywords.items():
param: Function.Argument param: Function.Parameter
if name not in kw_params: if name not in kw_params:
if report_errors: if report_errors:
if name in set_args: if name in set_params:
self.reporter.error( self.reporter.error(
arg[0].location, f"Multiple values for argument '{name}'" arg[0].location, f"Multiple values for parameter '{name}'"
) )
else: else:
self.reporter.error( self.reporter.error(
arg[0].location, f"Unknown keyword argument '{name}'" arg[0].location, f"Unknown keyword parameter '{name}'"
) )
valid_call = False valid_call = False
continue continue
@@ -414,40 +416,40 @@ class CallDispatcher(Generic[E]):
required_positional.remove(name) required_positional.remove(name)
if name in required_keyword: if name in required_keyword:
required_keyword.remove(name) required_keyword.remove(name)
set_args.add(name) set_params.add(name)
mapped.append( mapped.append(
MappedArgument( MappedArgument(
expr=arg[0], arg_expr=arg[0],
type=arg[1], arg_type=arg[1],
argument=param, parameter=param,
) )
) )
def join_args(args: list[str]) -> str: def join_params(params: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args)) params = list(map(lambda p: f"'{p}'", params))
if len(args) == 0: if len(params) == 0:
return "" return ""
if len(args) == 1: if len(params) == 1:
return args[0] return params[0]
return ", ".join(args[:-1]) + " and " + args[-1] return ", ".join(params[:-1]) + " and " + params[-1]
if len(required_positional) != 0: if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s" plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional) params: str = join_params(required_positional)
if report_errors: if report_errors:
self.reporter.error( self.reporter.error(
location, location,
f"Missing required positional argument{plural}: {args}", f"Missing required positional argument{plural}: {params}",
) )
valid_call = False valid_call = False
if len(required_keyword) != 0: if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s" plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword) params: str = join_params(required_keyword)
if report_errors: if report_errors:
self.reporter.error( self.reporter.error(
location, location,
f"Missing required keyword argument{plural}: {args}", f"Missing required keyword argument{plural}: {params}",
) )
valid_call = False valid_call = False
@@ -474,11 +476,11 @@ class CallDispatcher(Generic[E]):
""" """
by_expr: dict[E, Type] = {} by_expr: dict[E, Type] = {}
for arg in mapped1: for arg in mapped1:
by_expr[arg.expr] = arg.argument.type by_expr[arg.arg_expr] = arg.parameter.type
for arg in mapped2: for arg in mapped2:
type2: Type = arg.argument.type type2: Type = arg.parameter.type
type1: Type = by_expr[arg.expr] type1: Type = by_expr[arg.arg_expr]
if not self.types.is_subtype(type1, type2): if not self.types.is_subtype(type1, type2):
return False return False
return True return True

View File

@@ -158,15 +158,17 @@ class Evaluator(m.Expr.Visitor[Any]):
return res return res
def _map_args(self, function: Function, args: list[Any], kwargs: dict[str, Any]): def _map_args(self, function: Function, args: list[Any], kwargs: dict[str, Any]):
positional: list[Function.Argument] = function.pos_args + function.args positional: list[Function.Parameter] = (
keywords: dict[str, Function.Argument] = { function.params.pos + function.params.mixed
arg.name: arg for arg in function.args + function.kw_args )
keywords: dict[str, Function.Parameter] = {
param.name: param for param in function.params.mixed + function.params.kw
} }
for i, arg in enumerate(args): for i, arg in enumerate(args):
param: Function.Argument = positional[i] param: Function.Parameter = positional[i]
self.set_value(param.name, arg) self.set_value(param.name, arg)
for name, arg in kwargs.items(): for name, arg in kwargs.items():
param: Function.Argument = keywords[name] param: Function.Parameter = keywords[name]
self.set_value(param.name, arg) self.set_value(param.name, arg)

View File

@@ -7,7 +7,14 @@ import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.dispatcher import CallResult from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import ColumnGroupBy, Function, Type from midas.checker.types import (
ColumnGroupBy,
ColumnType,
Function,
ParamSpec,
TopType,
Type,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from midas.checker.python import TypedExpr from midas.checker.python import TypedExpr
@@ -22,39 +29,52 @@ class Call:
positional: list[TypedExpr] positional: list[TypedExpr]
keywords: dict[str, TypedExpr] keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.groupby_expr, self.groupby)
class ColumnGroupByMethodRegistry(MethodRegistry[Call]): class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
@method() NAMED_ARGS: dict[str, str] = {
def mean(self, call: Call) -> Type: "numeric_only": "bool",
bool_ = self.types.get_type("bool") "skipna": "bool",
"engine": "str",
"engine_kwargs": "dict",
}
def _aggregate(
self,
call: Call,
params: list[str | tuple[str, str, bool]] = [],
*,
preserve_inner_type: bool = False,
) -> Type:
real_params: list[Function.Parameter] = []
for i, param in enumerate(params):
match param:
case str() as name:
param = Function.Parameter(
pos=i,
name=name,
type=self.types.get_type(self.NAMED_ARGS[name]),
required=False,
)
case (name, type, required):
param = Function.Parameter(
pos=i,
name=name,
type=self.types.get_type(type),
required=required,
)
real_params.append(param)
signature = Function( signature = Function(
args=[ params=ParamSpec(mixed=real_params),
Function.Argument( returns=(
pos=0, call.groupby.column
name="numeric_only", if preserve_inner_type
type=bool_, else ColumnType(type=TopType())
required=False,
), ),
Function.Argument(
pos=1,
name="skipna",
type=bool_,
required=False,
),
Function.Argument(
pos=2,
name="engine",
type=self.types.get_type("str"),
required=False,
),
Function.Argument(
pos=3,
name="engine_kwargs",
type=self.types.get_type("dict"),
required=False,
),
],
returns=call.groupby.column,
) )
result: CallResult = self.dispatcher.get_result( result: CallResult = self.dispatcher.get_result(
@@ -64,3 +84,127 @@ class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
keywords=call.keywords, keywords=call.keywords,
) )
return result.result return result.result
@method()
def kurt(self, call: Call) -> Type:
return self._aggregate(
call,
["skipna", "numeric_only"],
)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
preserve_inner_type=True,
)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna", "engine", "engine_kwargs"],
)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(
call,
["numeric_only", "skipna"],
preserve_inner_type=True,
)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
preserve_inner_type=True,
)
@method()
def prod(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
],
)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"ddof",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(
call,
[
"numeric_only",
(
"min_count",
"int",
False,
),
"skipna",
"engine",
"engine_kwargs",
],
)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
(
"var",
"int",
False,
),
"engine",
"engine_kwargs",
"numeric_only",
"skipna",
],
)

View File

@@ -1,12 +1,13 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.frames.column_groupby_methods import Call as GroupByCall from midas.checker.frames.column_groupby_methods import Call as GroupByCall
from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
from midas.checker.registry import TypesRegistry
from midas.checker.types import ColumnGroupBy, ColumnType, Type from midas.checker.types import ColumnGroupBy, ColumnType, Type
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -60,3 +61,18 @@ class ColumnManager:
keywords=keywords, keywords=keywords,
) )
return self.groupby_method_resolver.call(method, call) return self.groupby_method_resolver.call(method, call)
def get_attribute(self, column: ColumnType, name: str) -> Optional[Type]:
types: TypesRegistry = self.typer.types
match name:
case "ndim" | "size":
return types.get_type("int")
case "shape":
return types.tuple_of("int")
case "T":
return column
case _:
return None

View File

@@ -13,6 +13,7 @@ from midas.checker.types import (
ColumnType, ColumnType,
Function, Function,
GenericType, GenericType,
ParamSpec,
TopType, TopType,
Type, Type,
TypeVar, TypeVar,
@@ -33,6 +34,10 @@ class Call:
positional: list[TypedExpr] positional: list[TypedExpr]
keywords: dict[str, TypedExpr] keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.column_expr, self.column)
class ColumnMethodRegistry(MethodRegistry[Call]): class ColumnMethodRegistry(MethodRegistry[Call]):
def _element_binary_op(self, call: Call, method: str) -> ColumnType: def _element_binary_op(self, call: Call, method: str) -> ColumnType:
@@ -69,8 +74,7 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
new_column = ColumnType(type=new_inner_type) new_column = ColumnType(type=new_inner_type)
return new_column return new_column
@method("add", "__add__") def _element_wise(self, call: Call, method: str) -> Type:
def add(self, call: Call) -> Type:
# TODO: support add with scalar # TODO: support add with scalar
# Build signature with new column type and generic operand # Build signature with new column type and generic operand
@@ -79,15 +83,17 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
name="add", name="add",
params=[param_type], params=[param_type],
body=Function( body=Function(
args=[ params=ParamSpec(
Function.Argument( mixed=[
Function.Parameter(
pos=0, pos=0,
name="other", name="other",
type=ColumnType(type=param_type), type=ColumnType(type=param_type),
required=True, required=True,
), ),
], ],
returns=self._element_binary_op(call, "__add__"), ),
returns=self._element_binary_op(call, method),
), ),
) )
@@ -105,18 +111,186 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
return result.result return result.result
@method() @method("add", "__add__")
def mean(self, call: Call) -> Type: def add(self, call: Call) -> Type:
return self._element_wise(call, "__add__")
@method("sub", "__sub__")
def sub(self, call: Call) -> Type:
return self._element_wise(call, "__sub__")
@method("mul", "__mul__")
def mul(self, call: Call) -> Type:
return self._element_wise(call, "__mul__")
@method("div", "truediv", "__truediv__")
def truediv(self, call: Call) -> Type:
return self._element_wise(call, "__truediv__")
@method("floordiv", "__floordiv__")
def floordiv(self, call: Call) -> Type:
return self._element_wise(call, "__floordiv__")
@method("mod", "__mod__")
def mod(self, call: Call) -> Type:
return self._element_wise(call, "__mod__")
@method("pow", "__pow__")
def pow(self, call: Call) -> Type:
return self._element_wise(call, "__pow__")
@method("lt", "__lt__")
def lt(self, call: Call) -> Type:
return self._element_wise(call, "__lt__")
@method("gt", "__gt__")
def gt(self, call: Call) -> Type:
return self._element_wise(call, "__gt__")
@method("le", "__le__")
def le(self, call: Call) -> Type:
return self._element_wise(call, "__le__")
@method("ge", "__ge__")
def ge(self, call: Call) -> Type:
return self._element_wise(call, "__ge__")
@method("ne", "__ne__")
def ne(self, call: Call) -> Type:
return self._element_wise(call, "__ne__")
@method("eq", "__eq__")
def eq(self, call: Call) -> Type:
return self._element_wise(call, "__eq__")
def _aggregate(
self,
call: Call,
kwargs: list[Function.Parameter] = [],
*,
preserve_inner_type: bool = False,
) -> Type:
signature = Function( signature = Function(
kw_args=[ params=ParamSpec(
Function.Argument( kw=[
Function.Parameter(
pos=0, pos=0,
name="axis", name="axis",
type=TopType(), type=TopType(),
required=False, required=False,
),
*kwargs,
],
),
returns=call.column if preserve_inner_type else ColumnType(type=TopType()),
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method("kurtosis", "kurt")
def kurtosis(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method()
def mode(self, call: Call) -> Type:
return self._aggregate(call, preserve_inner_type=True)
@method("product", "prod")
def product(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Parameter(
pos=1,
name="ddof",
type=self.types.get_type("int"),
required=False,
) )
], ],
returns=ColumnType(type=TopType()), )
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Parameter(
pos=1,
name="var",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def head(self, call: Call) -> Type:
signature = Function(
params=ParamSpec(
mixed=[
Function.Parameter(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
),
returns=call.column,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def tail(self, call: Call) -> Type:
signature = Function(
params=ParamSpec(
mixed=[
Function.Parameter(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
),
returns=call.column,
) )
result: CallResult = self.dispatcher.get_result( result: CallResult = self.dispatcher.get_result(
@@ -131,52 +305,33 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
def groupby(self, call: Call) -> Type: def groupby(self, call: Call) -> Type:
bool_: Type = self.types.get_type("bool") bool_: Type = self.types.get_type("bool")
function: Function = Function( function: Function = Function(
args=[ params=ParamSpec(
Function.Argument( mixed=[
Function.Parameter(
pos=0, pos=0,
name="by", name="by",
type=TopType(), type=TopType(),
required=False, required=False,
), ),
Function.Argument( Function.Parameter(
pos=1, pos=1,
name="level", name="level",
type=TopType(), type=TopType(),
required=False, required=False,
), ),
], ],
kw_args=[ kw=[
Function.Argument( Function.Parameter(
pos=2, pos=i + 2,
name="as_index", name=name,
type=bool_, type=bool_,
required=False, required=False,
), )
Function.Argument( for i, name in enumerate(
pos=3, ["as_index", "sort", "group_keys", "observed", "dropna"]
name="sort", )
type=bool_,
required=False,
),
Function.Argument(
pos=4,
name="group_keys",
type=bool_,
required=False,
),
Function.Argument(
pos=5,
name="observed",
type=bool_,
required=False,
),
Function.Argument(
pos=6,
name="dropna",
type=bool_,
required=False,
),
], ],
),
returns=ColumnGroupBy(column=call.column), returns=ColumnGroupBy(column=call.column),
) )
@@ -190,6 +345,21 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr): def _assert_same_length(self, call_expr: p.Expr, column1: p.Expr, column2: p.Expr):
func_name: str = "__midas_column_same_length__" func_name: str = "__midas_column_same_length__"
# Efficiently compute length
# https://stackoverflow.com/a/15943975/11109181
def len_of_col(col: ast.expr) -> ast.expr:
return ast.Call(
func=ast.Name(id="len"),
args=[
ast.Attribute(
value=col,
attr="index",
)
],
keywords=[],
)
self.assertions.define( self.assertions.define(
func_name, func_name,
ast.FunctionDef( ast.FunctionDef(
@@ -207,16 +377,10 @@ class ColumnMethodRegistry(MethodRegistry[Call]):
body=[ body=[
ast.Return( ast.Return(
value=ast.Compare( value=ast.Compare(
left=ast.Attribute( left=len_of_col(ast.Name(id="column1")),
value=ast.Name(id="column1"),
attr="size",
),
ops=[ast.Eq()], ops=[ast.Eq()],
comparators=[ comparators=[
ast.Attribute( len_of_col(ast.Name(id="column2")),
value=ast.Name(id="column2"),
attr="size",
)
], ],
) )
) )

View File

@@ -5,9 +5,15 @@ from typing import TYPE_CHECKING
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import FrameGroupBy, Function, Type from midas.checker.types import (
ColumnGroupBy,
ColumnType,
DataFrameType,
FrameGroupBy,
Type,
UnknownType,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from midas.checker.python import TypedExpr from midas.checker.python import TypedExpr
@@ -22,45 +28,76 @@ class Call:
positional: list[TypedExpr] positional: list[TypedExpr]
keywords: dict[str, TypedExpr] keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.groupby_expr, self.groupby)
class FrameGroupByMethodRegistry(MethodRegistry[Call]): class FrameGroupByMethodRegistry(MethodRegistry[Call]):
@method() NAMED_ARGS: dict[str, str] = {
def mean(self, call: Call) -> Type: "numeric_only": "bool",
bool_ = self.types.get_type("bool") "skipna": "bool",
signature = Function( "engine": "str",
args=[ "engine_kwargs": "dict",
Function.Argument( }
pos=0,
name="numeric_only",
type=bool_,
required=False,
),
Function.Argument(
pos=1,
name="skipna",
type=bool_,
required=False,
),
Function.Argument(
pos=2,
name="engine",
type=self.types.get_type("str"),
required=False,
),
Function.Argument(
pos=3,
name="engine_kwargs",
type=self.types.get_type("dict"),
required=False,
),
],
returns=call.groupby.frame,
)
result: CallResult = self.dispatcher.get_result( def _aggregate(self, call: Call, method: str) -> Type:
new_columns: list[DataFrameType.Column] = []
for column in call.groupby.frame.columns:
column_groupby: ColumnGroupBy = ColumnGroupBy(column=column.type)
result_type: Type = self.typer.call_method(
location=call.location, location=call.location,
callee=signature, call_expr=call.call_expr,
obj=(call.groupby_expr, column_groupby),
method_name=method,
positional=call.positional, positional=call.positional,
keywords=call.keywords, keywords=call.keywords,
) )
return result.result if not isinstance(result_type, ColumnType):
result_type = ColumnType(type=UnknownType())
new_columns.append(
DataFrameType.Column(
index=column.index,
name=column.name,
type=result_type,
)
)
return DataFrameType(columns=new_columns)
@method()
def kurt(self, call: Call) -> Type:
return self._aggregate(call, "kurt")
@method()
def max(self, call: Call) -> Type:
return self._aggregate(call, "max")
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(call, "mean")
@method()
def median(self, call: Call) -> Type:
return self._aggregate(call, "median")
@method()
def min(self, call: Call) -> Type:
return self._aggregate(call, "min")
@method()
def prod(self, call: Call) -> Type:
return self._aggregate(call, "prod")
@method()
def std(self, call: Call) -> Type:
return self._aggregate(call, "std")
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(call, "sum")
@method()
def var(self, call: Call) -> Type:
return self._aggregate(call, "var")

View File

@@ -7,6 +7,7 @@ from midas.ast.location import Location
from midas.checker.frames.frame_groupby_methods import Call as GroupByCall from midas.checker.frames.frame_groupby_methods import Call as GroupByCall
from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter from midas.checker.reporter import FileReporter
from midas.checker.types import ( from midas.checker.types import (
ColumnGroupBy, ColumnGroupBy,
@@ -240,3 +241,15 @@ class FrameManager:
keywords=keywords, keywords=keywords,
) )
return self.groupby_method_resolver.call(method, call) return self.groupby_method_resolver.call(method, call)
def get_attribute(self, frame: DataFrameType, name: str) -> Optional[Type]:
types: TypesRegistry = self.typer.types
match name:
case "ndim" | "size":
return types.get_type("int")
case "shape":
return types.tuple_of("int", "int")
case _:
return None

View File

@@ -14,6 +14,7 @@ from midas.checker.types import (
FrameGroupBy, FrameGroupBy,
Function, Function,
OverloadedFunction, OverloadedFunction,
ParamSpec,
TopType, TopType,
Type, Type,
UnknownType, UnknownType,
@@ -33,6 +34,10 @@ class Call:
positional: list[TypedExpr] positional: list[TypedExpr]
keywords: dict[str, TypedExpr] keywords: dict[str, TypedExpr]
@property
def subject(self) -> TypedExpr:
return (self.frame_expr, self.frame)
class FrameMethodRegistry(MethodRegistry[Call]): class FrameMethodRegistry(MethodRegistry[Call]):
def _get_method_result( def _get_method_result(
@@ -142,21 +147,21 @@ class FrameMethodRegistry(MethodRegistry[Call]):
return DataFrameType(columns=new_columns) return DataFrameType(columns=new_columns)
@method("add", "__add__") def _element_wise(self, call: Call, method: str) -> Type:
def add(self, call: Call) -> Type: # TODO: support scalar, sequence, Series, dict operand
# TODO: support add with scalar, sequence, Series, dict
# Build signature with new schema and generic operand # Build signature with new schema and generic operand
signature = Function( signature = Function(
args=[ params=ParamSpec(
Function.Argument( mixed=[
Function.Parameter(
pos=0, pos=0,
name="other", name="other",
type=DataFrameType(columns=[]), type=DataFrameType(columns=[]),
required=True, required=True,
), ),
], ],
returns=self._element_binary_op(call, "__add__"), ),
returns=self._element_binary_op(call, method),
) )
# Map arguments and compute result type # Map arguments and compute result type
@@ -173,28 +178,85 @@ class FrameMethodRegistry(MethodRegistry[Call]):
return result.result return result.result
@method() @method("add", "__add__")
def mean(self, call: Call) -> Type: def add(self, call: Call) -> Type:
return self._element_wise(call, "__add__")
@method("sub", "__sub__")
def sub(self, call: Call) -> Type:
return self._element_wise(call, "__sub__")
@method("mul", "__mul__")
def mul(self, call: Call) -> Type:
return self._element_wise(call, "__mul__")
@method("div", "truediv", "__truediv__")
def truediv(self, call: Call) -> Type:
return self._element_wise(call, "__truediv__")
@method("floordiv", "__floordiv__")
def floordiv(self, call: Call) -> Type:
return self._element_wise(call, "__floordiv__")
@method("mod", "__mod__")
def mod(self, call: Call) -> Type:
return self._element_wise(call, "__mod__")
@method("pow", "__pow__")
def pow(self, call: Call) -> Type:
return self._element_wise(call, "__pow__")
@method("lt", "__lt__")
def lt(self, call: Call) -> Type:
return self._element_wise(call, "__lt__")
@method("gt", "__gt__")
def gt(self, call: Call) -> Type:
return self._element_wise(call, "__gt__")
@method("le", "__le__")
def le(self, call: Call) -> Type:
return self._element_wise(call, "__le__")
@method("ge", "__ge__")
def ge(self, call: Call) -> Type:
return self._element_wise(call, "__ge__")
@method("ne", "__ne__")
def ne(self, call: Call) -> Type:
return self._element_wise(call, "__ne__")
@method("eq", "__eq__")
def eq(self, call: Call) -> Type:
return self._element_wise(call, "__eq__")
def _aggregate(self, call: Call, kwargs: list[Function.Parameter] = []) -> Type:
with_axis = Function( with_axis = Function(
kw_args=[ params=ParamSpec(
Function.Argument( kw=[
Function.Parameter(
pos=0, pos=0,
name="axis", name="axis",
type=self.types.get_type("int"), type=self.types.get_type("int"),
required=False, required=False,
) ),
*kwargs,
], ],
),
returns=ColumnType(type=TopType()), returns=ColumnType(type=TopType()),
) )
without_axis = Function( without_axis = Function(
kw_args=[ params=ParamSpec(
Function.Argument( kw=[
Function.Parameter(
pos=0, pos=0,
name="axis", name="axis",
type=self.types.get_type("None"), type=self.types.get_type("None"),
required=True, required=True,
) ),
*kwargs,
], ],
),
returns=TopType(), returns=TopType(),
) )
overload = OverloadedFunction( overload = OverloadedFunction(
@@ -212,56 +274,145 @@ class FrameMethodRegistry(MethodRegistry[Call]):
) )
return result.result return result.result
@method("kurtosis", "kurt")
def kurtosis(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def max(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def mean(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def median(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def min(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def mode(self, call: Call) -> Type:
return self._aggregate(call)
@method("product", "prod")
def product(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def std(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Parameter(
pos=1,
name="ddof",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def sum(self, call: Call) -> Type:
return self._aggregate(call)
@method()
def var(self, call: Call) -> Type:
return self._aggregate(
call,
[
Function.Parameter(
pos=1,
name="var",
type=self.types.get_type("int"),
required=False,
)
],
)
@method()
def head(self, call: Call) -> Type:
signature = Function(
params=ParamSpec(
mixed=[
Function.Parameter(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
),
returns=call.frame,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method()
def tail(self, call: Call) -> Type:
signature = Function(
params=ParamSpec(
mixed=[
Function.Parameter(
pos=0,
name="n",
type=self.types.get_type("int"),
required=False,
),
],
),
returns=call.frame,
)
result: CallResult = self.dispatcher.get_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
return result.result
@method() @method()
def groupby(self, call: Call) -> Type: def groupby(self, call: Call) -> Type:
bool_: Type = self.types.get_type("bool") bool_: Type = self.types.get_type("bool")
function: Function = Function( function: Function = Function(
args=[ params=ParamSpec(
Function.Argument( mixed=[
Function.Parameter(
pos=0, pos=0,
name="by", name="by",
type=TopType(), type=TopType(),
required=False, required=False,
), ),
Function.Argument( Function.Parameter(
pos=1, pos=1,
name="level", name="level",
type=TopType(), type=TopType(),
required=False, required=False,
), ),
], ],
kw_args=[ kw=[
Function.Argument( Function.Parameter(
pos=2, pos=i + 2,
name="as_index", name=name,
type=bool_, type=bool_,
required=False, required=False,
), )
Function.Argument( for i, name in enumerate(
pos=3, ["as_index", "sort", "group_keys", "observed", "dropna"]
name="sort", )
type=bool_,
required=False,
),
Function.Argument(
pos=4,
name="group_keys",
type=bool_,
required=False,
),
Function.Argument(
pos=5,
name="observed",
type=bool_,
required=False,
),
Function.Argument(
pos=6,
name="dropna",
type=bool_,
required=False,
),
], ],
),
returns=FrameGroupBy(frame=call.frame), returns=FrameGroupBy(frame=call.frame),
) )
@@ -275,6 +426,21 @@ class FrameMethodRegistry(MethodRegistry[Call]):
def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr): def _assert_same_length(self, call_expr: p.Expr, frame1: p.Expr, frame2: p.Expr):
func_name: str = "__midas_frame_same_length__" func_name: str = "__midas_frame_same_length__"
# Efficiently compute length
# https://stackoverflow.com/a/15943975/11109181
def len_of_df(df: ast.expr) -> ast.expr:
return ast.Call(
func=ast.Name(id="len"),
args=[
ast.Attribute(
value=df,
attr="index",
)
],
keywords=[],
)
self.assertions.define( self.assertions.define(
func_name, func_name,
ast.FunctionDef( ast.FunctionDef(
@@ -292,17 +458,9 @@ class FrameMethodRegistry(MethodRegistry[Call]):
body=[ body=[
ast.Return( ast.Return(
value=ast.Compare( value=ast.Compare(
left=ast.Attribute( left=len_of_df(ast.Name(id="frame1")),
value=ast.Name(id="frame1"),
attr="size",
),
ops=[ast.Eq()], ops=[ast.Eq()],
comparators=[ comparators=[len_of_df(ast.Name(id="frame2"))],
ast.Attribute(
value=ast.Name(id="frame2"),
attr="size",
)
],
) )
) )
], ],

View File

@@ -20,7 +20,7 @@ from midas.checker.types import Type, UnknownType
from midas.generator.collector import AssertionCollector from midas.generator.collector import AssertionCollector
if TYPE_CHECKING: if TYPE_CHECKING:
from midas.checker.python import PythonTyper from midas.checker.python import PythonTyper, TypedExpr
class _MethodRegistryMeta(type): class _MethodRegistryMeta(type):
@@ -41,12 +41,18 @@ class _MethodRegistryMeta(type):
return new_class return new_class
class HasLocation(Protocol): class MethodCall(Protocol):
@property @property
def location(self) -> Location: ... def location(self) -> Location: ...
@property
def call_expr(self) -> p.Expr: ...
T = TypeVar("T", bound=HasLocation) @property
def subject(self) -> TypedExpr: ...
T = TypeVar("T", bound=MethodCall)
class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta): class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
@@ -72,7 +78,9 @@ class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
def call(self, method: str, call: T) -> Type: def call(self, method: str, call: T) -> Type:
func: Optional[Callable[[Self, T], Type]] = self._methods.get(method) func: Optional[Callable[[Self, T], Type]] = self._methods.get(method)
if func is None: if func is None:
self.reporter.warning(call.location, f"Unknown method {method}") self.reporter.warning(
call.location, f"Unknown method {method} on {call.subject[1]}"
)
return UnknownType() return UnknownType()
return func(self, call) return func(self, call)

View File

@@ -21,6 +21,7 @@ from midas.checker.types import (
ExtensionType, ExtensionType,
Function, Function,
GenericType, GenericType,
ParamSpec,
Predicate, Predicate,
Type, Type,
TypeVar, TypeVar,
@@ -32,13 +33,6 @@ from midas.lexer.token import Token
from midas.parser.midas import MidasParser from midas.parser.midas import MidasParser
@dataclass(frozen=True, kw_only=True)
class TypedParamSpec:
pos: list[Function.Argument]
mixed: list[Function.Argument]
kw: list[Function.Argument]
class ReturnException(Exception): class ReturnException(Exception):
pass pass
@@ -47,7 +41,7 @@ class ReturnException(Exception):
class MappedArgument: class MappedArgument:
expr: m.Expr expr: m.Expr
type: Type type: Type
argument: Function.Argument argument: Function.Parameter
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
@@ -196,9 +190,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
self._predicate_params[param.name.lexeme] = param.type.accept(self) self._predicate_params[param.name.lexeme] = param.type.accept(self)
type: Type = self.type_of(stmt.body) type: Type = self.type_of(stmt.body)
params: list[TypedParamSpec] = [ params: list[ParamSpec] = [self._visit_param_spec(spec) for spec in stmt.params]
self._visit_param_spec(spec) for spec in stmt.params
]
if not self._is_valid_predicate(type): if not self._is_valid_predicate(type):
self.reporter.error( self.reporter.error(
@@ -209,9 +201,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
type = self._bool type = self._bool
for spec in reversed(params): for spec in reversed(params):
type = Function( type = Function(
pos_args=spec.pos, params=spec,
args=spec.mixed,
kw_args=spec.kw,
returns=type, returns=type,
) )
self._predicate_params = {} self._predicate_params = {}
@@ -386,30 +376,34 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
) )
def visit_function_type(self, type: m.FunctionType) -> Type: def visit_function_type(self, type: m.FunctionType) -> Type:
params: TypedParamSpec = self._visit_param_spec(type.params)
return Function( return Function(
pos_args=params.pos, params=self._visit_param_spec(type.params),
args=params.mixed,
kw_args=params.kw,
returns=type.returns.accept(self), returns=type.returns.accept(self),
) )
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec: def _visit_param_spec(self, spec: m.ParamSpec) -> ParamSpec:
n_pos: int = len(spec.pos) n_pos: int = len(spec.pos)
n_mixed: int = len(spec.mixed) n_mixed: int = len(spec.mixed)
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument: def process_param(
return Function.Argument( param: m.FunctionType.Parameter, i: int
) -> Function.Parameter:
return Function.Parameter(
pos=i, pos=i,
name=arg.name.lexeme if arg.name is not None else str(i), name=param.name.lexeme if param.name is not None else str(i),
type=arg.type.accept(self), type=param.type.accept(self),
required=arg.required, required=param.required,
) )
return TypedParamSpec( return ParamSpec(
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)], pos=[process_param(param, i) for i, param in enumerate(spec.pos)],
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)], mixed=[
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)], process_param(param, i + n_pos) for i, param in enumerate(spec.mixed)
],
kw=[
process_param(param, i + n_pos + n_mixed)
for i, param in enumerate(spec.kw)
],
) )
def visit_frame_type(self, type: m.FrameType) -> Type: def visit_frame_type(self, type: m.FrameType) -> Type:

View File

@@ -7,6 +7,7 @@ from midas.checker.types import (
Function, Function,
GenericType, GenericType,
OverloadedFunction, OverloadedFunction,
ParamSpec,
TopType, TopType,
Type, Type,
TypeVar, TypeVar,
@@ -108,8 +109,8 @@ class Preamble(Environment):
], ],
) )
def _list_of(self, item_type: Type) -> Type: def _list_of(self, item_type: str | Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type]) return self._types.list_of(item_type)
def _def_type_constructor( def _def_type_constructor(
self, name: str, py_function: Optional[Callable[..., Any]] = None self, name: str, py_function: Optional[Callable[..., Any]] = None
@@ -132,9 +133,9 @@ class Preamble(Environment):
returns: Type = UnitType(), returns: Type = UnitType(),
type_vars: list[TypeVar] = [], type_vars: list[TypeVar] = [],
) -> Type: ) -> Type:
def map_args(params: list[Param], offset: int) -> list[Function.Argument]: def map_params(params: list[Param], offset: int) -> list[Function.Parameter]:
return [ return [
Function.Argument( Function.Parameter(
pos=i + offset, pos=i + offset,
name=param.name, name=param.name,
type=param.type, type=param.type,
@@ -144,9 +145,11 @@ class Preamble(Environment):
] ]
function = Function( function = Function(
pos_args=map_args(pos, 0), params=ParamSpec(
args=map_args(mixed, len(pos)), pos=map_params(pos, 0),
kw_args=map_args(kw, len(pos) + len(mixed)), mixed=map_params(mixed, len(pos)),
kw=map_params(kw, len(pos) + len(mixed)),
),
returns=returns, returns=returns,
) )
if len(type_vars) != 0: if len(type_vars) != 0:

View File

@@ -31,6 +31,7 @@ from midas.checker.types import (
FrameGroupBy, FrameGroupBy,
Function, Function,
GenericType, GenericType,
ParamSpec,
TopType, TopType,
TupleType, TupleType,
Type, Type,
@@ -59,7 +60,7 @@ class UndefinedMethodException(Exception):
class MappedArgument: class MappedArgument:
expr: p.Expr expr: p.Expr
type: Type type: Type
argument: Function.Argument argument: Function.Parameter
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
@@ -222,7 +223,7 @@ class PythonTyper(
method_name: str, method_name: str,
positional: list[TypedExpr], positional: list[TypedExpr],
keywords: dict[str, TypedExpr], keywords: dict[str, TypedExpr],
) -> Optional[Type]: ) -> Type:
unfolded: Type = unfold_type(obj[1]) unfolded: Type = unfold_type(obj[1])
match unfolded: match unfolded:
case DataFrameType(): case DataFrameType():
@@ -289,61 +290,64 @@ class PythonTyper(
def visit_function(self, stmt: p.Function) -> None: def visit_function(self, stmt: p.Function) -> None:
env: Environment = Environment(self.env) env: Environment = Environment(self.env)
pos_args: list[Function.Argument] = [] pos: list[Function.Parameter] = []
args: list[Function.Argument] = [] mixed: list[Function.Parameter] = []
kw_args: list[Function.Argument] = [] kw: list[Function.Parameter] = []
def eval_arg_type(arg: p.Function.Argument) -> Type: def eval_param_type(param: p.Function.Parameter) -> Type:
if arg.type is not None: if param.type is not None:
return self.resolve_type_expr(arg.type) return self.resolve_type_expr(param.type)
if arg.default is not None: if param.default is not None:
return self.type_of(arg.default) return self.type_of(param.default)
return UnknownType() return UnknownType()
pos: int = 0 position: int = 0
for arg in stmt.posonlyargs: for param in stmt.params.pos:
pos_args.append( pos.append(
Function.Argument( Function.Parameter(
pos=pos, pos=position,
name=arg.name, name=param.name,
type=eval_arg_type(arg), type=eval_param_type(param),
required=arg.default is None, required=param.default is None,
) )
) )
pos += 1 position += 1
for arg in stmt.args: for param in stmt.params.mixed:
args.append( mixed.append(
Function.Argument( Function.Parameter(
pos=pos, pos=position,
name=arg.name, name=param.name,
type=eval_arg_type(arg), type=eval_param_type(param),
required=arg.default is None, required=param.default is None,
) )
) )
pos += 1 position += 1
for arg in stmt.kwonlyargs: for param in stmt.params.kw:
kw_args.append( kw.append(
Function.Argument( Function.Parameter(
pos=pos, # not relevant pos=position, # not relevant
name=arg.name, name=param.name,
type=eval_arg_type(arg), type=eval_param_type(param),
required=arg.default is None, required=param.default is None,
) )
) )
pos += 1 position += 1
all_args: list[Function.Argument] = pos_args + args + kw_args param_spec: ParamSpec = ParamSpec(
for arg in all_args: pos=pos,
env.define(arg.name, arg.type) mixed=mixed,
kw=kw,
)
all_params: list[Function.Parameter] = pos + mixed + kw
for param in all_params:
env.define(param.name, param.type)
returns_hint: Optional[Type] = None returns_hint: Optional[Type] = None
if stmt.returns is not None: if stmt.returns is not None:
returns_hint = self.resolve_type_expr(stmt.returns) returns_hint = self.resolve_type_expr(stmt.returns)
# Early define to handle simple fully-typed recursion # Early define to handle simple fully-typed recursion
inside_function: Function = Function( inside_function: Function = Function(
pos_args=pos_args, params=param_spec,
args=args,
kw_args=kw_args,
returns=returns_hint, returns=returns_hint,
) )
self.env.define(stmt.name, inside_function) self.env.define(stmt.name, inside_function)
@@ -375,13 +379,11 @@ class PythonTyper(
# TODO: handle *args and **kwargs sinks # TODO: handle *args and **kwargs sinks
function: Type = Function( function: Type = Function(
pos_args=pos_args, params=param_spec,
args=args,
kw_args=kw_args,
returns=returns, returns=returns,
) )
generic_params: list[TypeVar] = [] generic_params: list[TypeVar] = []
all_types: list[Type] = [arg.type for arg in all_args] + [returns] all_types: list[Type] = [param.type for param in all_params] + [returns]
for type in all_types: for type in all_types:
if isinstance(type, TypeVar): if isinstance(type, TypeVar):
if type not in generic_params: if type not in generic_params:
@@ -580,9 +582,8 @@ class PythonTyper(
right: TypedExpr, right: TypedExpr,
method: str, method: str,
) -> Type: ) -> Type:
result: Optional[Type]
try: try:
result = self.call_method( return self.call_method(
location=location, location=location,
call_expr=expr, call_expr=expr,
obj=left, obj=left,
@@ -597,8 +598,6 @@ class PythonTyper(
) )
return UnknownType() return UnknownType()
return result or UnknownType()
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__) method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
if method is None: if method is None:
@@ -610,9 +609,8 @@ class PythonTyper(
operand: Type = self.type_of(expr.right) operand: Type = self.type_of(expr.right)
result: Optional[Type]
try: try:
result = self.call_method( return self.call_method(
location=expr.location, location=expr.location,
call_expr=expr, call_expr=expr,
obj=(expr.right, operand), obj=(expr.right, operand),
@@ -627,8 +625,6 @@ class PythonTyper(
) )
return UnknownType() return UnknownType()
return result or UnknownType()
def visit_call_expr(self, expr: p.CallExpr) -> Type: def visit_call_expr(self, expr: p.CallExpr) -> Type:
match expr.callee: match expr.callee:
case p.VariableExpr(name="TypeVar"): case p.VariableExpr(name="TypeVar"):
@@ -644,8 +640,7 @@ class PythonTyper(
match expr.callee: match expr.callee:
case p.GetExpr(object=obj, name=method): case p.GetExpr(object=obj, name=method):
obj_type: Type = self.type_of(obj) obj_type: Type = self.type_of(obj)
return ( return self.call_method(
self.call_method(
location=expr.location, location=expr.location,
call_expr=expr, call_expr=expr,
obj=(obj, obj_type), obj=(obj, obj_type),
@@ -653,8 +648,6 @@ class PythonTyper(
positional=positional, positional=positional,
keywords=keywords, keywords=keywords,
) )
or UnknownType()
)
callee: Type = self.type_of(expr.callee) callee: Type = self.type_of(expr.callee)
result: CallResult = self.dispatcher.get_result( result: CallResult = self.dispatcher.get_result(
@@ -668,6 +661,14 @@ class PythonTyper(
def visit_get_expr(self, expr: p.GetExpr) -> Type: def visit_get_expr(self, expr: p.GetExpr) -> Type:
object: Type = self.type_of(expr.object) object: Type = self.type_of(expr.object)
member: Optional[Type] = self.types.lookup_member(object, expr.name) member: Optional[Type] = self.types.lookup_member(object, expr.name)
if member is None:
match object:
case DataFrameType():
member = self.frame_mgr.get_attribute(object, expr.name)
case ColumnType():
member = self.column_mgr.get_attribute(object, expr.name)
if member is None: if member is None:
self.reporter.warning( self.reporter.warning(
expr.location, f"Unknown member '{expr.name}' of {object}" expr.location, f"Unknown member '{expr.name}' of {object}"

View File

@@ -113,6 +113,15 @@ class TypesRegistry:
raise ValueError(f"Predicate {name} already defined") raise ValueError(f"Predicate {name} already defined")
self._predicates[name] = predicate self._predicates[name] = predicate
def is_builtin_subtype(self, name1: str, name2: str) -> bool:
subtypes: set[str] = BUILTIN_SUBTYPES.get(name2, set())
if name1 in subtypes:
return True
for subtype in subtypes:
if self.is_builtin_subtype(name1, subtype):
return True
return False
def is_subtype(self, type1: Type, type2: Type) -> bool: def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2` """Check whether `type1` is a subtype of `type2`
@@ -150,7 +159,7 @@ class TypesRegistry:
return self.is_subtype(base1, type2) return self.is_subtype(base1, type2)
case (BaseType(name=name1), BaseType(name=name2)): case (BaseType(name=name1), BaseType(name=name2)):
return name1 in BUILTIN_SUBTYPES.get(name2, set()) return self.is_builtin_subtype(name1, name2)
case (ComplexType(properties=props1), ComplexType(properties=props2)): case (ComplexType(properties=props1), ComplexType(properties=props2)):
for k, t in props2.items(): for k, t in props2.items():
@@ -225,92 +234,100 @@ class TypesRegistry:
if not self.is_subtype(func1.returns, func2.returns): if not self.is_subtype(func1.returns, func2.returns):
return False return False
pos1: list[Function.Argument] = func1.pos_args pos1: list[Function.Parameter] = func1.params.pos
mixed1: list[Function.Argument] = func1.args mixed1: list[Function.Parameter] = func1.params.mixed
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args} kw1: dict[str, Function.Parameter] = {
pos2: list[Function.Argument] = func2.pos_args param.name: param for param in func1.params.kw
mixed2: list[Function.Argument] = func2.args }
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args} pos2: list[Function.Parameter] = func2.params.pos
mixed2: list[Function.Parameter] = func2.params.mixed
kw2: dict[str, Function.Parameter] = {
param.name: param for param in func2.params.kw
}
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2} mixed_by_pos: dict[int, Function.Parameter] = {
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2} param.pos: param for param in mixed2
}
mixed_by_name: dict[str, Function.Parameter] = {
param.name: param for param in mixed2
}
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool: def is_arg_subtype(sub: Function.Parameter, sup: Function.Parameter) -> bool:
if not self.is_subtype(sub.type, sup.type): if not self.is_subtype(sub.type, sup.type):
return False return False
if not sup.required and sub.required: if not sup.required and sub.required:
return False return False
return True return True
for arg1 in pos1: for param1 in pos1:
arg2: Function.Argument param2: Function.Parameter
if arg1.pos < len(pos2): if param1.pos < len(pos2):
arg2 = pos2[arg1.pos] param2 = pos2[param1.pos]
elif arg1.pos in mixed_by_pos: elif param1.pos in mixed_by_pos:
arg2 = mixed_by_pos[arg1.pos] param2 = mixed_by_pos[param1.pos]
elif not arg1.required: elif not param1.required:
continue continue
else: else:
return False return False
if not is_arg_subtype(arg2, arg1): if not is_arg_subtype(param2, param1):
return False return False
for name, arg1 in kw1.items(): for name, param1 in kw1.items():
arg2: Function.Argument param2: Function.Parameter
if name in kw2: if name in kw2:
arg2 = kw2[name] param2 = kw2[name]
elif name in mixed_by_name: elif name in mixed_by_name:
arg2 = mixed_by_name[name] param2 = mixed_by_name[name]
elif not arg1.required: elif not param1.required:
continue continue
else: else:
return False return False
if not is_arg_subtype(arg2, arg1): if not is_arg_subtype(param2, param1):
return False return False
for arg1 in mixed1: for param1 in mixed1:
pos_arg2: Optional[Function.Argument] = None pos_param2: Optional[Function.Parameter] = None
kw_arg2: Optional[Function.Argument] = None kw_param2: Optional[Function.Parameter] = None
if arg1.name in kw2: if param1.name in kw2:
kw_arg2 = kw2[arg1.name] kw_param2 = kw2[param1.name]
elif arg1.name in mixed_by_name: elif param1.name in mixed_by_name:
kw_arg2 = mixed_by_name[arg1.name] kw_param2 = mixed_by_name[param1.name]
if arg1.pos < len(pos2): if param1.pos < len(pos2):
pos_arg2 = pos2[arg1.pos] pos_param2 = pos2[param1.pos]
elif arg1.pos in mixed_by_pos: elif param1.pos in mixed_by_pos:
pos_arg2 = mixed_by_pos[arg1.pos] pos_param2 = mixed_by_pos[param1.pos]
# No match in func2 and arg is required # No match in func2 and arg is required
if pos_arg2 is None and kw_arg2 is None and arg1.required: if pos_param2 is None and kw_param2 is None and param1.required:
return False return False
# Matching keyword argument # Matching keyword argument
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1): if kw_param2 is not None and not is_arg_subtype(kw_param2, param1):
return False return False
# Matching positional argument # Matching positional argument
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1): if pos_param2 is not None and not is_arg_subtype(pos_param2, param1):
return False return False
mixed_positions: set[int] = {a.pos for a in mixed1} mixed_positions: set[int] = {param.pos for param in mixed1}
mixed_names: set[str] = {a.name for a in mixed1} mixed_names: set[str] = {param.name for param in mixed1}
for arg2 in pos2: for param2 in pos2:
if not arg2.required: if not param2.required:
continue continue
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions: if param2.pos >= len(pos1) and param2.pos not in mixed_positions:
return False return False
for name, arg2 in kw2.items(): for name, param2 in kw2.items():
if not arg2.required: if not param2.required:
continue continue
if name not in kw1 and name not in mixed_names: if name not in kw1 and name not in mixed_names:
return False return False
for arg2 in mixed2: for param2 in mixed2:
if arg2.required: if param2.required:
continue continue
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions pos_match: bool = param2.pos < len(pos1) or param2.pos in mixed_positions
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names kw_match: bool = param2.name in kw1 or param2.name in mixed_names
if not pos_match or not kw_match: if not pos_match or not kw_match:
return False return False
@@ -443,3 +460,29 @@ class TypesRegistry:
def lookup_predicate(self, name: str) -> Optional[Predicate]: def lookup_predicate(self, name: str) -> Optional[Predicate]:
return self._predicates.get(name) return self._predicates.get(name)
def _by_name_or_type(self, name_or_type: str | Type) -> Type:
if isinstance(name_or_type, str):
return self.get_type(name_or_type)
return name_or_type
def list_of(self, item_type: str | Type) -> Type:
list_ = self.get_type("list")
return self.apply_generic(list_, [self._by_name_or_type(item_type)])
def tuple_of(self, *item_types: str | Type) -> Type:
tuple_ = self.get_type("tuple")
return self.apply_generic(
tuple_,
[self._by_name_or_type(item_type) for item_type in item_types],
)
def dict_of(self, key_type: str | Type, value_type: str | Type) -> Type:
dict_ = self.get_type("dict")
return self.apply_generic(
dict_,
[
self._by_name_or_type(key_type),
self._by_name_or_type(value_type),
],
)

View File

@@ -93,7 +93,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
function (p.Function): the function to resolve function (p.Function): the function to resolve
""" """
self.begin_scope() self.begin_scope()
for param in function.all_args: for param in function.params.all:
self.declare(param.name) self.declare(param.name)
self.define(param.name) self.define(param.name)
self.resolve(*function.body) self.resolve(*function.body)

View File

@@ -45,28 +45,14 @@ class UnitType:
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Function: class Function:
pos_args: list[Argument] = field(default_factory=list) params: ParamSpec
args: list[Argument] = field(default_factory=list)
kw_args: list[Argument] = field(default_factory=list)
returns: Type returns: Type
def __str__(self) -> str: def __str__(self) -> str:
args: list[str] = [] return f"{self.params} -> {self.returns}"
if len(self.pos_args) != 0:
args += list(map(str, self.pos_args))
args.append("/")
if len(self.args) != 0:
args += list(map(str, self.args))
if len(self.kw_args) != 0:
args.append("*")
args += list(map(str, self.kw_args))
return f"({', '.join(args)}) -> {self.returns}"
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class Argument: class Parameter:
pos: int pos: int
name: str name: str
type: Type type: Type
@@ -77,6 +63,28 @@ class Function:
return f"{self.name}: {self.type}{opt}" return f"{self.name}: {self.type}{opt}"
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
pos: list[Function.Parameter] = field(default_factory=list)
mixed: list[Function.Parameter] = field(default_factory=list)
kw: list[Function.Parameter] = field(default_factory=list)
def __str__(self) -> str:
params: list[str] = []
if len(self.pos) != 0:
params += list(map(str, self.pos))
params.append("/")
if len(self.mixed) != 0:
params += list(map(str, self.mixed))
if len(self.kw) != 0:
params.append("*")
params += list(map(str, self.kw))
return f"({', '.join(params)})"
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class OverloadedFunction: class OverloadedFunction:
overloads: list[Type] overloads: list[Type]
@@ -204,12 +212,19 @@ class ColumnGroupBy:
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type: def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument): def sub_parameter(param: Function.Parameter):
return Function.Argument( return Function.Parameter(
pos=arg.pos, pos=param.pos,
name=arg.name, name=param.name,
type=substitute_typevars(arg.type, substitutions), type=substitute_typevars(param.type, substitutions),
required=arg.required, required=param.required,
)
def sub_param_spec(spec: ParamSpec):
return ParamSpec(
pos=list(map(sub_parameter, spec.pos)),
mixed=list(map(sub_parameter, spec.mixed)),
kw=list(map(sub_parameter, spec.kw)),
) )
def sub_column(col: DataFrameType.Column): def sub_column(col: DataFrameType.Column):
@@ -235,15 +250,11 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
) )
case Function( case Function(
pos_args=pos_args, params=params,
args=args,
kw_args=kw_args,
returns=returns, returns=returns,
): ):
return Function( return Function(
pos_args=list(map(sub_argument, pos_args)), params=sub_param_spec(params),
args=list(map(sub_argument, args)),
kw_args=list(map(sub_argument, kw_args)),
returns=substitute_typevars(returns, substitutions), returns=substitute_typevars(returns, substitutions),
) )
@@ -351,14 +362,14 @@ def unfold_type(type: Type) -> Type:
def to_annotation(type: Type) -> str: def to_annotation(type: Type) -> str:
def _args_annotation(func: Function) -> str: def _params_annotation(spec: ParamSpec) -> str:
if len(func.kw_args) != 0: if len(spec.kw) != 0:
return "..." return "..."
args: str = ", ".join( params: str = ", ".join(
to_annotation(arg.type) for arg in func.pos_args + func.args to_annotation(param.type) for param in spec.pos + spec.mixed
) )
return f"[{args}]" return f"[{params}]"
match type: match type:
case TopType(): case TopType():
@@ -376,8 +387,8 @@ def to_annotation(type: Type) -> str:
case UnitType(): case UnitType():
return "None" return "None"
case Function(returns=returns): case Function(params=params, returns=returns):
params_annot: str = _args_annotation(type) params_annot: str = _params_annotation(params)
return f"Callable[{params_annot}, {to_annotation(returns)}]" return f"Callable[{params_annot}, {to_annotation(returns)}]"
case OverloadedFunction(): case OverloadedFunction():

View File

@@ -8,6 +8,7 @@ from midas.checker.types import (
DataFrameType, DataFrameType,
Function, Function,
GenericType, GenericType,
ParamSpec,
TopType, TopType,
Type, Type,
TypeVar, TypeVar,
@@ -29,8 +30,9 @@ class Unifier:
keywords: dict[str, Type], keywords: dict[str, Type],
) -> Optional[Type]: ) -> Optional[Type]:
concrete_func: Function = Function( concrete_func: Function = Function(
pos_args=[ params=ParamSpec(
Function.Argument( pos=[
Function.Parameter(
pos=i, pos=i,
name=str(i), name=str(i),
type=arg, type=arg,
@@ -38,9 +40,8 @@ class Unifier:
) )
for i, arg in enumerate(positional) for i, arg in enumerate(positional)
], ],
args=[], kw=[
kw_args=[ Function.Parameter(
Function.Argument(
pos=len(positional) + i, pos=len(positional) + i,
name=name, name=name,
type=arg, type=arg,
@@ -48,6 +49,7 @@ class Unifier:
) )
for i, (name, arg) in enumerate(keywords.items()) for i, (name, arg) in enumerate(keywords.items())
], ],
),
returns=TopType(), # TODO: use expected type returns=TopType(), # TODO: use expected type
) )
return self.unify_generic(type, concrete_func, match_return=False) return self.unify_generic(type, concrete_func, match_return=False)
@@ -125,7 +127,7 @@ class Unifier:
return self.match(template_column, concrete_column) return self.match(template_column, concrete_column)
case (Function(), Function()): case (Function(), Function()):
mapped: list[tuple[Function.Argument, Function.Argument]] = ( mapped: list[tuple[Function.Parameter, Function.Parameter]] = (
self.map_params(template, concrete) self.map_params(template, concrete)
) )
substitutions: dict[str, Type] = {} substitutions: dict[str, Type] = {}
@@ -161,19 +163,23 @@ class Unifier:
def map_params( def map_params(
self, func1: Function, func2: Function self, func1: Function, func2: Function
) -> list[tuple[Function.Argument, Function.Argument]]: ) -> list[tuple[Function.Parameter, Function.Parameter]]:
pos1: list[Function.Argument] = func1.pos_args pos1: list[Function.Parameter] = func1.params.pos
mixed1: list[Function.Argument] = func1.args mixed1: list[Function.Parameter] = func1.params.mixed
kw1: list[Function.Argument] = func1.kw_args kw1: list[Function.Parameter] = func1.params.kw
pos2: list[Function.Argument] = func2.pos_args pos2: list[Function.Parameter] = func2.params.pos
mixed2: list[Function.Argument] = func2.args mixed2: list[Function.Parameter] = func2.params.mixed
kw2: list[Function.Argument] = func2.kw_args kw2: list[Function.Parameter] = func2.params.kw
mapped: list[tuple[Function.Argument, Function.Argument]] = [] mapped: list[tuple[Function.Parameter, Function.Parameter]] = []
by_pos2: dict[int, Function.Argument] = {arg.pos: arg for arg in pos2 + mixed2} by_pos2: dict[int, Function.Parameter] = {
by_name2: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2 + kw2} param.pos: param for param in pos2 + mixed2
}
by_name2: dict[str, Function.Parameter] = {
param.name: param for param in mixed2 + kw2
}
for arg1 in pos1: for arg1 in pos1:
if (arg2 := by_pos2.get(arg1.pos)) is not None: if (arg2 := by_pos2.get(arg1.pos)) is not None:

View File

@@ -77,14 +77,14 @@ class VarianceInferrer:
match type: match type:
# Arguments are negative positions -> flip polarity # Arguments are negative positions -> flip polarity
# Return is positive position -> keep polarity # Return is positive position -> keep polarity
case Function(pos_args=pos_args, args=mixed_args, kw_args=kw_args): case Function(params=spec):
all_args: list[Function.Argument] = pos_args + mixed_args + kw_args all_params: list[Function.Parameter] = spec.pos + spec.mixed + spec.kw
for arg in all_args: for param in all_params:
self.walk( self.walk(
arg.type, param.type,
-polarity, -polarity,
base_name, base_name,
path + [f"arg:'{arg.name}'"], path + [f"param:'{param.name}'"],
) )
self.walk(type.returns, polarity, base_name, path + ["return"]) self.walk(type.returns, polarity, base_name, path + ["return"])
@@ -109,10 +109,10 @@ class VarianceInferrer:
Variance.COVARIANT: 1, Variance.COVARIANT: 1,
Variance.CONTRAVARIANT: -1, Variance.CONTRAVARIANT: -1,
} }
for arg, param in zip(args, params): for param, param in zip(args, params):
param_polarity: Polarity = polarities[param.variance] param_polarity: Polarity = polarities[param.variance]
self.walk( self.walk(
arg, param,
cast(Polarity, polarity * param_polarity), cast(Polarity, polarity * param_polarity),
base_name, base_name,
path + [f"applied:'{name}'"], path + [f"applied:'{name}'"],

View File

@@ -157,15 +157,18 @@ class PythonHighlighter(
def visit_function(self, stmt: p.Function) -> None: def visit_function(self, stmt: p.Function) -> None:
self.wrap(stmt, "function") self.wrap(stmt, "function")
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs: self._highlight_param_spec(stmt.params)
self._highlight_function_argument(arg)
for body_stmt in stmt.body: for body_stmt in stmt.body:
body_stmt.accept(self) body_stmt.accept(self)
def _highlight_function_argument(self, arg: p.Function.Argument) -> None: def _highlight_param_spec(self, spec: p.ParamSpec) -> None:
self.wrap(arg, "argument") for param in spec.all:
if arg.type is not None: self._highlight_function_param(param)
arg.type.accept(self)
def _highlight_function_param(self, param: p.Function.Parameter) -> None:
self.wrap(param, "parameter")
if param.type is not None:
param.type.accept(self)
def visit_type_assign(self, stmt: p.TypeAssign) -> None: def visit_type_assign(self, stmt: p.TypeAssign) -> None:
stmt.type.accept(self) stmt.type.accept(self)

View File

@@ -23,7 +23,7 @@ span {
--col: 215, 103, 224; --col: 215, 103, 224;
} }
&.argument { &.parameter {
--col: 103, 192, 224; --col: 103, 192, 224;
} }
} }

View File

@@ -5,6 +5,7 @@ import midas.ast.midas as m
from midas.checker.registry import TypesRegistry from midas.checker.registry import TypesRegistry
from midas.checker.types import ( from midas.checker.types import (
Function, Function,
ParamSpec,
Predicate, Predicate,
Type, Type,
to_annotation, to_annotation,
@@ -54,16 +55,16 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
return expr.accept(self) return expr.accept(self)
case _: case _:
func = Function( func = Function(
pos_args=[], params=ParamSpec(
args=[ mixed=[
Function.Argument( Function.Parameter(
pos=0, pos=0,
name="_", name="_",
type=self.types.get_type("Any"), type=self.types.get_type("Any"),
required=True, required=True,
) )
], ],
kw_args=[], ),
returns=self.types.get_type("bool"), returns=self.types.get_type("bool"),
) )
alias: str = self.make_alias(None) alias: str = self.make_alias(None)
@@ -94,28 +95,28 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
) )
return self.make_func(name, [ast.Return(value=body)], predicate.type) return self.make_func(name, [ast.Return(value=body)], predicate.type)
def make_args(self, func: Function) -> ast.arguments: def make_args(self, params: ParamSpec) -> ast.arguments:
return ast.arguments( return ast.arguments(
posonlyargs=[ posonlyargs=[
ast.arg( ast.arg(
arg=arg.name, arg=param.name,
annotation=ast.Constant(value=to_annotation(arg.type)), annotation=ast.Constant(value=to_annotation(param.type)),
) )
for arg in func.pos_args for param in params.pos
], ],
args=[ args=[
ast.arg( ast.arg(
arg=arg.name, arg=param.name,
annotation=ast.Constant(value=to_annotation(arg.type)), annotation=ast.Constant(value=to_annotation(param.type)),
) )
for arg in func.args for param in params.mixed
], ],
kwonlyargs=[ kwonlyargs=[
ast.arg( ast.arg(
arg=arg.name, arg=param.name,
annotation=ast.Constant(value=to_annotation(arg.type)), annotation=ast.Constant(value=to_annotation(param.type)),
) )
for arg in func.kw_args for param in params.kw
], ],
defaults=[], defaults=[],
kw_defaults=[], kw_defaults=[],
@@ -125,11 +126,11 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0 self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
) -> ast.stmt: ) -> ast.stmt:
match type: match type:
case Function(returns=Function()): case Function(params=params, returns=Function()):
inner_name: str = f"inner{level}" inner_name: str = f"inner{level}"
return ast.FunctionDef( return ast.FunctionDef(
name=name, name=name,
args=self.make_args(type), args=self.make_args(params),
body=[ body=[
self.make_func(inner_name, inner_body, type.returns, level + 1), self.make_func(inner_name, inner_body, type.returns, level + 1),
ast.Return(value=ast.Name(id=inner_name)), ast.Return(value=ast.Name(id=inner_name)),
@@ -138,10 +139,10 @@ class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
decorator_list=[], decorator_list=[],
) )
case Function(): case Function(params=params):
return ast.FunctionDef( return ast.FunctionDef(
name=name, name=name,
args=self.make_args(type), args=self.make_args(params),
body=inner_body, body=inner_body,
returns=ast.Constant(value=to_annotation(type.returns)), returns=ast.Constant(value=to_annotation(type.returns)),
decorator_list=[], decorator_list=[],

View File

@@ -250,25 +250,26 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
value=self.convert(stmt.expr), value=self.convert(stmt.expr),
) )
def make_args(self, params: p.ParamSpec) -> ast.arguments:
return ast.arguments(
posonlyargs=[ast.arg(arg=param.name) for param in params.pos],
args=[ast.arg(arg=param.name) for param in params.mixed],
kwonlyargs=[ast.arg(arg=param.name) for param in params.kw],
defaults=[
self.convert(param.default)
for param in params.pos + params.mixed
if param.default is not None
],
kw_defaults=[
self.convert(param.default) if param.default is not None else None
for param in params.kw
],
)
def visit_function(self, stmt: p.Function) -> ast.stmt: def visit_function(self, stmt: p.Function) -> ast.stmt:
return ast.FunctionDef( return ast.FunctionDef(
name=stmt.name, name=stmt.name,
args=ast.arguments( args=self.make_args(stmt.params),
posonlyargs=[ast.arg(arg=arg.name) for arg in stmt.posonlyargs],
vararg=None,
args=[ast.arg(arg=arg.name) for arg in stmt.args],
kwonlyargs=[ast.arg(arg=arg.name) for arg in stmt.kwonlyargs],
kwarg=None,
defaults=[
self.convert(arg.default)
for arg in stmt.posonlyargs + stmt.args
if arg.default is not None
],
kw_defaults=[
self.convert(arg.default) if arg.default is not None else None
for arg in stmt.kwonlyargs
],
),
body=self._visit_body(stmt.body), body=self._visit_body(stmt.body),
decorator_list=[], decorator_list=[],
) )

View File

@@ -17,6 +17,7 @@ from midas.checker.types import (
Function, Function,
GenericType, GenericType,
OverloadedFunction, OverloadedFunction,
ParamSpec,
TopType, TopType,
TupleType, TupleType,
Type, Type,
@@ -328,7 +329,7 @@ class StubsGenerator:
return [ return [
ast.FunctionDef( ast.FunctionDef(
name=name, name=name,
args=self.dump_args(method, with_self=True), args=self.dump_params(method.params, with_self=True),
returns=self.dump_type(method.returns), returns=self.dump_type(method.returns),
body=[ast.Expr(value=Empty)], body=[ast.Expr(value=Empty)],
decorator_list=[ast.Name(id="overload")] if overloaded else [], decorator_list=[ast.Name(id="overload")] if overloaded else [],
@@ -348,24 +349,33 @@ class StubsGenerator:
) )
] ]
def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments: def dump_params(self, params: ParamSpec, with_self: bool = False) -> ast.arguments:
pos: list[ast.arg] = [ pos: list[ast.arg] = [
ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type)) ast.arg(
for arg in func.pos_args arg=f"_{param.pos}",
annotation=self.dump_type(param.type),
)
for param in params.pos
] ]
mixed: list[ast.arg] = [ mixed: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type)) ast.arg(
for arg in func.args arg=param.name,
annotation=self.dump_type(param.type),
)
for param in params.mixed
] ]
kw: list[ast.arg] = [ kw: list[ast.arg] = [
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type)) ast.arg(
for arg in func.kw_args arg=param.name,
annotation=self.dump_type(param.type),
)
for param in params.kw
] ]
defaults: list[ast.expr] = [ defaults: list[ast.expr] = [
Empty for arg in func.pos_args + func.args if not arg.required Empty for param in params.pos + params.mixed if not param.required
] ]
kw_defaults: list[Optional[ast.expr]] = [ kw_defaults: list[Optional[ast.expr]] = [
None if arg.required else Empty for arg in func.kw_args None if param.required else Empty for param in params.kw
] ]
if with_self: if with_self:
arg = ast.arg(arg="self", annotation=None) arg = ast.arg(arg="self", annotation=None)
@@ -391,7 +401,7 @@ class StubsGenerator:
body=[ body=[
ast.FunctionDef( ast.FunctionDef(
name="__call__", name="__call__",
args=self.dump_args(func, with_self=True), args=self.dump_params(func.params, with_self=True),
returns=self.dump_type(func.returns), returns=self.dump_type(func.returns),
body=[ast.Expr(value=Empty)], body=[ast.Expr(value=Empty)],
decorator_list=[], decorator_list=[],

View File

@@ -16,9 +16,10 @@ class Lexer(ABC):
"""An abstract lexer which provides methods to easily extend it into a concrete one """An abstract lexer which provides methods to easily extend it into a concrete one
This implementation is based on the [_Crafting Interpreters_][1] book by Robert Nystrom, This implementation is based on the [_Crafting Interpreters_][1] book by Robert Nystrom,
more specifically on my [previous Python implementation](https://git.kb28.ch/HEL/pebble) more specifically on my [previous Python implementation][2]
[1]: https://craftinginterpreters.com/ [1]: https://craftinginterpreters.com/
[2]: https://git.kb28.ch/HEL/pebble
""" """
def __init__(self, source: str, file: Optional[str] = None) -> None: def __init__(self, source: str, file: Optional[str] = None) -> None:
@@ -168,6 +169,6 @@ class Lexer(ABC):
def scan_token(self) -> None: def scan_token(self) -> None:
"""Scan a token """Scan a token
This function should (at least) consume the current character and produce the appropriate token(s), using `add_token` This function should (at least) consume the current character and produce the appropriate token(s), using :func:`add_token`
""" """
pass pass

View File

@@ -81,6 +81,12 @@ class MidasLexer(Lexer):
return None return None
def scan_string(self, opening: str): def scan_string(self, opening: str):
"""Scan the rest of a string and add it as a token
Args:
opening (str): the opening quote or double quote, to be matched
at the end of the string
"""
while self.peek() != opening and not self.is_at_end(): while self.peek() != opening and not self.is_at_end():
self.advance() self.advance()
@@ -147,6 +153,18 @@ class MidasLexer(Lexer):
self.add_token(TokenType.COMMENT) self.add_token(TokenType.COMMENT)
def is_identifier_char(self, char: str, *, start: bool) -> bool: def is_identifier_char(self, char: str, *, start: bool) -> bool:
"""Check whether a character is a valid as part of an identifier
Identifiers can contain any alphanumerical character or underscore.
They cannot start with a digit.
Args:
char (str): the character to check
start (bool): whether this is the first character of the identifier
Returns:
bool: `True` if the character is valid, `False` otherwise
"""
if char == "_": if char == "_":
return True return True
if char.isalpha(): if char.isalpha():

View File

@@ -104,6 +104,15 @@ class Token:
) )
def location_to(self, to: Token) -> Location: def location_to(self, to: Token) -> Location:
"""Create a new :class:`Location` spanning from this token to another
Args:
to (Token): the end token
Returns:
Location: a new :class:`Location` starting at this token and ending
at `to`, both included
"""
return Location.span(self.get_location(), to.get_location()) return Location.span(self.get_location(), to.get_location())
@property @property

View File

@@ -16,6 +16,9 @@ class TokenError:
def get_report(self) -> str: def get_report(self) -> str:
"""Get a detailed error message """Get a detailed error message
The error message is formatted as "(<position>) Error at <token>: <message>".
For example: "(L2:5) Error at '3': Expected ')' after arguments."
Returns: Returns:
str: the complete error message str: the complete error message
""" """
@@ -32,9 +35,10 @@ class Parser(ABC, Generic[T]):
"""An abstract parser which provides methods to easily extend it into a concrete one """An abstract parser which provides methods to easily extend it into a concrete one
This implementation is based on the [_Crafting Interpreters_][1] book by Robert Nystrom, This implementation is based on the [_Crafting Interpreters_][1] book by Robert Nystrom,
more specifically on my [previous Python implementation](https://git.kb28.ch/HEL/pebble) more specifically on my [previous Python implementation][2]
[1]: https://craftinginterpreters.com/ [1]: https://craftinginterpreters.com/
[2]: https://git.kb28.ch/HEL/pebble
""" """
IGNORE: set[TokenType] = { IGNORE: set[TokenType] = {
@@ -173,7 +177,7 @@ class Parser(ABC, Generic[T]):
error_msg (str): the error message if the token doesn't match error_msg (str): the error message if the token doesn't match
Raises: Raises:
SyntaxError: if the current token doesn't match the given type ParsingError: if the current token doesn't match the given type
Returns: Returns:
Token: the current token which matched the given type Token: the current token which matched the given type

View File

@@ -35,10 +35,11 @@ from midas.parser.base import Parser
from midas.parser.errors import ParsingError from midas.parser.errors import ParsingError
class MidasParser(Parser): class MidasParser(Parser[list[Stmt]]):
"""A simple parser for midas type definitions""" """A simple parser for midas type definitions"""
SYNC_BOUNDARY: set[TokenType] = { SYNC_BOUNDARY: set[TokenType] = {
TokenType.ALIAS,
TokenType.TYPE, TokenType.TYPE,
TokenType.EXTEND, TokenType.EXTEND,
TokenType.PREDICATE, TokenType.PREDICATE,
@@ -73,10 +74,10 @@ class MidasParser(Parser):
def declaration(self) -> Optional[Stmt]: def declaration(self) -> Optional[Stmt]:
"""Try and parse a declaration """Try and parse a declaration
Any parsing error is caught and None is returned Any parsing error is caught and `None` is returned
Returns: Returns:
Optional[Stmt]: the parsed Midas statement, or None if a ParsingError was raised Optional[Stmt]: the parsed Midas statement, or `None` if a ParsingError was raised
""" """
try: try:
if self.match(TokenType.TYPE): if self.match(TokenType.TYPE):
@@ -95,23 +96,14 @@ class MidasParser(Parser):
def type_declaration(self) -> TypeStmt: def type_declaration(self) -> TypeStmt:
"""Parse a type declaration """Parse a type declaration
A type declaration can either be a simple type alias or a new complex type. A type declaration creates a named subtype of a type expression.
In either case, it can have an optional template expression after its name, wrapped in brackets. It can have an optional template expression after its name, wrapped in brackets, to handle type parameters.
A simple type alias is derived from a base type expression, and can have a optional constraint expression preceded by the `where` keyword.
A full simple type alias is thus written:
```
type Name[Template](TypeExpr) where Condition
```
A new complex type has a set of properties which are named, have a type and an optional constraint expression (also preceded by the `where` keyword). A type statement consists of:
A full complex type definition is thus written: - the `type` keyword
``` - a name (identifier)
type Name[Template] { - (optional) type parameters
prop1: TypeExpr1 where Condition1 - a body, a type expression (see :func:`type_expr`)
prop2: TypeExpr2 where Condition2
...
}
```
Returns: Returns:
TypeStmt: the parsed type declaration statement TypeStmt: the parsed type declaration statement
@@ -165,11 +157,16 @@ class MidasParser(Parser):
def alias_declaration(self) -> AliasStmt: def alias_declaration(self) -> AliasStmt:
"""Parse an alias declaration """Parse an alias declaration
An alias statement consists of:
- the `alias` keyword
- a name (identifier)
- a body, a type expression (see :func:`type_expr`)
Returns: Returns:
AliasStmt: the parsed alias declaration statement AliasStmt: the parsed alias declaration statement
""" """
keyword: Token = self.previous() keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected type name") name: Token = self.consume_identifier("Expected alias name")
self.consume(TokenType.EQUAL, "Expected '=' before alias definition") self.consume(TokenType.EQUAL, "Expected '=' before alias definition")
@@ -184,8 +181,8 @@ class MidasParser(Parser):
def type_expr(self) -> Type: def type_expr(self) -> Type:
"""Parse a type expression """Parse a type expression
A type is an identifier, optionally followed by a template expression. A type expression can either be a function type (see :func:`function`)
It can also optionally be followed by a '?' to indicate a nullable type or a constraint type (see :func:`constraint_type`)
Returns: Returns:
TypeExpr: the parsed type expression TypeExpr: the parsed type expression
@@ -205,6 +202,15 @@ class MidasParser(Parser):
return base return base
def constraint_type(self) -> Type: def constraint_type(self) -> Type:
"""Parse a constraint type expression
A constraint type consists of a base type (see :func:`base_type`),
optionally followed by the `where` keyword and a constraint
expression (see :func:`constraint`)
Returns:
Type: the parsed constraint type expression
"""
type: Type = self.base_type() type: Type = self.base_type()
if self.match(TokenType.WHERE): if self.match(TokenType.WHERE):
constraint: Expr = self.constraint() constraint: Expr = self.constraint()
@@ -216,6 +222,14 @@ class MidasParser(Parser):
return type return type
def base_type(self) -> Type: def base_type(self) -> Type:
"""Parse a base type expression
A base type is either a parenthesized type expression (see :func:`type_expr`)
or a generic type (see :func:`generic_type`)
Returns:
Type: the parsed base type expression
"""
if self.match(TokenType.LEFT_PAREN): if self.match(TokenType.LEFT_PAREN):
type: Type = self.type_expr() type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis") self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
@@ -227,6 +241,17 @@ class MidasParser(Parser):
return self.generic_type() return self.generic_type()
def generic_type(self) -> Type: def generic_type(self) -> Type:
"""Parse a generic type expression
A generic type consists of a named type (see :func:`named_type`),
optionally followed by type arguments in brackets.
The special `Frame` type accepts a frame schema instead of type
arguments (see :func:`frame_type`).
Returns:
Type: the parsed generic type
"""
type: NamedType = self.named_type() type: NamedType = self.named_type()
if self.check(TokenType.LEFT_BRACKET): if self.check(TokenType.LEFT_BRACKET):
if type.name.lexeme == "Frame": if type.name.lexeme == "Frame":
@@ -240,6 +265,13 @@ class MidasParser(Parser):
return type return type
def type_args(self) -> list[Type]: def type_args(self) -> list[Type]:
"""Parse a list of type arguments
Type arguments are a comma-separated list of type expression wrapped in brackets.
Returns:
list[Type]: the list of type arguments, if any, or an empty list
"""
args: list[Type] = [] args: list[Type] = []
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments") self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET): while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
@@ -250,6 +282,13 @@ class MidasParser(Parser):
return args return args
def named_type(self) -> NamedType: def named_type(self) -> NamedType:
"""Parse a named type expression
A named type is an identifier token
Returns:
NamedType: the parsed named type expression
"""
name: Token = self.consume_identifier("Expected type name") name: Token = self.consume_identifier("Expected type name")
return NamedType( return NamedType(
location=name.get_location(), location=name.get_location(),
@@ -257,13 +296,13 @@ class MidasParser(Parser):
) )
def complex_type(self) -> ComplexType: def complex_type(self) -> ComplexType:
"""Parse a type definition body """Parse a complex type expression
A type definition body is a set of whitespace-separated A complex type consists of zero or more member statements enclosed in
property statements enclosed in curly braces curly braces
Returns: Returns:
ComplexType: the parsed complex type ComplexType: the parsed complex type expression
""" """
left: Token = self.consume( left: Token = self.consume(
TokenType.LEFT_BRACE, "Expected '{' to start type body" TokenType.LEFT_BRACE, "Expected '{' to start type body"
@@ -285,6 +324,20 @@ class MidasParser(Parser):
) )
def frame_type(self) -> FrameType: def frame_type(self) -> FrameType:
"""Parse a frame type expression
A frame type consists of:
- the `Frame` identifier
- an opening bracket `[`
- a list of comma-separated column expression consisting of:
- a name (token)
- a colon `:`
- a type expression (see :func:`type_expr`)
- a closing bracket `]`
Returns:
FrameType: the parsed frame type
"""
keyword: Token = self.previous() keyword: Token = self.previous()
self.consume(TokenType.LEFT_BRACKET, "Expected '[' to start frame schema") self.consume(TokenType.LEFT_BRACKET, "Expected '[' to start frame schema")
@@ -311,9 +364,9 @@ class MidasParser(Parser):
) )
def constraint(self) -> Expr: def constraint(self) -> Expr:
"""Parse a constraint """Parse a constraint expression
A constraint is basically a logical predicate A constraint is an expression (see :func:`expression`)
Returns: Returns:
Expr: the parsed constraint expression Expr: the parsed constraint expression
@@ -321,10 +374,20 @@ class MidasParser(Parser):
return self.expression() return self.expression()
def expression(self) -> Expr: def expression(self) -> Expr:
"""Parse an expression
An expression consists of a logical AND expression (see :func:`and_`)
Returns:
Expr: the parsed expression
"""
return self.and_() return self.and_()
def and_(self) -> Expr: def and_(self) -> Expr:
"""Parse a logical AND expression or a simpler expression """Parse a logical AND expression
An AND consists of one or more equality expressions (see :func:`equality`)
separated by logical AND operators (`&`)
Returns: Returns:
Expr: the parsed expression Expr: the parsed expression
@@ -340,7 +403,10 @@ class MidasParser(Parser):
return expr return expr
def equality(self) -> Expr: def equality(self) -> Expr:
"""Parse a logical equality expression or a simpler expression """Parse an equality expression
An equality consists of one or more comparison expressions (see :func:`comparison`)
separated by equality operators (`==`, `!=`)
Returns: Returns:
Expr: the parsed expression Expr: the parsed expression
@@ -356,7 +422,10 @@ class MidasParser(Parser):
return expr return expr
def comparison(self) -> Expr: def comparison(self) -> Expr:
"""Parse a logical comparison expression or a simpler expression """Parse a comparison expression
A comparison consists of one or more term expressions (see :func:`term`)
separated by comparison operators (`<`, `<=`, `>`, `>=`)
Returns: Returns:
Expr: the parsed expression Expr: the parsed expression
@@ -377,6 +446,14 @@ class MidasParser(Parser):
return expr return expr
def term(self) -> Expr: def term(self) -> Expr:
"""Parse a term expression
A term consists of one or more factor expressions (see :func:`factor`)
separated by weak arithmetic operators (`+`, `-`)
Returns:
Expr: the parsed expression
"""
expr: Expr = self.factor() expr: Expr = self.factor()
while self.match(TokenType.PLUS, TokenType.MINUS): while self.match(TokenType.PLUS, TokenType.MINUS):
operator: Token = self.previous() operator: Token = self.previous()
@@ -388,6 +465,14 @@ class MidasParser(Parser):
return expr return expr
def factor(self) -> Expr: def factor(self) -> Expr:
"""Parse a factor expression
A factor consists of one or more unary expressions (see :func:`unary`)
separated by strong arithmetic operators (`*`, `/`)
Returns:
Expr: the parsed expression
"""
expr: Expr = self.unary() expr: Expr = self.unary()
while self.match(TokenType.STAR, TokenType.SLASH): while self.match(TokenType.STAR, TokenType.SLASH):
operator: Token = self.previous() operator: Token = self.previous()
@@ -399,12 +484,15 @@ class MidasParser(Parser):
return expr return expr
def unary(self) -> Expr: def unary(self) -> Expr:
"""Parse a unary expression or a simpler expression """Parse a unary expression
A unary consists of a call expression (see :func:`call`) optionally
preceded by zero or more unary operators (`+`, `-`)
Returns: Returns:
Expr: the parsed expression Expr: the parsed expression
""" """
if self.match(TokenType.MINUS): if self.match(TokenType.PLUS, TokenType.MINUS):
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.unary() right: Expr = self.unary()
location: Location = Location.span(operator.get_location(), right.location) location: Location = Location.span(operator.get_location(), right.location)
@@ -412,12 +500,44 @@ class MidasParser(Parser):
return self.call() return self.call()
def call(self) -> Expr: def call(self) -> Expr:
"""Parse a call expression
A call consists of a reference expression (see :func:`reference`)
optionally followed by zero or more argument groups.
Argument groups are parenthesize, comma-separated list of arguments (see :func:`finish_call`)
Returns:
Expr: the parsed expression
"""
expr: Expr = self.reference() expr: Expr = self.reference()
while self.match(TokenType.LEFT_PAREN): while self.match(TokenType.LEFT_PAREN):
expr = self.finish_call(expr) expr = self.finish_call(expr)
return expr return expr
def finish_call(self, callee: Expr) -> Expr: def finish_call(self, callee: Expr) -> Expr:
"""Parse an argument group, i.e. the arguments of a call
Arguments are either passed positionally or by name (keyword argument).
All positional arguments must come before any keyword argument and
vice-versa. Arguments are separated by commas.
A positional argument simply consists of an expression (see :func:`expression`)
A keyword argument consists of and identifier, followed by the equal `=`
token and an expression (see :func:`expression`).
Args:
callee (Expr): the callee expression
Raises:
ParsingError: if a positional argument is passed after a keyword
argument or if a keyword argument's name is invalid (i.e. not
an identifier)
Returns:
Expr: the parsed call expression
"""
pos_args: list[Expr] = [] pos_args: list[Expr] = []
kw_args: dict[str, Expr] = {} kw_args: dict[str, Expr] = {}
keywords: bool = False keywords: bool = False
@@ -437,13 +557,14 @@ class MidasParser(Parser):
else: else:
value = self.expression() value = self.expression()
if self.check(TokenType.EQUAL): if self.check(TokenType.EQUAL):
error_msg: str
if keywords: if keywords:
raise self.error(self.peek(), "Invalid keyword argument name") error_msg = "Invalid keyword argument name"
else: else:
raise self.error( error_msg = (
self.peek(), "Cannot pass positional arguments after a keyword argument"
"Cannot pass positional arguments after a keyword argument",
) )
raise self.error(self.peek(), error_msg)
pos_args.append(value) pos_args.append(value)
if not self.match(TokenType.COMMA): if not self.match(TokenType.COMMA):
@@ -460,7 +581,12 @@ class MidasParser(Parser):
) )
def reference(self) -> Expr: def reference(self) -> Expr:
"""Parse an attribute access expression or a simpler expression """Parse a reference expression
A reference consists of a primary expression (see :func:`primary`)
optionally followed by zero or more attribute accesses.
An attribute access consists of a dot `.` token followed by an identifier
Returns: Returns:
Expr: the parsed expression Expr: the parsed expression
@@ -475,7 +601,12 @@ class MidasParser(Parser):
def primary(self) -> Expr: def primary(self) -> Expr:
"""Parse a primary expression """Parse a primary expression
This includes literals (booleans, numbers, etc.), wildcards, identifiers and grouped expressions This includes literals (booleans, numbers, etc.), wildcards, identifiers
and grouped expressions
Raises:
ParsingError: if a primary expressions cannot be parsed from the
following tokens
Returns: Returns:
Expr: the parsed expression Expr: the parsed expression
@@ -508,14 +639,41 @@ class MidasParser(Parser):
raise self.error(self.peek(), "Expected expression") raise self.error(self.peek(), "Expected expression")
def consume_identifier(self, message: str = "Expected identifier") -> Token: def consume_identifier(self, message: str = "Expected identifier") -> Token:
"""Consume the current token if it is a valid identifier or raise an error (see :func:`check_identifier`)
If the current token is not a valid identifier, an error is raised
with the provided message
Args:
message (str, optional): the error message. Defaults to "Expected identifier".
Raises:
ParsingError: if the current token is not a valid identifier
Returns:
Token: the current token which is a valid identifier
"""
if not self.match_identifier(): if not self.match_identifier():
raise self.error(self.peek(), message) raise self.error(self.peek(), message)
return self.previous() return self.previous()
def match_identifier(self) -> bool: def match_identifier(self) -> bool:
"""Consume the next token if it is a valid identifier (see :func:`check_identifier`)
Returns:
bool: whether a token was matched and consumed
"""
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values()) return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
def check_identifier(self) -> bool: def check_identifier(self) -> bool:
"""Check whether the current token is a valid identifier
A valid identifier is either an identifier token or a keyword token.
This function always returns False if the parser is at the EOF token
Returns:
bool: True if the current token is a valid identifier and not EOF
"""
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]: for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
if self.check(tt): if self.check(tt):
return True return True
@@ -524,7 +682,14 @@ class MidasParser(Parser):
def member_stmt(self) -> MemberStmt: def member_stmt(self) -> MemberStmt:
"""Parse a member statement """Parse a member statement
A type member statement is written `prop name: Type` or `def name: Type` A member statement is written consists of:
- the `prop` (for a property) or `def` (for a method) keyword
- an name (identifier)
- a colon `:`
- a type expression (see :func:`type_expr`)
Raises:
ParsingError: if the first token is neither `prop` nor `def`
Returns: Returns:
MemberStmt: the parsed member statement MemberStmt: the parsed member statement
@@ -551,7 +716,13 @@ class MidasParser(Parser):
def extend_declaration(self) -> ExtendStmt: def extend_declaration(self) -> ExtendStmt:
"""Parse an extension definition """Parse an extension definition
An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }` An extension statement consists of:
- the `extend` keyword
- a type name (identifier)
- (optional) type parameters (see :func:`type_params`)
- an opening brace `{`
- zero or more member statements (see :func:`member_stmt`)
- a closing brace `}`
Returns: Returns:
ExtendStmt: the parsed extension statement ExtendStmt: the parsed extension statement
@@ -576,7 +747,12 @@ class MidasParser(Parser):
def predicate_declaration(self) -> PredicateStmt: def predicate_declaration(self) -> PredicateStmt:
"""Parse a predicate declaration """Parse a predicate declaration
A predicate is written `predicate Name(subject: Type) = constraint_expression` A predicate statement consists of:
- the `predicate` keyword
- a name (identifier)
- (optional) zero or more parameter specs (see :func:`function_params`)
- an equal sign `=`
- a body, a constraint expression (see :func:`constraint`)
Returns: Returns:
PredicateStmt: the parsed predicate declaration statement PredicateStmt: the parsed predicate declaration statement
@@ -587,7 +763,7 @@ class MidasParser(Parser):
params: list[ParamSpec] = [] params: list[ParamSpec] = []
while self.check(TokenType.LEFT_PAREN): while self.check(TokenType.LEFT_PAREN):
params.append(self.function_args()) params.append(self.function_params())
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject") self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
body: Expr = self.constraint() body: Expr = self.constraint()
@@ -599,7 +775,18 @@ class MidasParser(Parser):
) )
def function(self) -> FunctionType: def function(self) -> FunctionType:
params: ParamSpec = self.function_args() """Parse a function type expression
A function consists of:
- the `fn` keyword
- a parameter spec (see :func:`function_params`)
- the arrow keyword `->`
- a result type expression (see :func:`type_expr`)
Returns:
FunctionType: the parsed function type expression
"""
params: ParamSpec = self.function_params()
self.consume(TokenType.ARROW, "Expected '->' before result type") self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr() result: Type = self.type_expr()
@@ -610,36 +797,53 @@ class MidasParser(Parser):
returns=result, returns=result,
) )
def function_args(self) -> ParamSpec: def function_params(self) -> ParamSpec:
"""Parse a parameter spec
A parameter spec consists of zero or more comma-separated parameters,
wrapped in parentheses.
Like in Python, it can contain positional-only, mixed and keyword-only
parameters (separated by `/` and `*`).
Each parameter has a type (see :func:`type_expr`),
preceded by a name (identifier) and a colon `:` (not required for
positional-only parameters).
Returns:
ParamSpec: the parsed parameter spec
"""
l_paren: Token = self.consume( l_paren: Token = self.consume(
TokenType.LEFT_PAREN, "Expected '(' before function parameters" TokenType.LEFT_PAREN, "Expected '(' before function parameters"
) )
pos_args: list[FunctionType.Argument] = [] pos: list[FunctionType.Parameter] = []
args: list[FunctionType.Argument] = [] mixed: list[FunctionType.Parameter] = []
kw_args: list[FunctionType.Argument] = [] kw: list[FunctionType.Parameter] = []
args_first_tokens: list[Token] = [] mixed_first_tokens: list[Token] = []
section: int = 0 section: int = 0
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN): while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
match section: match section:
case 0 if self.match(TokenType.SLASH): case 0 if self.match(TokenType.SLASH):
pos_args = args pos = mixed
args = [] mixed = []
args_first_tokens = [] mixed_first_tokens = []
section = 1 section = 1
case 0 | 1 if self.match(TokenType.STAR): case 0 | 1 if self.match(TokenType.STAR):
section = 2 section = 2
case _: case _:
# Record first token of mixed argument for errors if unnamed # Record first token of mixed parameters for errors if unnamed
if section != 2: if section != 2:
args_first_tokens.append(self.peek()) mixed_first_tokens.append(self.peek())
name: Optional[Token] = None name: Optional[Token] = None
if section == 2: if section == 2:
name = self.consume_identifier("Expected keyword argument name") name = self.consume_identifier(
"Expected keyword parameter name"
)
self.consume( self.consume(
TokenType.COLON, "Expected ':' after argument name" TokenType.COLON, "Expected ':' after parameter name"
) )
elif self.check_identifier() and self.check_next(TokenType.COLON): elif self.check_identifier() and self.check_next(TokenType.COLON):
name = self.advance() name = self.advance()
@@ -647,24 +851,24 @@ class MidasParser(Parser):
type: Type = self.type_expr() type: Type = self.type_expr()
optional: bool = self.match(TokenType.QMARK) optional: bool = self.match(TokenType.QMARK)
arg = FunctionType.Argument( param = FunctionType.Parameter(
location=None, location=None,
name=name, name=name,
type=type, type=type,
required=not optional, required=not optional,
) )
if section == 2: if section == 2:
kw_args.append(arg) kw.append(param)
else: else:
args.append(arg) mixed.append(param)
if not self.match(TokenType.COMMA): if not self.match(TokenType.COMMA):
break break
for arg, token in zip(args, args_first_tokens): for param, token in zip(mixed, mixed_first_tokens):
if arg.name is None: if param.name is None:
# Not raised because we can keep parsing # Not raised because we can keep parsing
self.error(token, "Unnamed mixed argument") self.error(token, "Unnamed mixed parameter")
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters") self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args) return ParamSpec(l_paren=l_paren, pos=pos, mixed=mixed, kw=kw)

View File

@@ -23,6 +23,7 @@ from midas.ast.python import (
LiteralExpr, LiteralExpr,
LogicalExpr, LogicalExpr,
MidasType, MidasType,
ParamSpec,
RawExpr, RawExpr,
RawStmt, RawStmt,
ReturnStmt, ReturnStmt,
@@ -49,6 +50,8 @@ class UnsupportedSyntaxError(Exception):
class PythonParser: class PythonParser:
"""A parser to convert raw Python `ast` nodes in custom IR nodes"""
CAST_FUNCTION = "cast" CAST_FUNCTION = "cast"
UNSAFE_CAST_FUNCTION = "unsafe_cast" UNSAFE_CAST_FUNCTION = "unsafe_cast"
@@ -212,27 +215,10 @@ class PythonParser:
match node: match node:
case ast.FunctionDef( case ast.FunctionDef(
name=name, name=name,
args=ast.arguments(
posonlyargs=posonlyargs,
args=args, args=args,
vararg=sink,
kwonlyargs=kwonlyargs,
kwarg=kw_sink,
defaults=defaults,
kw_defaults=kw_defaults,
),
returns=returns, returns=returns,
body=raw_body, body=raw_body,
): ):
def parse_args(
args_list: list[ast.arg], defaults: list[Optional[Expr]]
) -> list[Function.Argument]:
return [
self._parse_function_argument(arg, default)
for arg, default in zip(args_list, defaults)
]
body: list[Stmt] = [] body: list[Stmt] = []
for stmt in raw_body: for stmt in raw_body:
stmts = self.parse_stmt(stmt) stmts = self.parse_stmt(stmt)
@@ -241,54 +227,58 @@ class PythonParser:
elif stmts is not None: elif stmts is not None:
body.extend(stmts) body.extend(stmts)
parsed_defaults: list[Optional[Expr]] = [
self.parse_expr(default) for default in defaults
]
n_posargs: int = len(posonlyargs)
n_args: int = len(args)
n_all_posargs = n_posargs + n_args
parsed_defaults = [
None,
] * (n_all_posargs - len(defaults)) + parsed_defaults
posargs_defaults: list[Optional[Expr]] = parsed_defaults[:n_posargs]
args_defaults: list[Optional[Expr]] = parsed_defaults[n_posargs:]
kwargs_defaults: list[Optional[Expr]] = [
self.parse_expr(default) if default is not None else None
for default in kw_defaults
]
return Function( return Function(
location=loc, location=loc,
name=name, name=name,
posonlyargs=parse_args(posonlyargs, posargs_defaults), params=self._parse_param_spec(args),
args=parse_args(args, args_defaults),
sink=(
self._parse_function_argument(sink, None)
if sink is not None
else None
),
kwonlyargs=parse_args(kwonlyargs, kwargs_defaults),
kw_sink=(
self._parse_function_argument(kw_sink, None)
if kw_sink is not None
else None
),
returns=self._parse_type(returns) if returns is not None else None, returns=self._parse_type(returns) if returns is not None else None,
body=body, body=body,
) )
case _: case _:
print(f"Unsupported function definition: {ast.unparse(node)}") print(f"Unsupported function definition: {ast.unparse(node)}")
def _parse_function_argument( def _parse_param_spec(self, args: ast.arguments) -> ParamSpec:
def parse_params(
args_list: list[ast.arg], defaults: list[Optional[Expr]]
) -> list[Function.Parameter]:
return [
self._parse_function_parameter(arg, default)
for arg, default in zip(args_list, defaults)
]
defaults: list[ast.expr] = args.defaults
parsed_defaults: list[Optional[Expr]] = [
self.parse_expr(default) for default in defaults
]
n_pos: int = len(args.posonlyargs)
n_mixed: int = len(args.args)
n_all_pos = n_pos + n_mixed
parsed_defaults = [
None,
] * (n_all_pos - len(defaults)) + parsed_defaults
pos_defaults: list[Optional[Expr]] = parsed_defaults[:n_pos]
mixed_defaults: list[Optional[Expr]] = parsed_defaults[n_pos:]
kw_defaults: list[Optional[Expr]] = [
self.parse_expr(default) if default is not None else None
for default in args.kw_defaults
]
return ParamSpec(
pos=parse_params(args.posonlyargs, pos_defaults),
mixed=parse_params(args.args, mixed_defaults),
kw=parse_params(args.kwonlyargs, kw_defaults),
)
def _parse_function_parameter(
self, arg: ast.arg, default: Optional[Expr] self, arg: ast.arg, default: Optional[Expr]
) -> Function.Argument: ) -> Function.Parameter:
loc: Location = Location.from_ast(arg) loc: Location = Location.from_ast(arg)
name: str = arg.arg name: str = arg.arg
type: Optional[MidasType] = None type: Optional[MidasType] = None
if arg.annotation is not None: if arg.annotation is not None:
type = self._parse_type(arg.annotation) type = self._parse_type(arg.annotation)
return Function.Argument( return Function.Parameter(
location=loc, location=loc,
name=name, name=name,
type=type, type=type,

43
tests/__main__.py Normal file
View File

@@ -0,0 +1,43 @@
from typing import Type
from midas.cli.ansi import Ansi
from tests.base import Tester
from tests.checker import CheckerTester
from tests.generator import GeneratorTester
from tests.midas import MidasTester
from tests.python import PythonTester
def print_banner(name: str):
horizontal: str = "+" + "-" * (len(name) + 2) + "+"
print(horizontal)
print(f"| {name} |")
print(horizontal)
def run_tests(tester_cls: Type[Tester]) -> bool:
print_banner(tester_cls.__name__)
tester: Tester = tester_cls()
success: bool = tester.run_all_tests()
print()
return success
def main():
testers: list[Type[Tester]] = [
PythonTester,
MidasTester,
CheckerTester,
GeneratorTester,
]
success: bool = all(map(run_tests, testers))
if success:
print(Ansi.FG(Ansi.BRIGHT_GREEN) + "All tests passed!" + Ansi.RESET)
else:
print(Ansi.FG(Ansi.BRIGHT_RED) + "Some tests failed!" + Ansi.RESET)
if __name__ == "__main__":
main()

View File

@@ -7,6 +7,8 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Iterator, Protocol from typing import Iterator, Protocol
from midas.cli.ansi import Ansi
class CaseResult(Protocol): class CaseResult(Protocol):
def dumps(self) -> str: ... def dumps(self) -> str: ...
@@ -44,8 +46,11 @@ class Tester(ABC):
print(rule) print(rule)
for i, test in enumerate(tests): for i, test in enumerate(tests):
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}") path: Path = test.resolve().relative_to(self.CASES_DIR)
print(f"{Ansi.FG(Ansi.BRIGHT_CYAN)}Case {i+1}/{n}: {path}{Ansi.RESET}")
print(Ansi.DIM, end="")
success: bool = self._run_test(test) success: bool = self._run_test(test)
print(Ansi.RESET, end="")
if success: if success:
successes += 1 successes += 1
else: else:
@@ -146,7 +151,8 @@ class Tester(ABC):
if not success: if not success:
sys.exit(1) sys.exit(1)
case None: case None:
print("No subcommand provided. Available subcommands: run, update") success: bool = tester.run_all_tests()
if not success:
sys.exit(1) sys.exit(1)
case _: case _:
print(f"Unknown subcommand '{args.subcommand}'") print(f"Unknown subcommand '{args.subcommand}'")

View File

@@ -124,7 +124,7 @@
22 22
] ]
}, },
"message": "Multiple values for argument 'b'" "message": "Multiple values for parameter 'b'"
}, },
{ {
"type": "Error", "type": "Error",
@@ -152,7 +152,7 @@
12 12
] ]
}, },
"message": "Unknown keyword argument 'a'" "message": "Unknown keyword parameter 'a'"
}, },
{ {
"type": "Error", "type": "Error",
@@ -194,7 +194,7 @@
17 17
] ]
}, },
"message": "Unknown keyword argument 'g'" "message": "Unknown keyword parameter 'g'"
}, },
{ {
"type": "Error", "type": "Error",
@@ -277,7 +277,8 @@
"name": "foo" "name": "foo"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "a", "name": "a",
@@ -287,7 +288,7 @@
"required": true "required": true
} }
], ],
"args": [ "mixed": [
{ {
"pos": 1, "pos": 1,
"name": "b", "name": "b",
@@ -297,7 +298,7 @@
"required": true "required": true
} }
], ],
"kw_args": [ "kw": [
{ {
"pos": 2, "pos": 2,
"name": "c", "name": "c",
@@ -306,7 +307,8 @@
}, },
"required": true "required": true
} }
], ]
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }
@@ -351,7 +353,8 @@
"name": "foo" "name": "foo"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "a", "name": "a",
@@ -361,7 +364,7 @@
"required": true "required": true
} }
], ],
"args": [ "mixed": [
{ {
"pos": 1, "pos": 1,
"name": "b", "name": "b",
@@ -371,7 +374,7 @@
"required": true "required": true
} }
], ],
"kw_args": [ "kw": [
{ {
"pos": 2, "pos": 2,
"name": "c", "name": "c",
@@ -380,7 +383,8 @@
}, },
"required": true "required": true
} }
], ]
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }
@@ -443,7 +447,8 @@
"name": "foo" "name": "foo"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "a", "name": "a",
@@ -453,7 +458,7 @@
"required": true "required": true
} }
], ],
"args": [ "mixed": [
{ {
"pos": 1, "pos": 1,
"name": "b", "name": "b",
@@ -463,7 +468,7 @@
"required": true "required": true
} }
], ],
"kw_args": [ "kw": [
{ {
"pos": 2, "pos": 2,
"name": "c", "name": "c",
@@ -472,7 +477,8 @@
}, },
"required": true "required": true
} }
], ]
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }
@@ -539,7 +545,8 @@
"name": "foo" "name": "foo"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "a", "name": "a",
@@ -549,7 +556,7 @@
"required": true "required": true
} }
], ],
"args": [ "mixed": [
{ {
"pos": 1, "pos": 1,
"name": "b", "name": "b",
@@ -559,7 +566,7 @@
"required": true "required": true
} }
], ],
"kw_args": [ "kw": [
{ {
"pos": 2, "pos": 2,
"name": "c", "name": "c",
@@ -568,7 +575,8 @@
}, },
"required": true "required": true
} }
], ]
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }
@@ -649,7 +657,8 @@
"name": "foo" "name": "foo"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "a", "name": "a",
@@ -659,7 +668,7 @@
"required": true "required": true
} }
], ],
"args": [ "mixed": [
{ {
"pos": 1, "pos": 1,
"name": "b", "name": "b",
@@ -669,7 +678,7 @@
"required": true "required": true
} }
], ],
"kw_args": [ "kw": [
{ {
"pos": 2, "pos": 2,
"name": "c", "name": "c",
@@ -678,7 +687,8 @@
}, },
"required": true "required": true
} }
], ]
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }
@@ -762,7 +772,8 @@
"name": "foo" "name": "foo"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "a", "name": "a",
@@ -772,7 +783,7 @@
"required": true "required": true
} }
], ],
"args": [ "mixed": [
{ {
"pos": 1, "pos": 1,
"name": "b", "name": "b",
@@ -782,7 +793,7 @@
"required": true "required": true
} }
], ],
"kw_args": [ "kw": [
{ {
"pos": 2, "pos": 2,
"name": "c", "name": "c",
@@ -791,7 +802,8 @@
}, },
"required": true "required": true
} }
], ]
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }
@@ -850,7 +862,8 @@
"name": "foo" "name": "foo"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "a", "name": "a",
@@ -860,7 +873,7 @@
"required": true "required": true
} }
], ],
"args": [ "mixed": [
{ {
"pos": 1, "pos": 1,
"name": "b", "name": "b",
@@ -870,7 +883,7 @@
"required": true "required": true
} }
], ],
"kw_args": [ "kw": [
{ {
"pos": 2, "pos": 2,
"name": "c", "name": "c",
@@ -879,7 +892,8 @@
}, },
"required": true "required": true
} }
], ]
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }
@@ -929,7 +943,8 @@
"name": "foo" "name": "foo"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "a", "name": "a",
@@ -939,7 +954,7 @@
"required": true "required": true
} }
], ],
"args": [ "mixed": [
{ {
"pos": 1, "pos": 1,
"name": "b", "name": "b",
@@ -949,7 +964,7 @@
"required": true "required": true
} }
], ],
"kw_args": [ "kw": [
{ {
"pos": 2, "pos": 2,
"name": "c", "name": "c",
@@ -958,7 +973,8 @@
}, },
"required": true "required": true
} }
], ]
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }
@@ -1034,7 +1050,8 @@
"name": "foo" "name": "foo"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "a", "name": "a",
@@ -1044,7 +1061,7 @@
"required": true "required": true
} }
], ],
"args": [ "mixed": [
{ {
"pos": 1, "pos": 1,
"name": "b", "name": "b",
@@ -1054,7 +1071,7 @@
"required": true "required": true
} }
], ],
"kw_args": [ "kw": [
{ {
"pos": 2, "pos": 2,
"name": "c", "name": "c",
@@ -1063,7 +1080,8 @@
}, },
"required": true "required": true
} }
], ]
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }
@@ -1150,7 +1168,8 @@
"name": "foo" "name": "foo"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "a", "name": "a",
@@ -1160,7 +1179,7 @@
"required": true "required": true
} }
], ],
"args": [ "mixed": [
{ {
"pos": 1, "pos": 1,
"name": "b", "name": "b",
@@ -1170,7 +1189,7 @@
"required": true "required": true
} }
], ],
"kw_args": [ "kw": [
{ {
"pos": 2, "pos": 2,
"name": "c", "name": "c",
@@ -1179,7 +1198,8 @@
}, },
"required": true "required": true
} }
], ]
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }
@@ -1266,7 +1286,8 @@
"name": "foo" "name": "foo"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "a", "name": "a",
@@ -1276,7 +1297,7 @@
"required": true "required": true
} }
], ],
"args": [ "mixed": [
{ {
"pos": 1, "pos": 1,
"name": "b", "name": "b",
@@ -1286,7 +1307,7 @@
"required": true "required": true
} }
], ],
"kw_args": [ "kw": [
{ {
"pos": 2, "pos": 2,
"name": "c", "name": "c",
@@ -1295,7 +1316,8 @@
}, },
"required": true "required": true
} }
], ]
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }
@@ -1382,7 +1404,8 @@
"name": "foo" "name": "foo"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "a", "name": "a",
@@ -1392,7 +1415,7 @@
"required": true "required": true
} }
], ],
"args": [ "mixed": [
{ {
"pos": 1, "pos": 1,
"name": "b", "name": "b",
@@ -1402,7 +1425,7 @@
"required": true "required": true
} }
], ],
"kw_args": [ "kw": [
{ {
"pos": 2, "pos": 2,
"name": "c", "name": "c",
@@ -1411,7 +1434,8 @@
}, },
"required": true "required": true
} }
], ]
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }

View File

@@ -136,8 +136,9 @@
"name": "maximum" "name": "maximum"
}, },
"type": { "type": {
"pos_args": [], "params": {
"args": [ "pos": [],
"mixed": [
{ {
"pos": 0, "pos": 0,
"name": "a", "name": "a",
@@ -155,7 +156,8 @@
"required": true "required": true
} }
], ],
"kw_args": [], "kw": []
},
"returns": { "returns": {
"name": "float" "name": "float"
} }

View File

@@ -312,7 +312,8 @@
"name": "print" "name": "print"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "object", "name": "object",
@@ -320,8 +321,9 @@
"required": false "required": false
} }
], ],
"args": [], "mixed": [],
"kw_args": [], "kw": []
},
"returns": {} "returns": {}
} }
}, },

View File

@@ -120,7 +120,8 @@
"name": "bool" "name": "bool"
}, },
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "object", "name": "object",
@@ -128,8 +129,9 @@
"required": false "required": false
} }
], ],
"args": [], "mixed": [],
"kw_args": [], "kw": []
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }
@@ -377,8 +379,9 @@
"name": "double" "name": "double"
}, },
"type": { "type": {
"pos_args": [], "params": {
"args": [ "pos": [],
"mixed": [
{ {
"pos": 0, "pos": 0,
"name": "value", "name": "value",
@@ -388,7 +391,8 @@
"required": true "required": true
} }
], ],
"kw_args": [], "kw": []
},
"returns": { "returns": {
"name": "float" "name": "float"
} }
@@ -439,12 +443,14 @@
} }
], ],
"body": { "body": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "transform", "name": "transform",
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "v", "name": "v",
@@ -456,8 +462,9 @@
"required": true "required": true
} }
], ],
"args": [], "mixed": [],
"kw_args": [], "kw": []
},
"returns": { "returns": {
"name": "U", "name": "U",
"bound": null, "bound": null,
@@ -485,8 +492,9 @@
"required": true "required": true
} }
], ],
"args": [], "mixed": [],
"kw_args": [], "kw": []
},
"returns": { "returns": {
"name": "list", "name": "list",
"args": [ "args": [
@@ -548,8 +556,9 @@
"name": "double" "name": "double"
}, },
"type": { "type": {
"pos_args": [], "params": {
"args": [ "pos": [],
"mixed": [
{ {
"pos": 0, "pos": 0,
"name": "value", "name": "value",
@@ -559,7 +568,8 @@
"required": true "required": true
} }
], ],
"kw_args": [], "kw": []
},
"returns": { "returns": {
"name": "float" "name": "float"
} }
@@ -610,12 +620,14 @@
} }
], ],
"body": { "body": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "transform", "name": "transform",
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "v", "name": "v",
@@ -627,8 +639,9 @@
"required": true "required": true
} }
], ],
"args": [], "mixed": [],
"kw_args": [], "kw": []
},
"returns": { "returns": {
"name": "U", "name": "U",
"bound": null, "bound": null,
@@ -656,8 +669,9 @@
"required": true "required": true
} }
], ],
"args": [], "mixed": [],
"kw_args": [], "kw": []
},
"returns": { "returns": {
"name": "list", "name": "list",
"args": [ "args": [
@@ -709,8 +723,9 @@
"name": "is_odd" "name": "is_odd"
}, },
"type": { "type": {
"pos_args": [], "params": {
"args": [ "pos": [],
"mixed": [
{ {
"pos": 0, "pos": 0,
"name": "value", "name": "value",
@@ -720,7 +735,8 @@
"required": true "required": true
} }
], ],
"kw_args": [], "kw": []
},
"returns": { "returns": {
"name": "bool" "name": "bool"
} }
@@ -771,12 +787,14 @@
} }
], ],
"body": { "body": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "transform", "name": "transform",
"type": { "type": {
"pos_args": [ "params": {
"pos": [
{ {
"pos": 0, "pos": 0,
"name": "v", "name": "v",
@@ -788,8 +806,9 @@
"required": true "required": true
} }
], ],
"args": [], "mixed": [],
"kw_args": [], "kw": []
},
"returns": { "returns": {
"name": "U", "name": "U",
"bound": null, "bound": null,
@@ -817,8 +836,9 @@
"required": true "required": true
} }
], ],
"args": [], "mixed": [],
"kw_args": [], "kw": []
},
"returns": { "returns": {
"name": "list", "name": "list",
"args": [ "args": [

View File

@@ -0,0 +1,117 @@
# type: ignore
# ruff: disable [F821]
df1: Frame[a:int, b:float]
df2: Frame[a:int, b:float]
_: Any
# Arithmetic
_ = df1 + df2
_ = df1 - df2
_ = df1 * df2
_ = df1 / df2
_ = df1 // df2
_ = df1 % df2
_ = df1**df2
# Comparisons
_ = df1 < df2
_ = df1 > df2
_ = df1 <= df2
_ = df1 >= df2
_ = df1 != df2
_ = df1 == df2
# Aggregate
_ = df1.kurt()
_ = df1.kurtosis()
_ = df1.max()
_ = df1.mean()
_ = df1.median()
_ = df1.min()
_ = df1.mode()
_ = df1.prod()
_ = df1.product()
_ = df1.std()
_ = df1.sum()
_ = df1.var()
# Groupby
df_gb = df1.groupby(by="a")
_ = df_gb.kurt()
_ = df_gb.max()
_ = df_gb.mean()
_ = df_gb.median()
_ = df_gb.min()
_ = df_gb.prod()
_ = df_gb.std()
_ = df_gb.sum()
_ = df_gb.var()
# Columns
col1 = df1["a"]
col2 = df1["a"]
# Arithmetic
_ = col1 + col2
_ = col1 - col2
_ = col1 * col2
_ = col1 / col2
_ = col1 // col2
_ = col1 % col2
_ = col1**col2
# Comparisons
_ = col1 < col2
_ = col1 > col2
_ = col1 <= col2
_ = col1 >= col2
_ = col1 != col2
_ = col1 == col2
# Aggregate
_ = col1.kurt()
_ = col1.kurtosis()
_ = col1.max()
_ = col1.mean()
_ = col1.median()
_ = col1.min()
_ = col1.mode()
_ = col1.prod()
_ = col1.product()
_ = col1.std()
_ = col1.sum()
_ = col1.var()
# Groupby
col_gb = col1.groupby(level=0)
_ = col_gb.kurt()
_ = col_gb.max()
_ = col_gb.mean()
_ = col_gb.median()
_ = col_gb.min()
_ = col_gb.prod()
_ = col_gb.std()
_ = col_gb.sum()
_ = col_gb.var()
# Attributes
_ = df1.ndim # int
_ = df1.size # int
_ = df1.shape # (int, int)
_ = col1.ndim # int
_ = col1.size # int
_ = col1.shape # (int)
_ = col1.T # Column[int]
# Misc
_ = df1.head()
_ = df1.tail()
_ = col1.head()
_ = col1.tail()

File diff suppressed because it is too large Load Diff

View File

@@ -7,8 +7,10 @@
{ {
"_type": "Function", "_type": "Function",
"name": "func", "name": "func",
"posonlyargs": [], "params": {
"args": [ "_type": "ParamSpec",
"pos": [],
"mixed": [
{ {
"name": "col1", "name": "col1",
"type": { "type": {
@@ -48,9 +50,8 @@
"default": null "default": null
} }
], ],
"sink": null, "kw": []
"kwonlyargs": [], },
"kw_sink": null,
"returns": { "returns": {
"_type": "BaseType", "_type": "BaseType",
"base": "Column", "base": "Column",
@@ -119,7 +120,9 @@
{ {
"_type": "Function", "_type": "Function",
"name": "func2", "name": "func2",
"posonlyargs": [ "params": {
"_type": "ParamSpec",
"pos": [
{ {
"name": "a", "name": "a",
"type": { "type": {
@@ -130,7 +133,7 @@
"default": null "default": null
} }
], ],
"args": [ "mixed": [
{ {
"name": "b", "name": "b",
"type": { "type": {
@@ -141,8 +144,7 @@
"default": null "default": null
} }
], ],
"sink": null, "kw": [
"kwonlyargs": [
{ {
"name": "c", "name": "c",
"type": { "type": {
@@ -152,8 +154,8 @@
}, },
"default": null "default": null
} }
], ]
"kw_sink": null, },
"returns": null, "returns": null,
"body": [] "body": []
} }

View File

@@ -188,16 +188,16 @@ class MidasAstJsonSerializer(
def _serialize_param_spec(self, spec: ParamSpec) -> dict: def _serialize_param_spec(self, spec: ParamSpec) -> dict:
return { return {
"_type": "ParamSpec", "_type": "ParamSpec",
"pos": [self._serialize_func_arg(arg) for arg in spec.pos], "pos": [self._serialize_func_param(arg) for arg in spec.pos],
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed], "mixed": [self._serialize_func_param(arg) for arg in spec.mixed],
"kw": [self._serialize_func_arg(arg) for arg in spec.kw], "kw": [self._serialize_func_param(arg) for arg in spec.kw],
} }
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict: def _serialize_func_param(self, param: FunctionType.Parameter) -> dict:
return { return {
"name": arg.name.lexeme if arg.name is not None else None, "name": param.name.lexeme if param.name is not None else None,
"type": arg.type.accept(self), "type": param.type.accept(self),
"required": arg.required, "required": param.required,
} }
def visit_extension_type(self, type: ExtensionType) -> dict: def visit_extension_type(self, type: ExtensionType) -> dict:

View File

@@ -22,6 +22,7 @@ from midas.ast.python import (
LiteralExpr, LiteralExpr,
LogicalExpr, LogicalExpr,
MidasType, MidasType,
ParamSpec,
Pass, Pass,
RawExpr, RawExpr,
RawStmt, RawStmt,
@@ -128,32 +129,30 @@ class PythonAstJsonSerializer(
"expr": stmt.expr.accept(self), "expr": stmt.expr.accept(self),
} }
def _serialize_argument(self, arg: Function.Argument) -> dict:
return {
"name": arg.name,
"type": self._serialize_optional(arg.type),
"default": self._serialize_optional(arg.default),
}
def visit_function(self, stmt: Function) -> dict: def visit_function(self, stmt: Function) -> dict:
return { return {
"_type": "Function", "_type": "Function",
"name": stmt.name, "name": stmt.name,
"posonlyargs": [self._serialize_argument(arg) for arg in stmt.posonlyargs], "params": self._serialize_param_spec(stmt.params),
"args": [self._serialize_argument(arg) for arg in stmt.args],
"sink": (
self._serialize_argument(stmt.sink) if stmt.sink is not None else None
),
"kwonlyargs": [self._serialize_argument(arg) for arg in stmt.kwonlyargs],
"kw_sink": (
self._serialize_argument(stmt.kw_sink)
if stmt.kw_sink is not None
else None
),
"returns": self._serialize_optional(stmt.returns), "returns": self._serialize_optional(stmt.returns),
"body": self._serialize_list(stmt.body), "body": self._serialize_list(stmt.body),
} }
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
return {
"_type": "ParamSpec",
"pos": [self._serialize_func_param(arg) for arg in spec.pos],
"mixed": [self._serialize_func_param(arg) for arg in spec.mixed],
"kw": [self._serialize_func_param(arg) for arg in spec.kw],
}
def _serialize_func_param(self, param: Function.Parameter) -> dict:
return {
"name": param.name,
"type": self._serialize_optional(param.type),
"default": self._serialize_optional(param.default),
}
def visit_type_assign(self, stmt: TypeAssign) -> dict: def visit_type_assign(self, stmt: TypeAssign) -> dict:
return { return {
"_type": "TypeAssign", "_type": "TypeAssign",