Compare commits
104 Commits
v0.0.1-pro
...
bea3f399ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
bea3f399ad
|
|||
|
55060bfecd
|
|||
|
dd126f2559
|
|||
|
4151f5373d
|
|||
|
bd31713ab4
|
|||
|
f4dc57cb96
|
|||
|
261fd47494
|
|||
|
1b66a8553d
|
|||
|
65164abadb
|
|||
|
9d45163d9c
|
|||
|
ab0fa1de1a
|
|||
|
5d4df7978b
|
|||
|
86ad348b99
|
|||
|
29f691e38a
|
|||
|
f2c61d24e2
|
|||
|
112ed0e816
|
|||
|
7eb1e13b70
|
|||
|
893e1ba190
|
|||
|
1a1b0e8e15
|
|||
|
4ddde364ed
|
|||
|
4a3363a3d6
|
|||
|
0a3216e07d
|
|||
|
c29c0ed3ec
|
|||
|
fa7e56cb77
|
|||
|
13c19db818
|
|||
|
95b218fbed
|
|||
|
c3722c7438
|
|||
|
9dd547d6c1
|
|||
|
e2d5943517
|
|||
|
86e4763a12
|
|||
|
89ec63cb05
|
|||
|
e6375f1aa9
|
|||
|
d16e192a3a
|
|||
|
3f61f84e5a
|
|||
|
fd5399f50a
|
|||
|
8906ac3db8
|
|||
|
022aebf55b
|
|||
|
5dc6903425
|
|||
|
1b078b832c
|
|||
|
7515716864
|
|||
|
218b0c5b78
|
|||
|
928901ef9c
|
|||
|
4b62c78874
|
|||
|
f882eebaf5
|
|||
|
a872938405
|
|||
|
146be72fd7
|
|||
|
6de54e1da1
|
|||
|
c82b41a4df
|
|||
|
8304760fe0
|
|||
|
6bf91db757
|
|||
|
3f6b650a4b
|
|||
| ec079f32ca | |||
|
6524b3591a
|
|||
|
170101aa37
|
|||
|
0b3f33d7fe
|
|||
|
8a9b4f3989
|
|||
|
bbd0e3ae8d
|
|||
|
4d23e8840e
|
|||
|
c64d626d1c
|
|||
|
ecab1b74a4
|
|||
|
0bbdf04621
|
|||
|
939e5af4ce
|
|||
|
a735113466
|
|||
|
0e0a1b26f2
|
|||
|
e94db2181f
|
|||
|
9b59058881
|
|||
|
d0c54db33a
|
|||
|
5aedddfabb
|
|||
|
8d7c115432
|
|||
|
832c350b61
|
|||
|
3d599b3462
|
|||
|
4f799caaf5
|
|||
|
f4d2be3b1b
|
|||
|
7ce2840f03
|
|||
|
e2f3cabe15
|
|||
|
5a112332f2
|
|||
|
eb79cf6dc3
|
|||
|
8a9bb6ef4e
|
|||
|
6e0190a378
|
|||
| b5969e9a2b | |||
|
409d9f8fa6
|
|||
|
12d762429d
|
|||
|
53929ee514
|
|||
|
2f6e137f1a
|
|||
|
5224e79d9f
|
|||
|
bdcb12c58a
|
|||
|
5cb4d587e3
|
|||
|
8f9ec8d73b
|
|||
|
c1c50a448e
|
|||
|
19229db0b1
|
|||
|
f3b6bd146f
|
|||
|
98c3510bd4
|
|||
|
429d0d98fe
|
|||
|
db8fe5d3ff
|
|||
|
7477ec8d70
|
|||
|
adf7f4e7a2
|
|||
|
abf6787946
|
|||
|
e282b08597
|
|||
|
0a02b9d3d9
|
|||
| 875ca589e4 | |||
|
88f92d6e1f
|
|||
|
db4ed74365
|
|||
|
7cbf4fdece
|
|||
|
1fa9a09bfe
|
5
.gitignore
vendored
5
.gitignore
vendored
@@ -3,4 +3,7 @@ __pycache__
|
|||||||
.env
|
.env
|
||||||
venv
|
venv
|
||||||
.venv
|
.venv
|
||||||
*.pyc
|
*.pyc
|
||||||
|
uv.lock
|
||||||
|
.python-version
|
||||||
|
/out
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Generic, Optional, TypeVar
|
|
||||||
|
|
||||||
from lexer.token import Token
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class Stmt(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
|
||||||
|
|
||||||
class Visitor(ABC, Generic[T]):
|
|
||||||
@abstractmethod
|
|
||||||
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> T: ...
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class AnnotationStmt(Stmt):
|
|
||||||
name: Token
|
|
||||||
schema: Optional[SchemaExpr]
|
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_annotation_stmt(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class Expr(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
|
||||||
|
|
||||||
class Visitor(ABC, Generic[T]):
|
|
||||||
@abstractmethod
|
|
||||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_type_expr(self, expr: TypeExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_schema_expr(self, expr: SchemaExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> T: ...
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class WildcardExpr(Expr):
|
|
||||||
token: Token
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_wildcard_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class LiteralExpr(Expr):
|
|
||||||
value: Any
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_literal_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class TypeExpr(Expr):
|
|
||||||
name: Token
|
|
||||||
constraints: list[ConstraintExpr]
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_type_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ConstraintExpr(Expr):
|
|
||||||
left: Expr
|
|
||||||
op: Token
|
|
||||||
right: Expr
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_constraint_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class SchemaExpr(Expr):
|
|
||||||
left: Token
|
|
||||||
elements: list[Expr]
|
|
||||||
right: Token
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_schema_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class SchemaElementExpr(Expr):
|
|
||||||
name: Optional[Token]
|
|
||||||
type: Optional[Expr]
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_schema_element_expr(self)
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Generic, Optional, TypeVar
|
|
||||||
|
|
||||||
from lexer.token import Token
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
# Statements
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class Stmt(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
|
||||||
|
|
||||||
class Visitor(ABC, Generic[T]):
|
|
||||||
@abstractmethod
|
|
||||||
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_op_stmt(self, stmt: OpStmt) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_constraint_stmt(self, stmt: ConstraintStmt) -> T: ...
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class TypeStmt(Stmt):
|
|
||||||
name: Token
|
|
||||||
bases: list[TypeExpr]
|
|
||||||
body: Optional[TypeBodyExpr]
|
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_type_stmt(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class PropertyStmt(Stmt):
|
|
||||||
name: Token
|
|
||||||
type: TypeExpr
|
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_property_stmt(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class OpStmt(Stmt):
|
|
||||||
left: TypeExpr
|
|
||||||
op: Token
|
|
||||||
right: TypeExpr
|
|
||||||
result: TypeExpr
|
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_op_stmt(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ConstraintStmt(Stmt):
|
|
||||||
name: Token
|
|
||||||
constraint: ConstraintExpr
|
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_constraint_stmt(self)
|
|
||||||
|
|
||||||
|
|
||||||
# Expressions
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class Expr(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
|
||||||
|
|
||||||
class Visitor(ABC, Generic[T]):
|
|
||||||
@abstractmethod
|
|
||||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_type_expr(self, expr: TypeExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_type_body_expr(self, expr: TypeBodyExpr) -> T: ...
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class WildcardExpr(Expr):
|
|
||||||
token: Token
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_wildcard_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class LiteralExpr(Expr):
|
|
||||||
value: Any
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_literal_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class TypeExpr(Expr):
|
|
||||||
name: Token
|
|
||||||
constraints: list[ConstraintExpr]
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_type_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ConstraintExpr(Expr):
|
|
||||||
left: Expr
|
|
||||||
op: Token
|
|
||||||
right: Expr
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_constraint_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class TypeBodyExpr(Expr):
|
|
||||||
properties: list[PropertyStmt]
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_type_body_expr(self)
|
|
||||||
@@ -1,360 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from enum import Enum, auto
|
|
||||||
import io
|
|
||||||
from typing import Generator, Generic, Optional, Protocol, TypeVar
|
|
||||||
|
|
||||||
import core.ast.annotations as a
|
|
||||||
import core.ast.midas as m
|
|
||||||
|
|
||||||
|
|
||||||
class _Level(Enum):
|
|
||||||
EMPTY = auto()
|
|
||||||
ACTIVE = auto()
|
|
||||||
LAST = auto()
|
|
||||||
|
|
||||||
|
|
||||||
class Expr(Protocol):
|
|
||||||
def accept(self, printer: AstPrinter) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=Expr)
|
|
||||||
|
|
||||||
|
|
||||||
class AstPrinter(Generic[T]):
|
|
||||||
LAST_CHILD = "└── "
|
|
||||||
CHILD = "├── "
|
|
||||||
VERTICAL = "│ "
|
|
||||||
EMPTY = " "
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._levels: list[_Level] = []
|
|
||||||
self._idx: Optional[int] = None
|
|
||||||
self._buf: io.StringIO = io.StringIO()
|
|
||||||
|
|
||||||
def print(self, expr: T):
|
|
||||||
self._buf = io.StringIO()
|
|
||||||
expr.accept(self)
|
|
||||||
return self._buf.getvalue()
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _child_level(self, last: bool = False) -> Generator[None, None, None]:
|
|
||||||
self._levels.append(_Level.LAST if last else _Level.ACTIVE)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
self._levels.pop()
|
|
||||||
|
|
||||||
def _mark_last(self):
|
|
||||||
if self._levels:
|
|
||||||
self._levels[-1] = _Level.LAST
|
|
||||||
|
|
||||||
def _write_line(self, text: str, *, last: bool = False):
|
|
||||||
if last:
|
|
||||||
self._mark_last()
|
|
||||||
indent: str = self._build_indent()
|
|
||||||
if self._idx is not None:
|
|
||||||
text = f"[{self._idx}] {text}"
|
|
||||||
self._idx = None
|
|
||||||
self._buf.write(indent + text + "\n")
|
|
||||||
|
|
||||||
def _build_indent(self) -> str:
|
|
||||||
parts: list[str] = []
|
|
||||||
for level in self._levels[:-1]:
|
|
||||||
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
|
|
||||||
if self._levels:
|
|
||||||
if self._levels[-1] == _Level.LAST:
|
|
||||||
parts.append(self.LAST_CHILD)
|
|
||||||
self._levels[-1] = _Level.EMPTY
|
|
||||||
else:
|
|
||||||
parts.append(self.CHILD)
|
|
||||||
return "".join(parts)
|
|
||||||
|
|
||||||
def _write_optional_child(
|
|
||||||
self, label: str, child: Optional[T], *, last: bool = False
|
|
||||||
):
|
|
||||||
if last:
|
|
||||||
self._mark_last()
|
|
||||||
if child is None:
|
|
||||||
self._write_line(f"{label}: None")
|
|
||||||
else:
|
|
||||||
self._write_line(label)
|
|
||||||
with self._child_level(last=True):
|
|
||||||
child.accept(self)
|
|
||||||
|
|
||||||
|
|
||||||
class AnnotationAstPrinter(AstPrinter, a.Expr.Visitor[None], a.Stmt.Visitor[None]):
|
|
||||||
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> None:
|
|
||||||
self._write_line("AnnotationStmt")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
|
||||||
self._write_optional_child("schema", stmt.schema, last=True)
|
|
||||||
|
|
||||||
def visit_type_expr(self, expr: a.TypeExpr):
|
|
||||||
self._write_line("TypeExpr")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
|
||||||
self._write_line("constraints", last=True)
|
|
||||||
with self._child_level():
|
|
||||||
for i, constraint in enumerate(expr.constraints):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(expr.constraints) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
constraint.accept(self)
|
|
||||||
|
|
||||||
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> None:
|
|
||||||
self._write_line("ConstraintExpr")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line("left")
|
|
||||||
with self._child_level():
|
|
||||||
self._mark_last()
|
|
||||||
expr.left.accept(self)
|
|
||||||
|
|
||||||
self._write_line(f"operator: {expr.op.lexeme}")
|
|
||||||
|
|
||||||
self._write_line("right", last=True)
|
|
||||||
with self._child_level():
|
|
||||||
self._mark_last()
|
|
||||||
expr.right.accept(self)
|
|
||||||
|
|
||||||
def visit_schema_expr(self, expr: a.SchemaExpr):
|
|
||||||
self._write_line("SchemaExpr")
|
|
||||||
with self._child_level():
|
|
||||||
for i, elmt in enumerate(expr.elements):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(expr.elements) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
elmt.accept(self)
|
|
||||||
|
|
||||||
def visit_schema_element_expr(self, expr: a.SchemaElementExpr):
|
|
||||||
self._write_line("SchemaElementExpr")
|
|
||||||
with self._child_level():
|
|
||||||
name_text: str = "None" if expr.name is None else f'"{expr.name.lexeme}"'
|
|
||||||
self._write_line(f"name: {name_text}")
|
|
||||||
self._write_optional_child("type", expr.type, last=True)
|
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> None:
|
|
||||||
self._write_line("WildcardExpr")
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: a.LiteralExpr) -> None:
|
|
||||||
self._write_line("LiteralExpr")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line(f"value: {expr.value}", last=True)
|
|
||||||
|
|
||||||
|
|
||||||
class AnnotationPrinter(a.Expr.Visitor[str], a.Stmt.Visitor[str]):
|
|
||||||
def print(self, expr: a.Expr | a.Stmt):
|
|
||||||
return expr.accept(self)
|
|
||||||
|
|
||||||
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> str:
|
|
||||||
schema: str = ""
|
|
||||||
if stmt.schema is not None:
|
|
||||||
schema = stmt.schema.accept(self)
|
|
||||||
return f"{stmt.name.lexeme}{schema}"
|
|
||||||
|
|
||||||
def visit_type_expr(self, expr: a.TypeExpr) -> str:
|
|
||||||
parts: list[str] = [expr.name.lexeme]
|
|
||||||
for constraint in expr.constraints:
|
|
||||||
parts.append("(" + constraint.accept(self) + ")")
|
|
||||||
return " + ".join(parts)
|
|
||||||
|
|
||||||
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> str:
|
|
||||||
parts: list[str] = [
|
|
||||||
expr.left.accept(self),
|
|
||||||
expr.op.lexeme,
|
|
||||||
expr.right.accept(self),
|
|
||||||
]
|
|
||||||
return " ".join(parts)
|
|
||||||
|
|
||||||
def visit_schema_expr(self, expr: a.SchemaExpr) -> str:
|
|
||||||
res: str = expr.left.lexeme
|
|
||||||
res += ", ".join(elmt.accept(self) for elmt in expr.elements)
|
|
||||||
res += expr.right.lexeme
|
|
||||||
return res
|
|
||||||
|
|
||||||
def visit_schema_element_expr(self, expr: a.SchemaElementExpr) -> str:
|
|
||||||
parts: list[str] = []
|
|
||||||
if expr.name is not None:
|
|
||||||
parts.append(expr.name.lexeme)
|
|
||||||
|
|
||||||
if expr.type is None:
|
|
||||||
parts.append("_")
|
|
||||||
else:
|
|
||||||
parts.append(expr.type.accept(self))
|
|
||||||
return ": ".join(parts)
|
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> str:
|
|
||||||
return "_"
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: a.LiteralExpr) -> str:
|
|
||||||
return str(expr.value)
|
|
||||||
|
|
||||||
|
|
||||||
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
|
||||||
def visit_type_stmt(self, stmt: m.TypeStmt):
|
|
||||||
self._write_line("TypeStmt")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
|
||||||
self._write_line("bases")
|
|
||||||
with self._child_level():
|
|
||||||
for i, base in enumerate(stmt.bases):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(stmt.bases) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
base.accept(self)
|
|
||||||
self._write_optional_child("body", stmt.body, last=True)
|
|
||||||
|
|
||||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
|
||||||
self._write_line("PropertyStmt")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
|
||||||
self._write_line("type", last=True)
|
|
||||||
with self._child_level():
|
|
||||||
self._mark_last()
|
|
||||||
stmt.type.accept(self)
|
|
||||||
|
|
||||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
|
||||||
self._write_line("OpStmt")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line("left")
|
|
||||||
with self._child_level():
|
|
||||||
self._mark_last()
|
|
||||||
stmt.left.accept(self)
|
|
||||||
|
|
||||||
self._write_line(f'op: "{stmt.op.lexeme}"')
|
|
||||||
|
|
||||||
self._write_line("right")
|
|
||||||
with self._child_level():
|
|
||||||
self._mark_last()
|
|
||||||
stmt.right.accept(self)
|
|
||||||
|
|
||||||
self._write_line("result", last=True)
|
|
||||||
with self._child_level():
|
|
||||||
self._mark_last()
|
|
||||||
stmt.result.accept(self)
|
|
||||||
|
|
||||||
def visit_constraint_stmt(self, stmt: m.ConstraintStmt):
|
|
||||||
self._write_line("ConstraintStmt")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
|
||||||
self._write_line("constraint", last=True)
|
|
||||||
with self._child_level():
|
|
||||||
self._mark_last()
|
|
||||||
stmt.constraint.accept(self)
|
|
||||||
|
|
||||||
def visit_type_expr(self, expr: m.TypeExpr):
|
|
||||||
self._write_line("TypeExpr")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
|
||||||
self._write_line("constraints", last=True)
|
|
||||||
with self._child_level():
|
|
||||||
for i, constraint in enumerate(expr.constraints):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(expr.constraints) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
constraint.accept(self)
|
|
||||||
|
|
||||||
def visit_constraint_expr(self, expr: m.ConstraintExpr):
|
|
||||||
self._write_line("ConstraintExpr")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line("left")
|
|
||||||
with self._child_level():
|
|
||||||
self._mark_last()
|
|
||||||
expr.left.accept(self)
|
|
||||||
|
|
||||||
self._write_line(f"operator: {expr.op.lexeme}")
|
|
||||||
|
|
||||||
self._write_line("right", last=True)
|
|
||||||
with self._child_level():
|
|
||||||
self._mark_last()
|
|
||||||
expr.right.accept(self)
|
|
||||||
|
|
||||||
def visit_type_body_expr(self, expr: m.TypeBodyExpr):
|
|
||||||
self._write_line("TypeBodyExpr")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line("properties", last=True)
|
|
||||||
with self._child_level():
|
|
||||||
for i, property in enumerate(expr.properties):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(expr.properties) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
property.accept(self)
|
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
|
||||||
self._write_line("WildcardExpr")
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
|
||||||
self._write_line("LiteralExpr")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line(f"value: {expr.value}", last=True)
|
|
||||||
|
|
||||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
|
||||||
def __init__(self, indent: int = 4):
|
|
||||||
self.indent: int = indent
|
|
||||||
self.level: int = 0
|
|
||||||
|
|
||||||
def indented(self, text: str) -> str:
|
|
||||||
return " " * (self.level * self.indent) + text
|
|
||||||
|
|
||||||
def print(self, expr: m.Expr | m.Stmt):
|
|
||||||
self.level = 0
|
|
||||||
return expr.accept(self)
|
|
||||||
|
|
||||||
def visit_type_stmt(self, stmt: m.TypeStmt):
|
|
||||||
bases: list[str] = [
|
|
||||||
b.accept(self)
|
|
||||||
for b in stmt.bases
|
|
||||||
]
|
|
||||||
|
|
||||||
res: str = self.indented(f"type {stmt.name.lexeme}<{', '.join(bases)}>")
|
|
||||||
if stmt.body is not None:
|
|
||||||
res += " {\n"
|
|
||||||
self.level += 1
|
|
||||||
res += stmt.body.accept(self)
|
|
||||||
self.level -= 1
|
|
||||||
res += "\n" + self.indented("}")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
|
||||||
return f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
|
|
||||||
|
|
||||||
def visit_op_stmt(self, stmt: m.OpStmt):
|
|
||||||
left: str = stmt.left.accept(self)
|
|
||||||
op: str = stmt.op.lexeme
|
|
||||||
right: str = stmt.right.accept(self)
|
|
||||||
result: str = stmt.result.accept(self)
|
|
||||||
return self.indented(f"op <{left}> {op} <{right}> = <{result}>")
|
|
||||||
|
|
||||||
def visit_constraint_stmt(self, stmt: m.ConstraintStmt):
|
|
||||||
name: str = stmt.name.lexeme
|
|
||||||
constraint: str = stmt.constraint.accept(self)
|
|
||||||
return self.indented(f"constraint {name} = {constraint}")
|
|
||||||
|
|
||||||
def visit_type_expr(self, expr: m.TypeExpr):
|
|
||||||
parts: list[str] = [expr.name.lexeme]
|
|
||||||
for constraint in expr.constraints:
|
|
||||||
parts.append("(" + constraint.accept(self) + ")")
|
|
||||||
return " + ".join(parts)
|
|
||||||
|
|
||||||
def visit_constraint_expr(self, expr: m.ConstraintExpr):
|
|
||||||
parts: list[str] = [
|
|
||||||
expr.left.accept(self),
|
|
||||||
expr.op.lexeme,
|
|
||||||
expr.right.accept(self),
|
|
||||||
]
|
|
||||||
return " ".join(parts)
|
|
||||||
|
|
||||||
def visit_type_body_expr(self, expr: m.TypeBodyExpr):
|
|
||||||
properties: list[str] = [
|
|
||||||
self.indented(prop.accept(self))
|
|
||||||
for prop in expr.properties
|
|
||||||
]
|
|
||||||
return "\n".join(properties)
|
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
|
||||||
return "_"
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: m.LiteralExpr):
|
|
||||||
return str(expr.value)
|
|
||||||
150
docs/architecture.typ
Normal file
150
docs/architecture.typ
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
#import "@preview/cetz:0.5.2": canvas, draw
|
||||||
|
|
||||||
|
#let diagram-only = false
|
||||||
|
|
||||||
|
#set document(
|
||||||
|
title: [Midas Architecture],
|
||||||
|
//author: "Louis Heredero",
|
||||||
|
)
|
||||||
|
|
||||||
|
#set text(
|
||||||
|
font: "Source Sans 3",
|
||||||
|
)
|
||||||
|
|
||||||
|
#let diagram = canvas({
|
||||||
|
let framed = draw.content.with(
|
||||||
|
padding: (x: .8em, y: 1em),
|
||||||
|
frame: "rect",
|
||||||
|
stroke: black,
|
||||||
|
)
|
||||||
|
let arrow = draw.line.with(mark: (end: ">", fill: black))
|
||||||
|
framed(
|
||||||
|
(0, 0),
|
||||||
|
name: "python-parser",
|
||||||
|
)[Python parser]
|
||||||
|
|
||||||
|
draw.content(
|
||||||
|
(rel: (0, 1), to: "python-parser.north"),
|
||||||
|
padding: 5pt,
|
||||||
|
anchor: "south",
|
||||||
|
name: "source-py",
|
||||||
|
)[_`source.py`_]
|
||||||
|
arrow("source-py", "python-parser")
|
||||||
|
|
||||||
|
framed(
|
||||||
|
(rel: (3, 0), to: "python-parser.east"),
|
||||||
|
anchor: "west",
|
||||||
|
name: "custom-parser",
|
||||||
|
align(center)[Custom python\ parser],
|
||||||
|
)
|
||||||
|
|
||||||
|
arrow("python-parser", "custom-parser", name: "arrow-python-ast")
|
||||||
|
draw.content(
|
||||||
|
"arrow-python-ast",
|
||||||
|
anchor: "south",
|
||||||
|
padding: 5pt,
|
||||||
|
)[`ast.Module`]
|
||||||
|
|
||||||
|
framed(
|
||||||
|
(rel: (-3, -2), to: "custom-parser.south"),
|
||||||
|
anchor: "east",
|
||||||
|
name: "python-resolver",
|
||||||
|
)[Python Resolver]
|
||||||
|
arrow(
|
||||||
|
"custom-parser",
|
||||||
|
((), "|-", "python-resolver.east"),
|
||||||
|
"python-resolver",
|
||||||
|
name: "arrow-python-custom-ast",
|
||||||
|
)
|
||||||
|
draw.content(
|
||||||
|
(rel: (1.5, 0), to: "arrow-python-custom-ast.end"),
|
||||||
|
padding: 5pt,
|
||||||
|
anchor: "south",
|
||||||
|
)[P-AST#footnote[#strong[P]ython *AST*]<fn-past>]
|
||||||
|
draw.content(
|
||||||
|
"python-resolver.west",
|
||||||
|
padding: 5pt,
|
||||||
|
anchor: "south-east",
|
||||||
|
)[Resolved P-AST@fn-past]
|
||||||
|
|
||||||
|
draw.circle(
|
||||||
|
(rel: (1, -2), to: "custom-parser.south-east"),
|
||||||
|
radius: .4,
|
||||||
|
name: "midas-loader",
|
||||||
|
)
|
||||||
|
arrow(
|
||||||
|
"custom-parser",
|
||||||
|
"midas-loader",
|
||||||
|
name: "arrow-load-midas",
|
||||||
|
mark: (end: (symbol: ">", fill: black), start: "o"),
|
||||||
|
)
|
||||||
|
draw.content(
|
||||||
|
"arrow-load-midas",
|
||||||
|
anchor: "west",
|
||||||
|
padding: 5pt,
|
||||||
|
)[```python midas.using("types.midas")```]
|
||||||
|
|
||||||
|
framed(
|
||||||
|
(rel: (0, -2), to: "midas-loader.south"),
|
||||||
|
name: "midas-parser",
|
||||||
|
)[Midas lexer/parser]
|
||||||
|
arrow("midas-loader", "midas-parser", name: "arrow-midas-source")
|
||||||
|
draw.content(
|
||||||
|
"arrow-midas-source",
|
||||||
|
anchor: "west",
|
||||||
|
padding: 5pt,
|
||||||
|
)[_`types.midas`_]
|
||||||
|
|
||||||
|
|
||||||
|
framed(
|
||||||
|
(rel: (-2, 0), to: "midas-parser.west"),
|
||||||
|
anchor: "east",
|
||||||
|
name: "midas-resolver",
|
||||||
|
)[Midas Resolver]
|
||||||
|
arrow("midas-parser", "midas-resolver", name: "arrow-midas-ast")
|
||||||
|
draw.content(
|
||||||
|
"arrow-midas-ast",
|
||||||
|
anchor: "south",
|
||||||
|
padding: 5pt,
|
||||||
|
)[M-AST#footnote[#strong[M]idas *AST*]<fn-mast>]
|
||||||
|
|
||||||
|
framed(
|
||||||
|
(rel: (-3, 0), to: "midas-resolver.west"),
|
||||||
|
anchor: "east",
|
||||||
|
name: "checker",
|
||||||
|
)[Checker]
|
||||||
|
arrow("midas-resolver", "checker", name: "arrow-type-ctx")
|
||||||
|
arrow(
|
||||||
|
"python-resolver",
|
||||||
|
((), "-|", "checker.north"),
|
||||||
|
"checker",
|
||||||
|
)
|
||||||
|
draw.content(
|
||||||
|
"arrow-type-ctx",
|
||||||
|
anchor: "south",
|
||||||
|
padding: 5pt,
|
||||||
|
)[Types context]
|
||||||
|
})
|
||||||
|
|
||||||
|
#show: doc => if diagram-only {
|
||||||
|
set page(width: auto, height: auto, margin: .5cm)
|
||||||
|
diagram
|
||||||
|
} else { doc }
|
||||||
|
|
||||||
|
#align(center, title())
|
||||||
|
|
||||||
|
#v(1cm)
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
diagram,
|
||||||
|
caption: [Midas type-checker architecture],
|
||||||
|
)
|
||||||
|
|
||||||
|
== Components
|
||||||
|
|
||||||
|
- *Python parser*: builtin Python AST parser, extracts abstract syntax from the raw Python source (```python ast.parse(...)```)
|
||||||
|
- *Custom python parser*: converts the raw Python AST into custom, more suitable constructs, especially for type annotations
|
||||||
|
- *Python resolver*: resolves bindings and references, tracks binding scopes
|
||||||
|
- *Midas lexer/parser*: parses a Midas type definition file and extracts its AST
|
||||||
|
- *Midas resolver*: walks the AST and fills the environment with the defined types and operations
|
||||||
|
- *Checker*: evaluates expressions and checks type coherence
|
||||||
@@ -21,7 +21,7 @@ lat + lon # Invalid operation
|
|||||||
# Registered operations are permitted
|
# Registered operations are permitted
|
||||||
lat1: Latitude = lat[0]
|
lat1: Latitude = lat[0]
|
||||||
lat2: Latitude = lat[1]
|
lat2: Latitude = lat[1]
|
||||||
lat_diff: LatitudeDiff = lat2 - lat1 # Valid operation
|
lat_diff: Difference[Latitude] = lat2 - lat1 # Valid operation
|
||||||
|
|
||||||
# In addition to the type, a column can have one or more constraints, either defined inline or in a separate file
|
# In addition to the type, a column can have one or more constraints, either defined inline or in a separate file
|
||||||
df2: Frame[
|
df2: Frame[
|
||||||
|
|||||||
73
examples/00_syntax_prototype/03_custom_types_v2.midas
Normal file
73
examples/00_syntax_prototype/03_custom_types_v2.midas
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
// Simple custom type derived from float
|
||||||
|
type Custom(float)
|
||||||
|
|
||||||
|
// Simple custom types with constraints
|
||||||
|
type Latitude(float) where (-90 <= _ <= 90)
|
||||||
|
type Longitude(float) where (-180 <= _ <= 180)
|
||||||
|
|
||||||
|
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
|
||||||
|
type Difference[T](T)
|
||||||
|
|
||||||
|
// Complex custom type, containing two values accessible through properties
|
||||||
|
type GeoLocation {
|
||||||
|
lat: Latitude
|
||||||
|
lon: Longitude
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define operations on our custom type
|
||||||
|
extend GeoLocation {
|
||||||
|
// This type is compatible with the `-` operation with another GeoLocation
|
||||||
|
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
|
||||||
|
// in a Difference of GeoLocations
|
||||||
|
op __sub__(GeoLocation) -> Difference[GeoLocation]
|
||||||
|
}
|
||||||
|
|
||||||
|
// For complex generics, you need to specify how the genericity the properties
|
||||||
|
// are handled
|
||||||
|
type Difference[GeoLocation] {
|
||||||
|
lat: Difference[Latitude]
|
||||||
|
lon: Difference[Longitude]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple operation defined on our custom types
|
||||||
|
extend Latitude {
|
||||||
|
op __sub__(Latitude) -> Difference[Latitude]
|
||||||
|
}
|
||||||
|
|
||||||
|
extend Longitude {
|
||||||
|
op __sub__(Longitude) -> Difference[Longitude]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Predefined custom predicates that can be referenced in other definitions
|
||||||
|
predicate Positive(v: float) = v >= 0
|
||||||
|
predicate StrictlyPositive(v: float) = v > 0
|
||||||
|
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
|
||||||
|
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
|
||||||
|
|
||||||
|
type Person {
|
||||||
|
name: str
|
||||||
|
|
||||||
|
// Property with an inline constraint
|
||||||
|
age: int? where (0 <= _ < 150)
|
||||||
|
|
||||||
|
// Property referencing a predicate
|
||||||
|
height: float where StrictlyPositive
|
||||||
|
|
||||||
|
home: GeoLocation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom complex type derived from another complex type, with a constraint
|
||||||
|
// on a property
|
||||||
|
// Multiple proposed syntaxes, not yet defined
|
||||||
|
|
||||||
|
// Explicit, but new keyword
|
||||||
|
type EquatorialPerson refines Person where Equatorial(_.home)
|
||||||
|
|
||||||
|
// Explicit with existing keyword, might be confusing if expectations regarding 'is'
|
||||||
|
type EquatorialPerson is Person where Equatorial(_.home)
|
||||||
|
|
||||||
|
// Consistent and Python-friendly but can be confused with structural extension
|
||||||
|
type EquatorialPerson(Person) where Equatorial(_.home)
|
||||||
|
|
||||||
|
// Allow new properties, probably not useful
|
||||||
|
type EquatorialPerson extends Person where Equatorial(_.home)
|
||||||
15
examples/00_syntax_prototype/04_functions.py
Normal file
15
examples/00_syntax_prototype/04_functions.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# type: ignore
|
||||||
|
# ruff: disable[F821]
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
def func(
|
||||||
|
col1: Column[float + (0 <= _ <= 1)],
|
||||||
|
col2: Column[float + (0 <= _ <= 1)],
|
||||||
|
) -> Column[float + (0 <= _ <= 2)]:
|
||||||
|
result: Column[float + (0 <= _ <= 2)] = col1 + col2
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def func2(a: int, /, b: float, *, c: str):
|
||||||
|
pass
|
||||||
11
examples/01_simple_type_checking/01_simple_operations.py
Normal file
11
examples/01_simple_type_checking/01_simple_operations.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
a: int = 3
|
||||||
|
b: int = 4
|
||||||
|
|
||||||
|
c = a + b # -> int
|
||||||
|
|
||||||
|
c = "invalid" # -> can't assign str to int variable
|
||||||
|
|
||||||
|
d = True
|
||||||
|
e = d + d
|
||||||
|
|
||||||
|
f: float = a
|
||||||
14
examples/01_simple_type_checking/02_simple_types.midas
Normal file
14
examples/01_simple_type_checking/02_simple_types.midas
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
type Meter(float)
|
||||||
|
type Second(float)
|
||||||
|
type MeterPerSecond(float)
|
||||||
|
|
||||||
|
extend Meter {
|
||||||
|
op __add__(Meter) -> Meter
|
||||||
|
op __sub__(Meter) -> Meter
|
||||||
|
op __truediv__(Second) -> MeterPerSecond
|
||||||
|
}
|
||||||
|
|
||||||
|
extend Second {
|
||||||
|
op __add__(Second) -> Second
|
||||||
|
op __sub__(Second) -> Second
|
||||||
|
}
|
||||||
8
examples/01_simple_type_checking/02_simple_types.py
Normal file
8
examples/01_simple_type_checking/02_simple_types.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# type: ignore
|
||||||
|
# ruff: disable [F821]
|
||||||
|
|
||||||
|
midas.using("02_simple_types.midas")
|
||||||
|
|
||||||
|
distance: Meter = cast(Meter, 123.45)
|
||||||
|
time: Second = cast(Second, 6.7)
|
||||||
|
speed = distance / time
|
||||||
16
examples/01_simple_type_checking/03_control_flow.py
Normal file
16
examples/01_simple_type_checking/03_control_flow.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
def minimum(x: int, y: int):
|
||||||
|
if x < y:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return y
|
||||||
|
|
||||||
|
a = 15
|
||||||
|
b = 72
|
||||||
|
c = minimum(a, b)
|
||||||
|
|
||||||
|
def factorial(n: int) -> int:
|
||||||
|
if n <= 1:
|
||||||
|
return 1
|
||||||
|
return n * factorial(n - 1)
|
||||||
|
|
||||||
|
category = "Category 1" if a < 10 else "Category 2"
|
||||||
146
gen/gen.py
Normal file
146
gen/gen.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
HEADER = '''"""
|
||||||
|
This file was generated by a script. Any manual changes might be overwritten.
|
||||||
|
Please modify {defs_path} instead and run {gen_path}
|
||||||
|
"""'''
|
||||||
|
|
||||||
|
SECTION_TEMPLATE = """{banner}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class {base}(ABC):
|
||||||
|
location: Location
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||||
|
|
||||||
|
class Visitor(ABC, Generic[T]):
|
||||||
|
{visitor_methods}
|
||||||
|
|
||||||
|
|
||||||
|
{classes}"""
|
||||||
|
|
||||||
|
TEMPLATE = """{header}
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
{imports}
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
{sections}
|
||||||
|
"""
|
||||||
|
|
||||||
|
VISITOR_METHOD_TEMPLATE = """
|
||||||
|
@abstractmethod
|
||||||
|
def visit_{func_name}(self, {param}: {cls}) -> T: ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
CLASS_TEMPLATE = """
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class {cls}({base}):
|
||||||
|
{body}
|
||||||
|
|
||||||
|
def accept(self, visitor: {base}.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_{func_name}(self)
|
||||||
|
"""
|
||||||
|
|
||||||
|
SECTION_REGEX = re.compile(
|
||||||
|
r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)(\s*\|\s*(?P<param>[^\n]*?))?\s*?\n(?P<body>.*?)\n###<$",
|
||||||
|
re.MULTILINE | re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
IMPORTS_REGEX = re.compile(
|
||||||
|
r"^###>\s*Imports\s*?\n(?P<body>.*?)\n###<$",
|
||||||
|
re.MULTILINE | re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def snake_case(text: str) -> str:
|
||||||
|
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
||||||
|
|
||||||
|
|
||||||
|
def make_visitor_method(cls: str, param: str):
|
||||||
|
method: str = VISITOR_METHOD_TEMPLATE.format(
|
||||||
|
func_name=snake_case(cls), param=param, cls=cls
|
||||||
|
)
|
||||||
|
return method.strip("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def make_class(name: str, cls: str, base: str):
|
||||||
|
body: str = cls.split("\n", 1)[1]
|
||||||
|
func_name: str = snake_case(name)
|
||||||
|
cls_def: str = CLASS_TEMPLATE.format(
|
||||||
|
cls=name,
|
||||||
|
base=base,
|
||||||
|
body=body,
|
||||||
|
func_name=func_name,
|
||||||
|
)
|
||||||
|
return cls_def.strip("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def make_banner(text: str) -> str:
|
||||||
|
middle: str = f"# {text} #"
|
||||||
|
rule: str = "#" * len(middle)
|
||||||
|
return "\n".join((rule, middle, rule))
|
||||||
|
|
||||||
|
|
||||||
|
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||||
|
visitor_methods: list[str] = []
|
||||||
|
classes: list[str] = []
|
||||||
|
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
||||||
|
for cls in definitions:
|
||||||
|
cls = cls.strip("\n")
|
||||||
|
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
||||||
|
print(f"Processing {name}")
|
||||||
|
visitor_methods.append(make_visitor_method(name, param))
|
||||||
|
classes.append(make_class(name, cls, base))
|
||||||
|
|
||||||
|
return SECTION_TEMPLATE.format(
|
||||||
|
banner=make_banner(full_name),
|
||||||
|
base=base,
|
||||||
|
visitor_methods="\n\n".join(visitor_methods),
|
||||||
|
classes="\n\n\n".join(classes),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(definitions_path: Path, out_path: Path):
|
||||||
|
root_dir: Path = Path(__file__).parent.parent
|
||||||
|
rel_path: Path = definitions_path.relative_to(root_dir)
|
||||||
|
src: str = definitions_path.read_text()
|
||||||
|
sections: list[str] = []
|
||||||
|
|
||||||
|
imports: str = ""
|
||||||
|
if m := IMPORTS_REGEX.search(src):
|
||||||
|
imports = m.group("body").strip("\n")
|
||||||
|
|
||||||
|
for section_m in SECTION_REGEX.finditer(src):
|
||||||
|
full_name: str = section_m.group("name")
|
||||||
|
base: str = section_m.group("base")
|
||||||
|
param: str = section_m.group("param") or base.lower()
|
||||||
|
body: str = section_m.group("body")
|
||||||
|
sections.append(make_section(full_name, base, param, body))
|
||||||
|
|
||||||
|
result: str = TEMPLATE.format(
|
||||||
|
header=HEADER.format(
|
||||||
|
defs_path=rel_path,
|
||||||
|
gen_path=Path(__file__).relative_to(root_dir),
|
||||||
|
),
|
||||||
|
imports=imports,
|
||||||
|
sections="\n\n\n".join(sections),
|
||||||
|
)
|
||||||
|
out_path.write_text(result)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
root: Path = Path(__file__).parent.parent
|
||||||
|
defs_dir: Path = root / "gen"
|
||||||
|
ast_dir: Path = root / "midas" / "ast"
|
||||||
|
generate(defs_dir / "midas.py", ast_dir / "midas.py")
|
||||||
|
generate(defs_dir / "python.py", ast_dir / "python.py")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
110
gen/midas.py
Normal file
110
gen/midas.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
# type: ignore
|
||||||
|
# ruff: disable[F821, F401]
|
||||||
|
|
||||||
|
###> Imports
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Generic, Optional, TypeVar
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.lexer.token import Token
|
||||||
|
|
||||||
|
###<
|
||||||
|
|
||||||
|
|
||||||
|
###> Stmt | Statements
|
||||||
|
class SimpleTypeStmt:
|
||||||
|
name: Token
|
||||||
|
template: Optional[TemplateExpr]
|
||||||
|
base: TypeExpr
|
||||||
|
constraint: Optional[Expr]
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexTypeStmt:
|
||||||
|
name: Token
|
||||||
|
template: Optional[TemplateExpr]
|
||||||
|
properties: list[PropertyStmt]
|
||||||
|
|
||||||
|
|
||||||
|
class PropertyStmt:
|
||||||
|
name: Token
|
||||||
|
type: TypeExpr
|
||||||
|
constraint: Optional[Expr]
|
||||||
|
|
||||||
|
|
||||||
|
class ExtendStmt:
|
||||||
|
type: TypeExpr
|
||||||
|
operations: list[OpStmt]
|
||||||
|
|
||||||
|
|
||||||
|
class OpStmt:
|
||||||
|
name: Token
|
||||||
|
operand: TypeExpr
|
||||||
|
result: TypeExpr
|
||||||
|
|
||||||
|
|
||||||
|
class PredicateStmt:
|
||||||
|
name: Token
|
||||||
|
subject: Token
|
||||||
|
type: TypeExpr
|
||||||
|
condition: Expr
|
||||||
|
|
||||||
|
|
||||||
|
###<
|
||||||
|
|
||||||
|
|
||||||
|
###> Expr | Expressions
|
||||||
|
class SimpleTypeExpr:
|
||||||
|
name: Token
|
||||||
|
optional: bool
|
||||||
|
|
||||||
|
|
||||||
|
class LogicalExpr:
|
||||||
|
left: Expr
|
||||||
|
operator: Token
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryExpr:
|
||||||
|
left: Expr
|
||||||
|
operator: Token
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class UnaryExpr:
|
||||||
|
operator: Token
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class GetExpr:
|
||||||
|
expr: Expr
|
||||||
|
name: Token
|
||||||
|
|
||||||
|
|
||||||
|
class VariableExpr:
|
||||||
|
name: Token
|
||||||
|
|
||||||
|
|
||||||
|
class GroupingExpr:
|
||||||
|
expr: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class LiteralExpr:
|
||||||
|
value: Any
|
||||||
|
|
||||||
|
|
||||||
|
class WildcardExpr:
|
||||||
|
token: Token
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateExpr:
|
||||||
|
type: TypeExpr
|
||||||
|
|
||||||
|
|
||||||
|
class TypeExpr:
|
||||||
|
name: Token
|
||||||
|
template: Optional[TemplateExpr]
|
||||||
|
optional: bool
|
||||||
|
|
||||||
|
|
||||||
|
###<
|
||||||
148
gen/python.py
Normal file
148
gen/python.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
# type: ignore
|
||||||
|
# ruff: disable[F821, F401]
|
||||||
|
|
||||||
|
###> Imports
|
||||||
|
import ast
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Generic, Optional, TypeVar
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
|
||||||
|
###<
|
||||||
|
|
||||||
|
|
||||||
|
###> MidasType | Type annotations | node
|
||||||
|
class BaseType:
|
||||||
|
base: str
|
||||||
|
param: Optional[MidasType]
|
||||||
|
|
||||||
|
|
||||||
|
class ConstraintType:
|
||||||
|
type: MidasType
|
||||||
|
constraint: ast.expr
|
||||||
|
|
||||||
|
|
||||||
|
class FrameColumn:
|
||||||
|
name: Optional[str]
|
||||||
|
type: Optional[MidasType]
|
||||||
|
|
||||||
|
|
||||||
|
class FrameType:
|
||||||
|
columns: list[FrameColumn]
|
||||||
|
|
||||||
|
|
||||||
|
###<
|
||||||
|
|
||||||
|
|
||||||
|
###> Stmt | Statements
|
||||||
|
class ExpressionStmt:
|
||||||
|
expr: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class Function:
|
||||||
|
name: str
|
||||||
|
posonlyargs: list[Argument]
|
||||||
|
args: list[Argument]
|
||||||
|
sink: Optional[Argument]
|
||||||
|
kwonlyargs: list[Argument]
|
||||||
|
kw_sink: Optional[Argument]
|
||||||
|
returns: Optional[MidasType]
|
||||||
|
body: list[Stmt]
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Argument:
|
||||||
|
location: Optional[Location] = None
|
||||||
|
name: str
|
||||||
|
type: Optional[MidasType]
|
||||||
|
default: Optional[Expr]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_args(self) -> list[Argument]:
|
||||||
|
return self.posonlyargs + self.args + self.kwonlyargs
|
||||||
|
|
||||||
|
|
||||||
|
class TypeAssign:
|
||||||
|
name: str
|
||||||
|
type: MidasType
|
||||||
|
|
||||||
|
|
||||||
|
class AssignStmt:
|
||||||
|
targets: list[Expr]
|
||||||
|
value: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnStmt:
|
||||||
|
value: Optional[Expr]
|
||||||
|
|
||||||
|
|
||||||
|
class IfStmt:
|
||||||
|
test: Expr
|
||||||
|
body: list[Stmt]
|
||||||
|
orelse: list[Stmt]
|
||||||
|
|
||||||
|
|
||||||
|
###<
|
||||||
|
|
||||||
|
|
||||||
|
###> Expr | Expressions
|
||||||
|
class BinaryExpr:
|
||||||
|
left: Expr
|
||||||
|
operator: ast.operator
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class CompareExpr:
|
||||||
|
left: Expr
|
||||||
|
operator: ast.cmpop
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class UnaryExpr:
|
||||||
|
operator: ast.unaryop
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class CallExpr:
|
||||||
|
callee: Expr
|
||||||
|
arguments: list[Expr]
|
||||||
|
keywords: dict[str, Expr]
|
||||||
|
|
||||||
|
|
||||||
|
class GetExpr:
|
||||||
|
object: Expr
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class LiteralExpr:
|
||||||
|
value: Any
|
||||||
|
|
||||||
|
|
||||||
|
class VariableExpr:
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class LogicalExpr:
|
||||||
|
left: Expr
|
||||||
|
operator: ast.boolop
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class SetExpr:
|
||||||
|
object: Expr
|
||||||
|
name: str
|
||||||
|
value: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class CastExpr:
|
||||||
|
type: MidasType
|
||||||
|
expr: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class TernaryExpr:
|
||||||
|
test: Expr
|
||||||
|
if_true: Expr
|
||||||
|
if_false: Expr
|
||||||
|
|
||||||
|
|
||||||
|
###<
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
from lexer.base import Lexer
|
|
||||||
from lexer.keyword import ANNOTATION_KEYWORDS
|
|
||||||
from lexer.token import TokenType
|
|
||||||
|
|
||||||
|
|
||||||
class AnnotationLexer(Lexer):
|
|
||||||
def scan_token(self) -> None:
|
|
||||||
char: str = self.advance()
|
|
||||||
match char:
|
|
||||||
case "(":
|
|
||||||
self.add_token(TokenType.LEFT_PAREN)
|
|
||||||
case ")":
|
|
||||||
self.add_token(TokenType.RIGHT_PAREN)
|
|
||||||
case "[":
|
|
||||||
self.add_token(TokenType.LEFT_BRACKET)
|
|
||||||
case "]":
|
|
||||||
self.add_token(TokenType.RIGHT_BRACKET)
|
|
||||||
case "<":
|
|
||||||
self.add_token(
|
|
||||||
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS
|
|
||||||
)
|
|
||||||
case ">":
|
|
||||||
self.add_token(
|
|
||||||
TokenType.GREATER_EQUAL if self.match("=") else TokenType.GREATER
|
|
||||||
)
|
|
||||||
case "=":
|
|
||||||
self.add_token(
|
|
||||||
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
|
|
||||||
)
|
|
||||||
case "!":
|
|
||||||
if self.match("="):
|
|
||||||
self.add_token(TokenType.BANG_EQUAL)
|
|
||||||
else:
|
|
||||||
self.error("Unexpected single bang. Did you mean '!=' ?")
|
|
||||||
case ":":
|
|
||||||
self.add_token(TokenType.COLON)
|
|
||||||
case ",":
|
|
||||||
self.add_token(TokenType.COMMA)
|
|
||||||
case "_":
|
|
||||||
self.add_token(TokenType.UNDERSCORE)
|
|
||||||
case "+":
|
|
||||||
self.add_token(TokenType.PLUS)
|
|
||||||
case "#":
|
|
||||||
self.scan_comment()
|
|
||||||
case "\n":
|
|
||||||
self.add_token(TokenType.NEWLINE)
|
|
||||||
case " " | "\r" | "\t":
|
|
||||||
# Consume all whitespace characters until EOL or EOF
|
|
||||||
while (
|
|
||||||
self.peek().isspace()
|
|
||||||
and self.peek() != "\n"
|
|
||||||
and not self.is_at_end()
|
|
||||||
):
|
|
||||||
self.advance()
|
|
||||||
self.add_token(TokenType.WHITESPACE)
|
|
||||||
case _:
|
|
||||||
if char.isdigit():
|
|
||||||
self.scan_number()
|
|
||||||
elif char.isalpha():
|
|
||||||
self.scan_identifier()
|
|
||||||
else:
|
|
||||||
self.error("Unexpected character")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def scan_number(self):
|
|
||||||
"""Scan the rest of number and add it as a token
|
|
||||||
|
|
||||||
This method handles both simple integers and floats. Scientific notation
|
|
||||||
and base prefixes (0x, 0b, 0o) are not supported
|
|
||||||
"""
|
|
||||||
while self.peek().isdigit():
|
|
||||||
self.advance()
|
|
||||||
|
|
||||||
if self.peek() == "." and self.peek_next().isdigit():
|
|
||||||
self.advance()
|
|
||||||
while self.peek().isdigit():
|
|
||||||
self.advance()
|
|
||||||
|
|
||||||
value: float = float(self.source[self.start : self.idx])
|
|
||||||
self.add_token(TokenType.NUMBER, value)
|
|
||||||
|
|
||||||
def scan_identifier(self):
|
|
||||||
"""Scan the rest of an identifier and add it as a token
|
|
||||||
|
|
||||||
An identifier starts with a letter, followed by any number of
|
|
||||||
alphanumerical characters or underscores
|
|
||||||
"""
|
|
||||||
while self.peek().isalnum() or self.peek() == "_":
|
|
||||||
self.advance()
|
|
||||||
|
|
||||||
lexeme: str = self.source[self.start : self.idx]
|
|
||||||
token_type: TokenType = ANNOTATION_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
|
||||||
self.add_token(token_type)
|
|
||||||
|
|
||||||
def scan_comment(self):
|
|
||||||
"""Scan the rest of a comment and add it as a token
|
|
||||||
|
|
||||||
A comment starts with a `#` character and ends at the EOL/EOF
|
|
||||||
"""
|
|
||||||
while self.peek() != "\n" and not self.is_at_end():
|
|
||||||
self.advance()
|
|
||||||
self.add_token(TokenType.COMMENT)
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
from lexer.token import TokenType
|
|
||||||
|
|
||||||
ANNOTATION_KEYWORDS: dict[str, TokenType] = {
|
|
||||||
"True": TokenType.TRUE,
|
|
||||||
"False": TokenType.FALSE,
|
|
||||||
"None": TokenType.NONE,
|
|
||||||
}
|
|
||||||
|
|
||||||
MIDAS_KEYWORDS: dict[str, TokenType] = {
|
|
||||||
"type": TokenType.TYPE,
|
|
||||||
"op": TokenType.OP,
|
|
||||||
"constraint": TokenType.CONSTRAINT,
|
|
||||||
"true": TokenType.TRUE,
|
|
||||||
"false": TokenType.FALSE,
|
|
||||||
"none": TokenType.NONE,
|
|
||||||
}
|
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum, auto
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from lexer.position import Position
|
|
||||||
|
|
||||||
|
|
||||||
class TokenType(Enum):
|
|
||||||
# Punctuation
|
|
||||||
LEFT_PAREN = auto()
|
|
||||||
RIGHT_PAREN = auto()
|
|
||||||
LEFT_BRACKET = auto()
|
|
||||||
RIGHT_BRACKET = auto()
|
|
||||||
LEFT_BRACE = auto()
|
|
||||||
RIGHT_BRACE = auto()
|
|
||||||
COLON = auto()
|
|
||||||
COMMA = auto()
|
|
||||||
UNDERSCORE = auto()
|
|
||||||
|
|
||||||
# Operators
|
|
||||||
PLUS = auto()
|
|
||||||
MINUS = auto()
|
|
||||||
STAR = auto()
|
|
||||||
SLASH = auto()
|
|
||||||
GREATER = auto()
|
|
||||||
GREATER_EQUAL = auto()
|
|
||||||
LESS = auto()
|
|
||||||
LESS_EQUAL = auto()
|
|
||||||
EQUAL = auto()
|
|
||||||
EQUAL_EQUAL = auto()
|
|
||||||
BANG_EQUAL = auto()
|
|
||||||
|
|
||||||
# Literals
|
|
||||||
IDENTIFIER = auto()
|
|
||||||
NUMBER = auto()
|
|
||||||
TRUE = auto()
|
|
||||||
FALSE = auto()
|
|
||||||
NONE = auto()
|
|
||||||
|
|
||||||
# Keywords
|
|
||||||
TYPE = auto()
|
|
||||||
OP = auto()
|
|
||||||
CONSTRAINT = auto()
|
|
||||||
|
|
||||||
# Misc
|
|
||||||
COMMENT = auto()
|
|
||||||
WHITESPACE = auto()
|
|
||||||
EOF = auto()
|
|
||||||
NEWLINE = auto()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class Token:
|
|
||||||
"""A scanned token"""
|
|
||||||
|
|
||||||
type: TokenType
|
|
||||||
lexeme: str
|
|
||||||
value: Any
|
|
||||||
position: Position
|
|
||||||
37
midas/ast/location.py
Normal file
37
midas/ast/location.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class HasLocation(Protocol):
|
||||||
|
lineno: int
|
||||||
|
col_offset: int
|
||||||
|
end_lineno: Optional[int]
|
||||||
|
end_col_offset: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Location:
|
||||||
|
lineno: int
|
||||||
|
col_offset: int
|
||||||
|
end_lineno: Optional[int]
|
||||||
|
end_col_offset: Optional[int]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_ast(obj: HasLocation) -> Location:
|
||||||
|
return Location(
|
||||||
|
lineno=obj.lineno,
|
||||||
|
col_offset=obj.col_offset,
|
||||||
|
end_lineno=obj.end_lineno,
|
||||||
|
end_col_offset=obj.end_col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def span(start: Location, end: Location) -> Location:
|
||||||
|
return Location(
|
||||||
|
lineno=start.lineno,
|
||||||
|
col_offset=start.col_offset,
|
||||||
|
end_lineno=end.lineno,
|
||||||
|
end_col_offset=end.end_col_offset,
|
||||||
|
)
|
||||||
251
midas/ast/midas.py
Normal file
251
midas/ast/midas.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
"""
|
||||||
|
This file was generated by a script. Any manual changes might be overwritten.
|
||||||
|
Please modify gen/midas.py instead and run gen/gen.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Generic, Optional, TypeVar
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.lexer.token import Token
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
##############
|
||||||
|
# Statements #
|
||||||
|
##############
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Stmt(ABC):
|
||||||
|
location: Location
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||||
|
|
||||||
|
class Visitor(ABC, Generic[T]):
|
||||||
|
@abstractmethod
|
||||||
|
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_op_stmt(self, stmt: OpStmt) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_predicate_stmt(self, stmt: PredicateStmt) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SimpleTypeStmt(Stmt):
|
||||||
|
name: Token
|
||||||
|
template: Optional[TemplateExpr]
|
||||||
|
base: TypeExpr
|
||||||
|
constraint: Optional[Expr]
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_simple_type_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ComplexTypeStmt(Stmt):
|
||||||
|
name: Token
|
||||||
|
template: Optional[TemplateExpr]
|
||||||
|
properties: list[PropertyStmt]
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_complex_type_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PropertyStmt(Stmt):
|
||||||
|
name: Token
|
||||||
|
type: TypeExpr
|
||||||
|
constraint: Optional[Expr]
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_property_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ExtendStmt(Stmt):
|
||||||
|
type: TypeExpr
|
||||||
|
operations: list[OpStmt]
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_extend_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class OpStmt(Stmt):
|
||||||
|
name: Token
|
||||||
|
operand: TypeExpr
|
||||||
|
result: TypeExpr
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_op_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PredicateStmt(Stmt):
|
||||||
|
name: Token
|
||||||
|
subject: Token
|
||||||
|
type: TypeExpr
|
||||||
|
condition: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_predicate_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
###############
|
||||||
|
# Expressions #
|
||||||
|
###############
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Expr(ABC):
|
||||||
|
location: Location
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||||
|
|
||||||
|
class Visitor(ABC, Generic[T]):
|
||||||
|
@abstractmethod
|
||||||
|
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_variable_expr(self, expr: VariableExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_grouping_expr(self, expr: GroupingExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_template_expr(self, expr: TemplateExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_type_expr(self, expr: TypeExpr) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SimpleTypeExpr(Expr):
|
||||||
|
name: Token
|
||||||
|
optional: bool
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_simple_type_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LogicalExpr(Expr):
|
||||||
|
left: Expr
|
||||||
|
operator: Token
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_logical_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class BinaryExpr(Expr):
|
||||||
|
left: Expr
|
||||||
|
operator: Token
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_binary_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class UnaryExpr(Expr):
|
||||||
|
operator: Token
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_unary_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GetExpr(Expr):
|
||||||
|
expr: Expr
|
||||||
|
name: Token
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_get_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class VariableExpr(Expr):
|
||||||
|
name: Token
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_variable_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GroupingExpr(Expr):
|
||||||
|
expr: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_grouping_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LiteralExpr(Expr):
|
||||||
|
value: Any
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_literal_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class WildcardExpr(Expr):
|
||||||
|
token: Token
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_wildcard_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TemplateExpr(Expr):
|
||||||
|
type: TypeExpr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_template_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TypeExpr(Expr):
|
||||||
|
name: Token
|
||||||
|
template: Optional[TemplateExpr]
|
||||||
|
optional: bool
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_type_expr(self)
|
||||||
610
midas/ast/printer.py
Normal file
610
midas/ast/printer.py
Normal file
@@ -0,0 +1,610 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import io
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Generator, Generic, Optional, Protocol, TypeVar
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
import midas.ast.python as p
|
||||||
|
|
||||||
|
|
||||||
|
class _Level(Enum):
|
||||||
|
EMPTY = auto()
|
||||||
|
ACTIVE = auto()
|
||||||
|
LAST = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class Expr(Protocol):
|
||||||
|
def accept(self, printer: AstPrinter) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=Expr)
|
||||||
|
|
||||||
|
|
||||||
|
class AstPrinter(Generic[T]):
|
||||||
|
LAST_CHILD = "└── "
|
||||||
|
CHILD = "├── "
|
||||||
|
VERTICAL = "│ "
|
||||||
|
EMPTY = " "
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._levels: list[_Level] = []
|
||||||
|
self._idx: Optional[int] = None
|
||||||
|
self._buf: io.StringIO = io.StringIO()
|
||||||
|
|
||||||
|
def print(self, expr: T):
|
||||||
|
self._buf = io.StringIO()
|
||||||
|
expr.accept(self)
|
||||||
|
return self._buf.getvalue()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _child_level(self, single: bool = False) -> Generator[None, None, None]:
|
||||||
|
self._levels.append(_Level.LAST if single else _Level.ACTIVE)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._levels.pop()
|
||||||
|
|
||||||
|
def _mark_last(self):
|
||||||
|
if self._levels:
|
||||||
|
self._levels[-1] = _Level.LAST
|
||||||
|
|
||||||
|
def _write_line(self, text: str, *, last: bool = False):
|
||||||
|
if last:
|
||||||
|
self._mark_last()
|
||||||
|
indent: str = self._build_indent()
|
||||||
|
if self._idx is not None:
|
||||||
|
text = f"[{self._idx}] {text}"
|
||||||
|
self._idx = None
|
||||||
|
self._buf.write(indent + text + "\n")
|
||||||
|
|
||||||
|
def _build_indent(self) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
for level in self._levels[:-1]:
|
||||||
|
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
|
||||||
|
if self._levels:
|
||||||
|
if self._levels[-1] == _Level.LAST:
|
||||||
|
parts.append(self.LAST_CHILD)
|
||||||
|
self._levels[-1] = _Level.EMPTY
|
||||||
|
else:
|
||||||
|
parts.append(self.CHILD)
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
def _write_optional_child(
|
||||||
|
self, label: str, child: Optional[T], *, last: bool = False
|
||||||
|
):
|
||||||
|
if last:
|
||||||
|
self._mark_last()
|
||||||
|
if child is None:
|
||||||
|
self._write_line(f"{label}: None")
|
||||||
|
else:
|
||||||
|
self._write_line(label)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
child.accept(self)
|
||||||
|
|
||||||
|
|
||||||
|
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
||||||
|
# Statements
|
||||||
|
|
||||||
|
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
|
||||||
|
self._write_line("SimpleTypeStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||||
|
self._write_optional_child("template", stmt.template)
|
||||||
|
self._write_line("base")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.base.accept(self)
|
||||||
|
self._write_optional_child("constraint", stmt.constraint, last=True)
|
||||||
|
|
||||||
|
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
|
||||||
|
self._write_line("ComplexTypeStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||||
|
self._write_optional_child("template", stmt.template)
|
||||||
|
self._write_line("properties", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, prop in enumerate(stmt.properties):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.properties) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
prop.accept(self)
|
||||||
|
|
||||||
|
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||||
|
self._write_line("PropertyStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||||
|
self._write_line("type")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.type.accept(self)
|
||||||
|
self._write_optional_child("constraint", stmt.constraint, last=True)
|
||||||
|
|
||||||
|
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||||
|
self._write_line("ExtendStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("type")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.type.accept(self)
|
||||||
|
self._write_line("operations", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, op in enumerate(stmt.operations):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.operations) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
op.accept(self)
|
||||||
|
|
||||||
|
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
||||||
|
self._write_line("OpStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||||
|
|
||||||
|
self._write_line("operand")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.operand.accept(self)
|
||||||
|
|
||||||
|
self._write_line("result", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.result.accept(self)
|
||||||
|
|
||||||
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||||
|
self._write_line("PredicateStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||||
|
self._write_line(f'subject: "{stmt.subject.lexeme}"')
|
||||||
|
self._write_line("type")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.type.accept(self)
|
||||||
|
self._write_line("condition", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.condition.accept(self)
|
||||||
|
|
||||||
|
# Expressions
|
||||||
|
|
||||||
|
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
|
||||||
|
self._write_line("SimpleTypeExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||||
|
self._write_line(f"optional: {expr.optional}", last=True)
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||||
|
self._write_line("LogicalExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("left")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.left.accept(self)
|
||||||
|
|
||||||
|
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||||
|
|
||||||
|
self._write_line("right", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||||
|
self._write_line("BinaryExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("left")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.left.accept(self)
|
||||||
|
|
||||||
|
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||||
|
|
||||||
|
self._write_line("right", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||||
|
self._write_line("UnaryExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||||
|
|
||||||
|
self._write_line("right", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: m.GetExpr):
|
||||||
|
self._write_line("GetExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("expr")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.expr.accept(self)
|
||||||
|
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||||
|
self._write_line("VariableExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||||
|
self._write_line("GroupingExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("expr", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.expr.accept(self)
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||||
|
self._write_line("LiteralExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f"value: {expr.value}", last=True)
|
||||||
|
|
||||||
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||||
|
self._write_line("WildcardExpr")
|
||||||
|
|
||||||
|
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
|
||||||
|
self._write_line("TemplateExpr")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
self._write_line("type")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.type.accept(self)
|
||||||
|
|
||||||
|
def visit_type_expr(self, expr: m.TypeExpr):
|
||||||
|
self._write_line("TypeExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||||
|
self._write_optional_child("template", expr.template)
|
||||||
|
self._write_line(f"optional: {expr.optional}", last=True)
|
||||||
|
|
||||||
|
|
||||||
|
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||||
|
def __init__(self, indent: int = 4):
|
||||||
|
self.indent: int = indent
|
||||||
|
self.level: int = 0
|
||||||
|
|
||||||
|
def indented(self, text: str) -> str:
|
||||||
|
return " " * (self.level * self.indent) + text
|
||||||
|
|
||||||
|
def print(self, expr: m.Expr | m.Stmt):
|
||||||
|
self.level = 0
|
||||||
|
return expr.accept(self)
|
||||||
|
|
||||||
|
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
|
||||||
|
template: str = stmt.template.accept(self) if stmt.template is not None else ""
|
||||||
|
res: str = f"type {stmt.name.lexeme}{template}({stmt.base.accept(self)})"
|
||||||
|
if stmt.constraint is not None:
|
||||||
|
res += " where " + stmt.constraint.accept(self)
|
||||||
|
return self.indented(res)
|
||||||
|
|
||||||
|
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
|
||||||
|
template: str = stmt.template.accept(self) if stmt.template is not None else ""
|
||||||
|
res: str = self.indented(f"type {stmt.name.lexeme}{template}")
|
||||||
|
res += " {\n"
|
||||||
|
self.level += 1
|
||||||
|
for prop in stmt.properties:
|
||||||
|
res += prop.accept(self)
|
||||||
|
res += "\n"
|
||||||
|
self.level -= 1
|
||||||
|
res += self.indented("}")
|
||||||
|
return res
|
||||||
|
|
||||||
|
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||||
|
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||||
|
if stmt.constraint is not None:
|
||||||
|
res += " where " + stmt.constraint.accept(self)
|
||||||
|
return self.indented(res)
|
||||||
|
|
||||||
|
def visit_extend_stmt(self, stmt: m.ExtendStmt):
|
||||||
|
res: str = self.indented(f"extend {stmt.type.accept(self)}")
|
||||||
|
res += " {\n"
|
||||||
|
self.level += 1
|
||||||
|
for op in stmt.operations:
|
||||||
|
res += op.accept(self)
|
||||||
|
self.level -= 1
|
||||||
|
res += "\n" + self.indented("}")
|
||||||
|
return res
|
||||||
|
|
||||||
|
def visit_op_stmt(self, stmt: m.OpStmt):
|
||||||
|
operand: str = stmt.operand.accept(self)
|
||||||
|
result: str = stmt.result.accept(self)
|
||||||
|
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}")
|
||||||
|
|
||||||
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||||
|
name: str = stmt.name.lexeme
|
||||||
|
subject: str = stmt.subject.lexeme
|
||||||
|
type: str = stmt.type.accept(self)
|
||||||
|
condition: str = stmt.condition.accept(self)
|
||||||
|
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
|
||||||
|
|
||||||
|
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
|
||||||
|
return f"{expr.name.lexeme}{'?' if expr.optional else ''}"
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||||
|
left: str = expr.left.accept(self)
|
||||||
|
operator: str = expr.operator.lexeme
|
||||||
|
right: str = expr.right.accept(self)
|
||||||
|
return f"{left} {operator} {right}"
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||||
|
left: str = expr.left.accept(self)
|
||||||
|
operator: str = expr.operator.lexeme
|
||||||
|
right: str = expr.right.accept(self)
|
||||||
|
return f"{left} {operator} {right}"
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||||
|
operator: str = expr.operator.lexeme
|
||||||
|
right: str = expr.right.accept(self)
|
||||||
|
return f"{operator}{right}"
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: m.GetExpr):
|
||||||
|
expr_: str = expr.expr.accept(self)
|
||||||
|
name: str = expr.name.lexeme
|
||||||
|
return f"{expr_}.{name}"
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||||
|
return expr.name.lexeme
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||||
|
expr_: str = expr.expr.accept(self)
|
||||||
|
return f"({expr_})"
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: m.LiteralExpr):
|
||||||
|
return str(expr.value)
|
||||||
|
|
||||||
|
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
||||||
|
return "_"
|
||||||
|
|
||||||
|
def visit_template_expr(self, expr: m.TemplateExpr):
|
||||||
|
return f"[{expr.type.accept(self)}]"
|
||||||
|
|
||||||
|
def visit_type_expr(self, expr: m.TypeExpr):
|
||||||
|
template: str = expr.template.accept(self) if expr.template is not None else ""
|
||||||
|
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}"
|
||||||
|
|
||||||
|
|
||||||
|
class PythonAstPrinter(
|
||||||
|
AstPrinter,
|
||||||
|
p.MidasType.Visitor[None],
|
||||||
|
p.Stmt.Visitor[None],
|
||||||
|
p.Expr.Visitor[None],
|
||||||
|
):
|
||||||
|
def visit_base_type(self, node: p.BaseType) -> None:
|
||||||
|
self._write_line("BaseType")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f"base: {node.base}")
|
||||||
|
self._write_optional_child("param", node.param, last=True)
|
||||||
|
|
||||||
|
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||||
|
self._write_line("ConstraintType")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("type")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
node.type.accept(self)
|
||||||
|
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
|
||||||
|
|
||||||
|
def visit_frame_column(self, node: p.FrameColumn) -> None:
|
||||||
|
self._write_line("FrameColumn")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f"name: {node.name}")
|
||||||
|
self._write_optional_child("type", node.type, last=True)
|
||||||
|
|
||||||
|
def visit_frame_type(self, node: p.FrameType) -> None:
|
||||||
|
self._write_line("FrameType")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("columns", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, col in enumerate(node.columns):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(node.columns) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
col.accept(self)
|
||||||
|
|
||||||
|
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||||
|
stmt.expr.accept(self)
|
||||||
|
|
||||||
|
def visit_function(self, stmt: p.Function) -> None:
|
||||||
|
self._write_line("Function")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f"name: {stmt.name}")
|
||||||
|
|
||||||
|
self._write_line("posonlyargs")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(stmt.posonlyargs):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.posonlyargs) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_argument(arg)
|
||||||
|
|
||||||
|
self._write_line("args")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(stmt.args):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.args) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_argument(arg)
|
||||||
|
|
||||||
|
self._write_line("kwonlyargs")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(stmt.kwonlyargs):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.kwonlyargs) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_argument(arg)
|
||||||
|
|
||||||
|
self._write_optional_child("returns", stmt.returns)
|
||||||
|
self._write_line("body", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, body_stmt in enumerate(stmt.body):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.body) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
body_stmt.accept(self)
|
||||||
|
|
||||||
|
def _print_argument(self, arg: p.Function.Argument) -> None:
|
||||||
|
self._write_line("FunctionArgument")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f"name: {arg.name}")
|
||||||
|
self._write_optional_child("type", arg.type, last=True)
|
||||||
|
|
||||||
|
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||||
|
self._write_line("TypeAssign")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f"name: {stmt.name}")
|
||||||
|
self._write_line("type", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.type.accept(self)
|
||||||
|
|
||||||
|
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||||
|
self._write_line("AssignStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("targets")
|
||||||
|
with self._child_level():
|
||||||
|
for i, target in enumerate(stmt.targets):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.targets) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
target.accept(self)
|
||||||
|
self._write_line("value", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.value.accept(self)
|
||||||
|
|
||||||
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
|
self._write_line("ReturnStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_optional_child("value", stmt.value, last=True)
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||||
|
self._write_line("IfStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("test")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.test.accept(self)
|
||||||
|
self._write_line("body")
|
||||||
|
with self._child_level():
|
||||||
|
for i, body_stmt in enumerate(stmt.body):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.body) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
body_stmt.accept(self)
|
||||||
|
self._write_line("orelse", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, else_stmt in enumerate(stmt.orelse):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.orelse) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
else_stmt.accept(self)
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||||
|
self._write_line("BinaryExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("left")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.left.accept(self)
|
||||||
|
|
||||||
|
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||||
|
|
||||||
|
self._write_line("right", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||||
|
self._write_line("CompareExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("left")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.left.accept(self)
|
||||||
|
|
||||||
|
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||||
|
|
||||||
|
self._write_line("right", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||||
|
self._write_line("UnaryExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||||
|
|
||||||
|
self._write_line("right", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||||
|
self._write_line("CallExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("callee")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.callee.accept(self)
|
||||||
|
|
||||||
|
self._write_line("arguments")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(expr.arguments):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(expr.arguments) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
arg.accept(self)
|
||||||
|
|
||||||
|
self._write_line("keywords", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(expr.keywords) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._write_line(name)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
arg.accept(self)
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||||
|
self._write_line("GetExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("object")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.object.accept(self)
|
||||||
|
self._write_line(f"name: {expr.name}", last=True)
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||||
|
self._write_line("LiteralExpr")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
self._write_line(f"value: {expr.value}")
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||||
|
self._write_line("VariableExpr")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
self._write_line(f"name: {expr.name}")
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
|
||||||
|
self._write_line("LogicalExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("left")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.left.accept(self)
|
||||||
|
|
||||||
|
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||||
|
|
||||||
|
self._write_line("right", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_set_expr(self, expr: p.SetExpr) -> None:
|
||||||
|
self._write_line("SetExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("object")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.object.accept(self)
|
||||||
|
self._write_line(f"name: {expr.name}")
|
||||||
|
self._write_line("value", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.value.accept(self)
|
||||||
|
|
||||||
|
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||||
|
self._write_line("CastExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("type")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.type.accept(self)
|
||||||
|
self._write_line("expr", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.expr.accept(self)
|
||||||
|
|
||||||
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||||
|
self._write_line("TernaryExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("test")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.test.accept(self)
|
||||||
|
|
||||||
|
self._write_line("if_true")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.if_true.accept(self)
|
||||||
|
|
||||||
|
self._write_line("if_false", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.if_false.accept(self)
|
||||||
327
midas/ast/python.py
Normal file
327
midas/ast/python.py
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
"""
|
||||||
|
This file was generated by a script. Any manual changes might be overwritten.
|
||||||
|
Please modify gen/python.py instead and run gen/gen.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Generic, Optional, TypeVar
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Type annotations #
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class MidasType(ABC):
|
||||||
|
location: Location
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||||
|
|
||||||
|
class Visitor(ABC, Generic[T]):
|
||||||
|
@abstractmethod
|
||||||
|
def visit_base_type(self, node: BaseType) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_constraint_type(self, node: ConstraintType) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_frame_column(self, node: FrameColumn) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_frame_type(self, node: FrameType) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class BaseType(MidasType):
|
||||||
|
base: str
|
||||||
|
param: Optional[MidasType]
|
||||||
|
|
||||||
|
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_base_type(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ConstraintType(MidasType):
|
||||||
|
type: MidasType
|
||||||
|
constraint: ast.expr
|
||||||
|
|
||||||
|
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_constraint_type(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FrameColumn(MidasType):
|
||||||
|
name: Optional[str]
|
||||||
|
type: Optional[MidasType]
|
||||||
|
|
||||||
|
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_frame_column(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FrameType(MidasType):
|
||||||
|
columns: list[FrameColumn]
|
||||||
|
|
||||||
|
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_frame_type(self)
|
||||||
|
|
||||||
|
|
||||||
|
##############
|
||||||
|
# Statements #
|
||||||
|
##############
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Stmt(ABC):
|
||||||
|
location: Location
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||||
|
|
||||||
|
class Visitor(ABC, Generic[T]):
|
||||||
|
@abstractmethod
|
||||||
|
def visit_expression_stmt(self, stmt: ExpressionStmt) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_function(self, stmt: Function) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_type_assign(self, stmt: TypeAssign) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_assign_stmt(self, stmt: AssignStmt) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ExpressionStmt(Stmt):
|
||||||
|
expr: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_expression_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Function(Stmt):
|
||||||
|
name: str
|
||||||
|
posonlyargs: list[Argument]
|
||||||
|
args: list[Argument]
|
||||||
|
sink: Optional[Argument]
|
||||||
|
kwonlyargs: list[Argument]
|
||||||
|
kw_sink: Optional[Argument]
|
||||||
|
returns: Optional[MidasType]
|
||||||
|
body: list[Stmt]
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Argument:
|
||||||
|
location: Optional[Location] = None
|
||||||
|
name: str
|
||||||
|
type: Optional[MidasType]
|
||||||
|
default: Optional[Expr]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_args(self) -> list[Argument]:
|
||||||
|
return self.posonlyargs + self.args + self.kwonlyargs
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_function(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TypeAssign(Stmt):
|
||||||
|
name: str
|
||||||
|
type: MidasType
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_type_assign(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AssignStmt(Stmt):
|
||||||
|
targets: list[Expr]
|
||||||
|
value: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_assign_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ReturnStmt(Stmt):
|
||||||
|
value: Optional[Expr]
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_return_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class IfStmt(Stmt):
|
||||||
|
test: Expr
|
||||||
|
body: list[Stmt]
|
||||||
|
orelse: list[Stmt]
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_if_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
###############
|
||||||
|
# Expressions #
|
||||||
|
###############
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Expr(ABC):
|
||||||
|
location: Location
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||||
|
|
||||||
|
class Visitor(ABC, Generic[T]):
|
||||||
|
@abstractmethod
|
||||||
|
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_compare_expr(self, expr: CompareExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_call_expr(self, expr: CallExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_variable_expr(self, expr: VariableExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_set_expr(self, expr: SetExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_cast_expr(self, expr: CastExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class BinaryExpr(Expr):
|
||||||
|
left: Expr
|
||||||
|
operator: ast.operator
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_binary_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CompareExpr(Expr):
|
||||||
|
left: Expr
|
||||||
|
operator: ast.cmpop
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_compare_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class UnaryExpr(Expr):
|
||||||
|
operator: ast.unaryop
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_unary_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CallExpr(Expr):
|
||||||
|
callee: Expr
|
||||||
|
arguments: list[Expr]
|
||||||
|
keywords: dict[str, Expr]
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_call_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GetExpr(Expr):
|
||||||
|
object: Expr
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_get_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LiteralExpr(Expr):
|
||||||
|
value: Any
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_literal_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class VariableExpr(Expr):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_variable_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LogicalExpr(Expr):
|
||||||
|
left: Expr
|
||||||
|
operator: ast.boolop
|
||||||
|
right: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_logical_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SetExpr(Expr):
|
||||||
|
object: Expr
|
||||||
|
name: str
|
||||||
|
value: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_set_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CastExpr(Expr):
|
||||||
|
type: MidasType
|
||||||
|
expr: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_cast_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TernaryExpr(Expr):
|
||||||
|
test: Expr
|
||||||
|
if_true: Expr
|
||||||
|
if_false: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_ternary_expr(self)
|
||||||
546
midas/checker/checker.py
Normal file
546
midas/checker/checker.py
Normal file
@@ -0,0 +1,546 @@
|
|||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||||
|
from midas.checker.environment import Environment
|
||||||
|
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
||||||
|
from midas.checker.types import BaseType, Function, SimpleType, Type, UnitType, UnknownType
|
||||||
|
from midas.lexer.midas import MidasLexer
|
||||||
|
from midas.lexer.token import Token
|
||||||
|
from midas.parser.midas import MidasParser
|
||||||
|
from midas.resolver.midas import MidasResolver
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class MappedArgument:
|
||||||
|
expr: p.Expr
|
||||||
|
type: Type
|
||||||
|
argument: Function.Argument
|
||||||
|
|
||||||
|
|
||||||
|
class Checker(
|
||||||
|
p.Stmt.Visitor[None],
|
||||||
|
p.Expr.Visitor[Type],
|
||||||
|
p.MidasType.Visitor[Type],
|
||||||
|
):
|
||||||
|
"""A type checker which can use custom type definitions"""
|
||||||
|
|
||||||
|
def __init__(self, locals: dict[p.Expr, int], file_path: Path):
|
||||||
|
self.logger: logging.Logger = logging.getLogger("Checker")
|
||||||
|
self.file_path: Path = file_path
|
||||||
|
self.ctx: MidasResolver = MidasResolver()
|
||||||
|
self.global_env: Environment = Environment()
|
||||||
|
self.env: Environment = self.global_env
|
||||||
|
self.locals: dict[p.Expr, int] = locals
|
||||||
|
self.diagnostics: list[Diagnostic] = []
|
||||||
|
|
||||||
|
def diagnostic(self, type: DiagnosticType, location: Location, message: str):
|
||||||
|
self.diagnostics.append(
|
||||||
|
Diagnostic(
|
||||||
|
file_path=self.file_path,
|
||||||
|
location=location,
|
||||||
|
type=type,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def error(self, location: Location, message: str):
|
||||||
|
self.diagnostic(
|
||||||
|
type=DiagnosticType.ERROR,
|
||||||
|
location=location,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def warning(self, location: Location, message: str):
|
||||||
|
self.diagnostic(
|
||||||
|
type=DiagnosticType.WARNING,
|
||||||
|
location=location,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def info(self, location: Location, message: str):
|
||||||
|
self.diagnostic(
|
||||||
|
type=DiagnosticType.INFO,
|
||||||
|
location=location,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def evaluate(self, expr: p.Expr) -> Type:
|
||||||
|
"""Evaluate the type of an expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr (p.Expr): the expression to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the type of the given expression
|
||||||
|
"""
|
||||||
|
return expr.accept(self)
|
||||||
|
|
||||||
|
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
||||||
|
"""Evaluate a sequence of statements
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block (list[p.Stmt]): the statements to evaluate
|
||||||
|
env (Environment): the environment in which to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether a return statement is present in the block
|
||||||
|
"""
|
||||||
|
previous_env: Environment = self.env
|
||||||
|
self.env = env
|
||||||
|
returned: bool = False
|
||||||
|
for i, stmt in enumerate(block):
|
||||||
|
try:
|
||||||
|
stmt.accept(self)
|
||||||
|
except ReturnException:
|
||||||
|
returned = True
|
||||||
|
if i < len(block) - 1:
|
||||||
|
self.warning(block[i + 1].location, "Unreachable statement")
|
||||||
|
break
|
||||||
|
self.env = previous_env
|
||||||
|
return returned
|
||||||
|
|
||||||
|
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
||||||
|
"""Type check a sequence of statements and returns diagnostics
|
||||||
|
|
||||||
|
Args:
|
||||||
|
statements (list[p.Stmt]): the statements to evaluate and check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
|
||||||
|
"""
|
||||||
|
self.diagnostics = []
|
||||||
|
for stmt in statements:
|
||||||
|
stmt.accept(self)
|
||||||
|
|
||||||
|
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
||||||
|
return self.diagnostics
|
||||||
|
|
||||||
|
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
||||||
|
"""Look up a variable in the environment it was declared
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the variable
|
||||||
|
expr (p.Expr): the variable expression, used to lookup the scope distance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Type]: the type of the variable, or None if it was not found
|
||||||
|
"""
|
||||||
|
distance: Optional[int] = self.locals.get(expr)
|
||||||
|
if distance is not None:
|
||||||
|
return self.env.get_at(distance, name)
|
||||||
|
return self.global_env.get(name)
|
||||||
|
|
||||||
|
def parse_midas_import(self, expr: p.CallExpr) -> Optional[Path]:
|
||||||
|
"""Parse a Midas import statement
|
||||||
|
|
||||||
|
The statement should be written as `midas.using("path/to/types.midas")`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr (p.CallExpr): the import call expression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Path]: the path to the imported file, or None if the expression is malformed
|
||||||
|
"""
|
||||||
|
match expr:
|
||||||
|
case p.CallExpr(
|
||||||
|
callee=p.GetExpr(
|
||||||
|
object=p.VariableExpr(name="midas"),
|
||||||
|
name="using",
|
||||||
|
),
|
||||||
|
arguments=[
|
||||||
|
p.LiteralExpr(value=path),
|
||||||
|
],
|
||||||
|
):
|
||||||
|
return Path(path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def import_midas(self, path: Path) -> None:
|
||||||
|
"""Import Midas definitions from a path
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (Path): the import path
|
||||||
|
"""
|
||||||
|
self.logger.debug(f"Importing type definitions from {path}")
|
||||||
|
path = (self.file_path.parent / path).resolve()
|
||||||
|
lexer: MidasLexer = MidasLexer(path.read_text())
|
||||||
|
tokens: list[Token] = lexer.process()
|
||||||
|
parser: MidasParser = MidasParser(tokens)
|
||||||
|
stmts: list[m.Stmt] = parser.parse()
|
||||||
|
self.ctx.resolve(stmts)
|
||||||
|
self.logger.debug(f"Midas types: {self.ctx._types}")
|
||||||
|
self.logger.debug(f"Midas operations: {self.ctx._operations}")
|
||||||
|
|
||||||
|
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||||
|
self.evaluate(stmt.expr)
|
||||||
|
|
||||||
|
def visit_function(self, stmt: p.Function) -> None:
|
||||||
|
env: Environment = Environment(self.env)
|
||||||
|
pos_args: list[Function.Argument] = []
|
||||||
|
args: list[Function.Argument] = []
|
||||||
|
kw_args: list[Function.Argument] = []
|
||||||
|
|
||||||
|
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||||
|
if arg.type is not None:
|
||||||
|
return arg.type.accept(self)
|
||||||
|
if arg.default is not None:
|
||||||
|
return arg.default.accept(self)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
for arg in stmt.posonlyargs:
|
||||||
|
pos_args.append(
|
||||||
|
Function.Argument(
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for arg in stmt.args:
|
||||||
|
args.append(
|
||||||
|
Function.Argument(
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for arg in stmt.kwonlyargs:
|
||||||
|
kw_args.append(
|
||||||
|
Function.Argument(
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for arg in pos_args + args + kw_args:
|
||||||
|
env.define(arg.name, arg.type)
|
||||||
|
|
||||||
|
returns_hint: Optional[Type] = None
|
||||||
|
if stmt.returns is not None:
|
||||||
|
returns_hint = stmt.returns.accept(self)
|
||||||
|
# Early define to handle simple fully-typed recursion
|
||||||
|
inside_function: Function = Function(
|
||||||
|
name=stmt.name,
|
||||||
|
pos_args=pos_args,
|
||||||
|
args=args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=returns_hint,
|
||||||
|
)
|
||||||
|
self.env.define(stmt.name, inside_function)
|
||||||
|
|
||||||
|
returned: bool = self.evaluate_block(stmt.body, env)
|
||||||
|
inferred_return: Type = UnknownType()
|
||||||
|
if not returned:
|
||||||
|
env.return_types.append(UnitType())
|
||||||
|
return_types: set[Type] = set(env.return_types)
|
||||||
|
if len(return_types) == 1:
|
||||||
|
inferred_return = list(return_types)[0]
|
||||||
|
elif len(return_types) > 1:
|
||||||
|
self.error(
|
||||||
|
stmt.location,
|
||||||
|
f"Mixed return types: {env.return_types}",
|
||||||
|
)
|
||||||
|
|
||||||
|
returns: Type = UnknownType()
|
||||||
|
if returns_hint is not None:
|
||||||
|
assert stmt.returns is not None
|
||||||
|
returns = returns_hint
|
||||||
|
if returns != inferred_return:
|
||||||
|
self.error(
|
||||||
|
stmt.returns.location,
|
||||||
|
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
returns = inferred_return
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
function: Function = Function(
|
||||||
|
name=stmt.name,
|
||||||
|
pos_args=pos_args,
|
||||||
|
args=args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=returns,
|
||||||
|
)
|
||||||
|
self.env.define(stmt.name, function)
|
||||||
|
|
||||||
|
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||||
|
# TODO check not yet defined locally
|
||||||
|
type: Type = stmt.type.accept(self)
|
||||||
|
self.env.define(stmt.name, type)
|
||||||
|
|
||||||
|
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||||
|
value: Type = self.evaluate(stmt.value)
|
||||||
|
for target in stmt.targets:
|
||||||
|
if not isinstance(target, p.VariableExpr):
|
||||||
|
self.logger.warning(f"Unsupported assignment to {target}")
|
||||||
|
self.warning(target.location, f"Unsupported assignment to {target}")
|
||||||
|
continue
|
||||||
|
name: str = target.name
|
||||||
|
var_type: Optional[Type] = self.look_up_variable(name, target)
|
||||||
|
|
||||||
|
if var_type is None:
|
||||||
|
self.env.define(name, value)
|
||||||
|
else:
|
||||||
|
# TODO: implement real comparison method
|
||||||
|
if var_type != value:
|
||||||
|
self.error(
|
||||||
|
stmt.location,
|
||||||
|
f"Cannot assign {value} to {name} of type {var_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
|
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
||||||
|
self.env.return_types.append(type)
|
||||||
|
raise ReturnException()
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||||
|
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
||||||
|
# For example:
|
||||||
|
# if (m := 1 + 1) < 2:
|
||||||
|
# ...
|
||||||
|
# print(m) # <- m is still defined
|
||||||
|
test_type: Type = stmt.test.accept(self)
|
||||||
|
|
||||||
|
# TODO Allow subtypes or any type
|
||||||
|
if test_type != self.ctx.get_type("bool"):
|
||||||
|
self.error(
|
||||||
|
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
env: Environment = Environment(self.env)
|
||||||
|
body_returned: bool = self.evaluate_block(stmt.body, env)
|
||||||
|
else_returned: bool = self.evaluate_block(stmt.orelse, env)
|
||||||
|
self.env.return_types.extend(env.return_types)
|
||||||
|
if body_returned and else_returned:
|
||||||
|
raise ReturnException()
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||||
|
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
|
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
||||||
|
return UnknownType()
|
||||||
|
left: Type = self.evaluate(expr.left)
|
||||||
|
right: Type = self.evaluate(expr.right)
|
||||||
|
|
||||||
|
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
||||||
|
if result is None:
|
||||||
|
self.error(
|
||||||
|
expr.location,
|
||||||
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||||
|
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
|
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
||||||
|
return UnknownType()
|
||||||
|
left: Type = self.evaluate(expr.left)
|
||||||
|
right: Type = self.evaluate(expr.right)
|
||||||
|
|
||||||
|
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
||||||
|
if result is None:
|
||||||
|
self.error(
|
||||||
|
expr.location,
|
||||||
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||||
|
if path := self.parse_midas_import(expr):
|
||||||
|
self.import_midas(path)
|
||||||
|
return UnknownType()
|
||||||
|
callee: Type = self.evaluate(expr.callee)
|
||||||
|
if not isinstance(callee, Function):
|
||||||
|
self.error(expr.callee.location, "Callee is not a function")
|
||||||
|
return UnknownType()
|
||||||
|
function: Function = callee
|
||||||
|
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||||
|
for arg in mapped:
|
||||||
|
if arg.type != arg.argument.type:
|
||||||
|
self.error(
|
||||||
|
arg.expr.location,
|
||||||
|
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||||
|
)
|
||||||
|
return function.returns
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
||||||
|
match expr.value:
|
||||||
|
case bool(): # Must be before int
|
||||||
|
return self.ctx.get_type("bool")
|
||||||
|
case int():
|
||||||
|
return self.ctx.get_type("int")
|
||||||
|
case float():
|
||||||
|
return self.ctx.get_type("float")
|
||||||
|
case str():
|
||||||
|
return self.ctx.get_type("str")
|
||||||
|
case _:
|
||||||
|
self.warning(expr.location, f"Unknown literal {expr}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
||||||
|
return self.look_up_variable(expr.name, expr) or UnknownType()
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||||
|
return expr.type.accept(self)
|
||||||
|
|
||||||
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||||
|
test_type: Type = expr.test.accept(self)
|
||||||
|
|
||||||
|
# TODO Allow subtypes or any type
|
||||||
|
if test_type != self.ctx.get_type("bool"):
|
||||||
|
self.error(
|
||||||
|
expr.test.location, f"If test must be a boolean, got {test_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
true_type: Type = expr.if_true.accept(self)
|
||||||
|
false_type: Type = expr.if_false.accept(self)
|
||||||
|
if true_type != false_type:
|
||||||
|
self.error(expr.location, f"Type mismatch in ternary if branches: true={true_type} != false={false_type}")
|
||||||
|
return UnknownType()
|
||||||
|
return true_type
|
||||||
|
|
||||||
|
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||||
|
return self.ctx.get_type(node.base)
|
||||||
|
|
||||||
|
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
|
||||||
|
|
||||||
|
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
||||||
|
|
||||||
|
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
||||||
|
|
||||||
|
def map_call_arguments(
|
||||||
|
self, function: Function, call: p.CallExpr
|
||||||
|
) -> list[MappedArgument]:
|
||||||
|
"""Map call arguments to function parameters as defined in its signature
|
||||||
|
|
||||||
|
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||||
|
with the arguments passed at the call site
|
||||||
|
|
||||||
|
Any mismatched, missing or unexpected argument is reported as a diagnostic
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function (Function): the function definition
|
||||||
|
call (p.CallExpr): the call expression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[MappedArgument]: the list of mapped arguments
|
||||||
|
"""
|
||||||
|
positional: list[tuple[p.Expr, Type]] = [
|
||||||
|
(arg, self.evaluate(arg)) for arg in call.arguments
|
||||||
|
]
|
||||||
|
keywords: dict[str, tuple[p.Expr, Type]] = {
|
||||||
|
name: (arg, self.evaluate(arg)) for name, arg in call.keywords.items()
|
||||||
|
}
|
||||||
|
set_args: set[str] = set()
|
||||||
|
|
||||||
|
required_positional: list[str] = [
|
||||||
|
arg.name for arg in function.pos_args + function.args if arg.required
|
||||||
|
]
|
||||||
|
required_keyword: list[str] = [
|
||||||
|
arg.name for arg in function.kw_args if arg.required
|
||||||
|
]
|
||||||
|
|
||||||
|
mapped: list[MappedArgument] = []
|
||||||
|
|
||||||
|
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||||
|
mixed_params: list[Function.Argument] = list(function.args)
|
||||||
|
kw_params: dict[str, Function.Argument] = {
|
||||||
|
arg.name: arg for arg in function.kw_args
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
for arg in positional:
|
||||||
|
param: Function.Argument
|
||||||
|
if len(pos_params) != 0:
|
||||||
|
param = pos_params.pop(0)
|
||||||
|
elif len(mixed_params) != 0:
|
||||||
|
param = mixed_params.pop(0)
|
||||||
|
else:
|
||||||
|
self.error(arg[0].location, "Too many positional arguments")
|
||||||
|
break
|
||||||
|
name: str = param.name
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||||
|
for name, arg in keywords.items():
|
||||||
|
param: Function.Argument
|
||||||
|
if name not in kw_params:
|
||||||
|
if name in set_args:
|
||||||
|
self.error(
|
||||||
|
arg[0].location, f"Multiple values for argument '{name}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
|
||||||
|
continue
|
||||||
|
param = kw_params.pop(name)
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def join_args(args: list[str]) -> str:
|
||||||
|
args = list(map(lambda a: f"'{a}'", args))
|
||||||
|
if len(args) == 0:
|
||||||
|
return ""
|
||||||
|
if len(args) == 1:
|
||||||
|
return args[0]
|
||||||
|
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||||
|
|
||||||
|
if len(required_positional) != 0:
|
||||||
|
plural: str = "" if len(required_positional) == 1 else "s"
|
||||||
|
args: str = join_args(required_positional)
|
||||||
|
self.error(
|
||||||
|
call.location,
|
||||||
|
f"Missing required positional argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(required_keyword) != 0:
|
||||||
|
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||||
|
args: str = join_args(required_keyword)
|
||||||
|
self.error(
|
||||||
|
call.location,
|
||||||
|
f"Missing required keyword argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return mapped
|
||||||
33
midas/checker/diagnostic.py
Normal file
33
midas/checker/diagnostic.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import StrEnum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
|
||||||
|
|
||||||
|
class DiagnosticType(StrEnum):
|
||||||
|
ERROR = "Error"
|
||||||
|
WARNING = "Warning"
|
||||||
|
INFO = "Info"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Diagnostic:
|
||||||
|
file_path: Path
|
||||||
|
location: Location
|
||||||
|
type: DiagnosticType
|
||||||
|
message: str
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
|
||||||
|
end_loc: Optional[str] = ""
|
||||||
|
if (
|
||||||
|
self.location.end_lineno is not None
|
||||||
|
and self.location.end_col_offset is not None
|
||||||
|
):
|
||||||
|
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
|
||||||
|
loc: str = (
|
||||||
|
f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}"
|
||||||
|
)
|
||||||
|
return f"{self.type} in {self.file_path} {loc}: {self.message}"
|
||||||
142
midas/checker/environment.py
Normal file
142
midas/checker/environment.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.checker.types import Type
|
||||||
|
|
||||||
|
|
||||||
|
class Environment:
|
||||||
|
"""
|
||||||
|
A scoped environment in which variables are defined
|
||||||
|
|
||||||
|
Each environment can inherit from a parent/enclosing environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, enclosing: Optional[Environment] = None) -> None:
|
||||||
|
self.enclosing: Optional[Environment] = enclosing
|
||||||
|
self.values: dict[str, Type] = {}
|
||||||
|
self.return_types: list[Type] = []
|
||||||
|
|
||||||
|
self._children: list[Environment] = []
|
||||||
|
if enclosing is not None:
|
||||||
|
enclosing._children.append(self)
|
||||||
|
|
||||||
|
def define(self, name: str, value: Type) -> None:
|
||||||
|
"""Define a variable in this environment
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the variable
|
||||||
|
value (Type): the value
|
||||||
|
"""
|
||||||
|
self.values[name] = value
|
||||||
|
|
||||||
|
def get(self, name: str) -> Optional[Type]:
|
||||||
|
"""Get a variable in the closest environment which has a definition for it
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the variable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Type]: the value of the variable, or None if it was not found
|
||||||
|
"""
|
||||||
|
if name in self.values:
|
||||||
|
return self.values[name]
|
||||||
|
if self.enclosing is not None:
|
||||||
|
return self.enclosing.get(name)
|
||||||
|
# raise NameError(f"Undefined variable '{name}'")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def assign(self, name: str, value: Type) -> bool:
|
||||||
|
"""Assign a new value to a variable in the environment it was defined in
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the variable
|
||||||
|
value (Type): the new value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the variable was assigned in this environment or an ancestor, False otherwise
|
||||||
|
"""
|
||||||
|
if name not in self.values:
|
||||||
|
if self.enclosing is None:
|
||||||
|
return False
|
||||||
|
if self.enclosing.assign(name, value):
|
||||||
|
return True
|
||||||
|
self.values[name] = value
|
||||||
|
return True
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear all definitions in this environment"""
|
||||||
|
self.values = {}
|
||||||
|
|
||||||
|
def get_at(self, distance: int, name: str) -> Optional[Type]:
|
||||||
|
"""Get the value of a variable at a given distance
|
||||||
|
|
||||||
|
A distance of 0 looks up in this environment, 1 in the parent environment, etc.
|
||||||
|
This methods expects `distance` to be valid. An error will be raised if
|
||||||
|
the stack does not extend far enough to reach `distance`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distance (int): the scope distance
|
||||||
|
name (str): the name of the variable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Type]: the value at the given distance, or None if it is not defined in that environment
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||||
|
"""
|
||||||
|
return self.ancestor(distance).values.get(name)
|
||||||
|
|
||||||
|
def assign_at(self, distance: int, name: str, value: Type) -> None:
|
||||||
|
"""Assign a new value to a variable at a given distance
|
||||||
|
|
||||||
|
A distance of 0 assigns in this environment, 1 in the parent environment, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distance (int): the scope distance
|
||||||
|
name (str): the name of the variable
|
||||||
|
value (Type): the new value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||||
|
"""
|
||||||
|
self.ancestor(distance).values[name] = value
|
||||||
|
|
||||||
|
def ancestor(self, distance: int) -> Environment:
|
||||||
|
"""Get the ancestor at a given distance
|
||||||
|
|
||||||
|
A distance of 0 references this environment, 1 the parent environment, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distance (int): the scope distance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Environment: the environment
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||||
|
"""
|
||||||
|
env: Environment = self
|
||||||
|
for _ in range(distance):
|
||||||
|
assert env.enclosing is not None
|
||||||
|
env = env.enclosing
|
||||||
|
return env
|
||||||
|
|
||||||
|
def flat_dict(self) -> dict[str, Type]:
|
||||||
|
"""Get the current environment including definitions in its ancestor as a flat dictionary
|
||||||
|
|
||||||
|
This method recursively combines this environment definitions with its ancestor's
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: the combined environment
|
||||||
|
"""
|
||||||
|
if self.enclosing is None:
|
||||||
|
return self.values
|
||||||
|
return self.enclosing.flat_dict() | self.values
|
||||||
|
|
||||||
|
def dump(self) -> dict:
|
||||||
|
return {
|
||||||
|
"values": self.values,
|
||||||
|
"return_types": self.return_types,
|
||||||
|
"children": [child.dump() for child in self._children],
|
||||||
|
}
|
||||||
31
midas/checker/operators.py
Normal file
31
midas/checker/operators.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import ast
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
||||||
|
ast.Add: "__add__",
|
||||||
|
ast.Sub: "__sub__",
|
||||||
|
ast.Mult: "__mul__",
|
||||||
|
ast.MatMult: "__matmul__",
|
||||||
|
ast.Div: "__truediv__",
|
||||||
|
ast.Mod: "__mod__",
|
||||||
|
ast.Pow: "__pow__",
|
||||||
|
ast.LShift: "__lshift__",
|
||||||
|
ast.RShift: "__rshift__",
|
||||||
|
ast.BitOr: "__or__",
|
||||||
|
ast.BitXor: "__xor__",
|
||||||
|
ast.BitAnd: "__and__",
|
||||||
|
ast.FloorDiv: "__floordiv__",
|
||||||
|
}
|
||||||
|
|
||||||
|
COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
||||||
|
ast.Eq: "__eq__",
|
||||||
|
# ast.NotEq: "__noteq__",
|
||||||
|
ast.Lt: "__lt__",
|
||||||
|
ast.LtE: "__le__",
|
||||||
|
ast.Gt: "__gt__",
|
||||||
|
ast.GtE: "__ge__",
|
||||||
|
# ast.Is: "__is__",
|
||||||
|
# ast.IsNot: "__isnot__",
|
||||||
|
# ast.In: "__in__",
|
||||||
|
# ast.NotIn: "__notin__",
|
||||||
|
}
|
||||||
42
midas/checker/types.py
Normal file
42
midas/checker/types.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class BaseType:
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class SimpleType:
|
||||||
|
name: str
|
||||||
|
base: BaseType | SimpleType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class UnknownType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class UnitType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Function:
|
||||||
|
name: str
|
||||||
|
pos_args: list[Argument]
|
||||||
|
args: list[Argument]
|
||||||
|
kw_args: list[Argument]
|
||||||
|
returns: Type
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Argument:
|
||||||
|
name: str
|
||||||
|
type: Type
|
||||||
|
required: bool
|
||||||
|
|
||||||
|
|
||||||
|
Type = BaseType | SimpleType | UnknownType | UnitType | Function
|
||||||
0
midas/cli/__init__.py
Normal file
0
midas/cli/__init__.py
Normal file
57
midas/cli/highlight.css
Normal file
57
midas/cli/highlight.css
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
html,
|
||||||
|
body {
|
||||||
|
margin: 0;
|
||||||
|
font-size: 14pt;
|
||||||
|
}
|
||||||
|
|
||||||
|
* {
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
#code {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
font-family: monospace;
|
||||||
|
white-space: pre-wrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.line {
|
||||||
|
display: flex;
|
||||||
|
|
||||||
|
&:nth-child(odd) {
|
||||||
|
background-color: rgb(247, 247, 247);
|
||||||
|
}
|
||||||
|
|
||||||
|
.no {
|
||||||
|
width: 4em;
|
||||||
|
text-align: right;
|
||||||
|
padding: 0.2em 0.4em;
|
||||||
|
border-right: solid black 1px;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.txt {
|
||||||
|
flex-grow: 1;
|
||||||
|
padding: 0.2em 0.8em;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
span {
|
||||||
|
--col: transparent;
|
||||||
|
--opacity: 0.1;
|
||||||
|
--border: 0px;
|
||||||
|
background-color: rgba(var(--col), var(--opacity));
|
||||||
|
outline: solid rgb(var(--col)) var(--border);
|
||||||
|
outline-offset: 2px;
|
||||||
|
border-radius: 2px;
|
||||||
|
|
||||||
|
&:hover:not(:has(*:hover)) {
|
||||||
|
--opacity: 0.8;
|
||||||
|
--border: 2px;
|
||||||
|
z-index: 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.keyword {
|
||||||
|
color: rgb(211, 72, 9);
|
||||||
|
}
|
||||||
|
}
|
||||||
300
midas/cli/highlighter.py
Normal file
300
midas/cli/highlighter.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generic, Optional, Protocol, TextIO, TypeVar
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.diagnostic import Diagnostic
|
||||||
|
|
||||||
|
H = TypeVar("H", bound="Highlighter", contravariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Highlightable(Protocol, Generic[H]):
|
||||||
|
def accept(self, visitor: H): ...
|
||||||
|
|
||||||
|
|
||||||
|
class Locatable(Protocol):
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def location(self) -> Optional[Location]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class Highlighter(ABC):
|
||||||
|
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
|
||||||
|
EXTRA_CSS_PATH: Optional[Path] = None
|
||||||
|
|
||||||
|
def __init__(self, source: str) -> None:
|
||||||
|
self.source: str = source
|
||||||
|
self.lines: list[str] = self.source.splitlines()
|
||||||
|
self.openings: dict[tuple[int, int], list[str]] = {}
|
||||||
|
self.closings: dict[tuple[int, int], list[str]] = {}
|
||||||
|
|
||||||
|
def format_css(self, path: Path) -> list[str]:
|
||||||
|
css: str = path.read_text()
|
||||||
|
css = "\n".join((" " + line).rstrip() for line in css.splitlines())
|
||||||
|
return [
|
||||||
|
" <style>",
|
||||||
|
css,
|
||||||
|
" </style>",
|
||||||
|
]
|
||||||
|
|
||||||
|
def dump(self, buf: TextIO):
|
||||||
|
base_css: list[str] = self.format_css(self.BASE_CSS_PATH)
|
||||||
|
extra_css: list[str] = (
|
||||||
|
self.format_css(self.EXTRA_CSS_PATH)
|
||||||
|
if self.EXTRA_CSS_PATH is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
lines: list[str] = [
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
'<html lang="en">',
|
||||||
|
"<head>",
|
||||||
|
' <meta charset="UTF-8">',
|
||||||
|
' <meta name="viewport" content="width=device-width, initial-scale=1.0">',
|
||||||
|
" <title>Highlighted file</title>",
|
||||||
|
*base_css,
|
||||||
|
*extra_css,
|
||||||
|
"</head>",
|
||||||
|
"<body>",
|
||||||
|
' <div id="code">',
|
||||||
|
]
|
||||||
|
for l, line in enumerate(self.lines):
|
||||||
|
lineno: int = l + 1
|
||||||
|
line_buf: str = (
|
||||||
|
f'<div class="line" id="l{lineno}"><div class="no">{lineno}</div><div class="txt">'
|
||||||
|
)
|
||||||
|
for c, char in enumerate(line):
|
||||||
|
pos: tuple[int, int] = (lineno, c)
|
||||||
|
closings: list[str] = self.closings.get(pos, [])
|
||||||
|
openings: list[str] = self.openings.get(pos, [])
|
||||||
|
line_buf += "".join(closings + openings)
|
||||||
|
line_buf += char
|
||||||
|
line_buf += "".join(self.closings.get((lineno, len(line)), []))
|
||||||
|
line_buf += "</div></div>"
|
||||||
|
lines.append(" " + line_buf)
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
" </div>",
|
||||||
|
"</body>",
|
||||||
|
"</html>",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
buf.write("\n".join(lines))
|
||||||
|
|
||||||
|
def wrap(self, node: Locatable, cls: str, message: Optional[str] = None):
|
||||||
|
if node.location is None:
|
||||||
|
return
|
||||||
|
if node.location.end_lineno is None or node.location.end_col_offset is None:
|
||||||
|
return
|
||||||
|
start_pos: tuple[int, int] = (node.location.lineno, node.location.col_offset)
|
||||||
|
end_pos: tuple[int, int] = (
|
||||||
|
node.location.end_lineno,
|
||||||
|
node.location.end_col_offset,
|
||||||
|
)
|
||||||
|
opening: str = f'<span class="{cls}" title="{cls}">'
|
||||||
|
closing: str = "</span>"
|
||||||
|
if message is not None:
|
||||||
|
opening = f'<span class="with-msg">{opening}'
|
||||||
|
closing = f'{closing}<span class="message">{message}</span></span>'
|
||||||
|
|
||||||
|
self.openings.setdefault(start_pos, []).append(opening)
|
||||||
|
self.closings.setdefault(end_pos, []).insert(0, closing)
|
||||||
|
if start_pos[0] != end_pos[0]:
|
||||||
|
for l in range(start_pos[0], end_pos[0]):
|
||||||
|
c: int = len(self.lines[l - 1])
|
||||||
|
self.closings.setdefault((l, c), []).insert(0, closing)
|
||||||
|
self.openings.setdefault((l + 1, 0), []).append(opening)
|
||||||
|
|
||||||
|
|
||||||
|
class PythonHighlighter(
|
||||||
|
Highlighter,
|
||||||
|
p.MidasType.Visitor[None],
|
||||||
|
p.Stmt.Visitor[None],
|
||||||
|
p.Expr.Visitor[None],
|
||||||
|
):
|
||||||
|
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_python.css"
|
||||||
|
|
||||||
|
def highlight(self, node: Highlightable[PythonHighlighter]):
|
||||||
|
node.accept(self)
|
||||||
|
|
||||||
|
def visit_base_type(self, node: p.BaseType) -> None:
|
||||||
|
self.wrap(node, "base-type")
|
||||||
|
if node.param is not None:
|
||||||
|
self.wrap(node.param, "param")
|
||||||
|
node.param.accept(self)
|
||||||
|
|
||||||
|
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||||
|
self.wrap(node, "constraint-type")
|
||||||
|
node.type.accept(self)
|
||||||
|
|
||||||
|
def visit_frame_column(self, node: p.FrameColumn) -> None:
|
||||||
|
self.wrap(node, "frame-column")
|
||||||
|
if node.type is not None:
|
||||||
|
node.type.accept(self)
|
||||||
|
|
||||||
|
def visit_frame_type(self, node: p.FrameType) -> None:
|
||||||
|
self.wrap(node, "frame-type")
|
||||||
|
for column in node.columns:
|
||||||
|
column.accept(self)
|
||||||
|
|
||||||
|
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||||
|
stmt.expr.accept(self)
|
||||||
|
|
||||||
|
def visit_function(self, stmt: p.Function) -> None:
|
||||||
|
self.wrap(stmt, "function")
|
||||||
|
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs:
|
||||||
|
self._highlight_function_argument(arg)
|
||||||
|
for body_stmt in stmt.body:
|
||||||
|
body_stmt.accept(self)
|
||||||
|
|
||||||
|
def _highlight_function_argument(self, arg: p.Function.Argument) -> None:
|
||||||
|
self.wrap(arg, "argument")
|
||||||
|
if arg.type is not None:
|
||||||
|
arg.type.accept(self)
|
||||||
|
|
||||||
|
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||||
|
stmt.type.accept(self)
|
||||||
|
|
||||||
|
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||||
|
for target in stmt.targets:
|
||||||
|
target.accept(self)
|
||||||
|
stmt.value.accept(self)
|
||||||
|
|
||||||
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
|
self.wrap(stmt, "return")
|
||||||
|
if stmt.value is not None:
|
||||||
|
stmt.value.accept(self)
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||||
|
self.wrap(stmt, "if")
|
||||||
|
stmt.test.accept(self)
|
||||||
|
for body_stmt in stmt.body:
|
||||||
|
body_stmt.accept(self)
|
||||||
|
for else_stmt in stmt.orelse:
|
||||||
|
else_stmt.accept(self)
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||||
|
self.wrap(expr, "call")
|
||||||
|
expr.callee.accept(self)
|
||||||
|
for arg in expr.arguments:
|
||||||
|
arg.accept(self)
|
||||||
|
for arg in expr.keywords.values():
|
||||||
|
arg.accept(self)
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: p.GetExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: p.LiteralExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: p.VariableExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_set_expr(self, expr: p.SetExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
|
||||||
|
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
|
||||||
|
|
||||||
|
def highlight(self, node: Highlightable[MidasHighlighter]):
|
||||||
|
node.accept(self)
|
||||||
|
|
||||||
|
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
|
||||||
|
self.wrap(stmt, "simple-type")
|
||||||
|
if stmt.template is not None:
|
||||||
|
stmt.template.accept(self)
|
||||||
|
stmt.base.accept(self)
|
||||||
|
if stmt.constraint is not None:
|
||||||
|
self.wrap(stmt.constraint, "constraint")
|
||||||
|
stmt.constraint.accept(self)
|
||||||
|
|
||||||
|
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None:
|
||||||
|
self.wrap(stmt, "complex-type")
|
||||||
|
if stmt.template is not None:
|
||||||
|
stmt.template.accept(self)
|
||||||
|
for prop in stmt.properties:
|
||||||
|
prop.accept(self)
|
||||||
|
|
||||||
|
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
|
||||||
|
self.wrap(stmt, "property")
|
||||||
|
stmt.type.accept(self)
|
||||||
|
if stmt.constraint is not None:
|
||||||
|
self.wrap(stmt.constraint, "constraint")
|
||||||
|
stmt.constraint.accept(self)
|
||||||
|
|
||||||
|
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||||
|
self.wrap(stmt, "extend")
|
||||||
|
stmt.type.accept(self)
|
||||||
|
for op in stmt.operations:
|
||||||
|
op.accept(self)
|
||||||
|
|
||||||
|
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
||||||
|
self.wrap(stmt, "op")
|
||||||
|
stmt.operand.accept(self)
|
||||||
|
stmt.result.accept(self)
|
||||||
|
|
||||||
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||||
|
self.wrap(stmt, "predicate")
|
||||||
|
stmt.type.accept(self)
|
||||||
|
stmt.condition.accept(self)
|
||||||
|
|
||||||
|
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None:
|
||||||
|
self.wrap(expr, "simple-type-expr")
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||||
|
self.wrap(expr, "logical-expr")
|
||||||
|
expr.left.accept(self)
|
||||||
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
|
||||||
|
self.wrap(expr, "binary-expr")
|
||||||
|
expr.left.accept(self)
|
||||||
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
||||||
|
self.wrap(expr, "unary-expr")
|
||||||
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
||||||
|
self.wrap(expr, "get-expr")
|
||||||
|
expr.expr.accept(self)
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
|
||||||
|
self.wrap(expr, "variable")
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||||
|
expr.expr.accept(self)
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
|
||||||
|
self.wrap(expr, "template")
|
||||||
|
expr.type.accept(self)
|
||||||
|
|
||||||
|
def visit_type_expr(self, expr: m.TypeExpr) -> None:
|
||||||
|
self.wrap(expr, "type")
|
||||||
|
if expr.template is not None:
|
||||||
|
expr.template.accept(self)
|
||||||
|
|
||||||
|
|
||||||
|
class DiagnosticsHighlighter(Highlighter):
|
||||||
|
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
||||||
|
|
||||||
|
def highlight(self, diagnostics: list[Diagnostic]):
|
||||||
|
for diagnostic in diagnostics:
|
||||||
|
self.wrap(diagnostic, str(diagnostic.type).lower(), diagnostic.message)
|
||||||
39
midas/cli/hl_diagnostic.css
Normal file
39
midas/cli/hl_diagnostic.css
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
span {
|
||||||
|
--opacity: 0.4;
|
||||||
|
|
||||||
|
&.error {
|
||||||
|
--col: 255, 0, 0;
|
||||||
|
}
|
||||||
|
&.warning {
|
||||||
|
--col: 250, 160, 0;
|
||||||
|
}
|
||||||
|
&.info {
|
||||||
|
--col: 150, 190, 250;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.with-msg {
|
||||||
|
position: relative;
|
||||||
|
|
||||||
|
.message {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
&:hover:not(:has(.with-msg:hover)) {
|
||||||
|
.message {
|
||||||
|
display: inline-block;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.message {
|
||||||
|
position: absolute;
|
||||||
|
top: calc(100% + 0.2em);
|
||||||
|
left: -.2em;
|
||||||
|
background-color: black;
|
||||||
|
color: white;
|
||||||
|
padding: 0.2em 0.4em;
|
||||||
|
border-radius: .2em;
|
||||||
|
z-index: 10;
|
||||||
|
width: 300%;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
55
midas/cli/hl_midas.css
Normal file
55
midas/cli/hl_midas.css
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
span {
|
||||||
|
&.comment {
|
||||||
|
--col: 200, 200, 200;
|
||||||
|
color: rgb(110, 110, 110);
|
||||||
|
font-style: italic;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.simple-type {
|
||||||
|
--col: 108, 233, 108;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.complex-type {
|
||||||
|
--col: 233, 206, 108;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.constraint {
|
||||||
|
--col: 233, 108, 108;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.property {
|
||||||
|
--col: 233, 108, 176;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.extend {
|
||||||
|
--col: 108, 197, 233;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.op {
|
||||||
|
--col: 108, 148, 233;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.predicate {
|
||||||
|
--col: 193, 108, 233;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.simple-type-expr {
|
||||||
|
--col: 150, 150, 150;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.logical-expr,
|
||||||
|
&.binary-expr,
|
||||||
|
&.unary-expr,
|
||||||
|
&.get-expr {
|
||||||
|
--col: 123, 215, 193;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.template {
|
||||||
|
--col: 163, 117, 71;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.type {
|
||||||
|
--col: 200, 200, 200;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
}
|
||||||
29
midas/cli/hl_python.css
Normal file
29
midas/cli/hl_python.css
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
span {
|
||||||
|
&.base-type {
|
||||||
|
--col: 108, 233, 108;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.param {
|
||||||
|
--col: 103, 192, 224;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.constraint-type {
|
||||||
|
--col: 174, 200, 195;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.frame-column {
|
||||||
|
--col: 216, 231, 81;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.frame-type {
|
||||||
|
--col: 231, 46, 40;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.function {
|
||||||
|
--col: 215, 103, 224;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.argument {
|
||||||
|
--col: 103, 192, 224;
|
||||||
|
}
|
||||||
|
}
|
||||||
180
midas/cli/main.py
Normal file
180
midas/cli/main.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, TextIO, get_args
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
|
||||||
|
from midas.checker.checker import Checker
|
||||||
|
from midas.checker.diagnostic import Diagnostic
|
||||||
|
from midas.checker.types import Type
|
||||||
|
from midas.cli.highlighter import (
|
||||||
|
DiagnosticsHighlighter,
|
||||||
|
Highlighter,
|
||||||
|
MidasHighlighter,
|
||||||
|
PythonHighlighter,
|
||||||
|
)
|
||||||
|
from midas.lexer.midas import MidasLexer
|
||||||
|
from midas.lexer.token import Token, TokenType
|
||||||
|
from midas.parser.midas import MidasParser
|
||||||
|
from midas.parser.python import PythonParser
|
||||||
|
from midas.resolver.resolver import Resolver
|
||||||
|
from midas.utils import UniversalJSONDumper
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def midas():
|
||||||
|
click.echo("Welcome to Midas!")
|
||||||
|
|
||||||
|
|
||||||
|
@midas.command()
|
||||||
|
@click.option("-l", "--highlight", type=click.File("w"))
|
||||||
|
@click.argument("file", type=click.File("r"))
|
||||||
|
def compile(highlight: Optional[TextIO], file: TextIO):
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
source: str = file.read()
|
||||||
|
tree: ast.Module = ast.parse(source, filename=file.name)
|
||||||
|
parser = PythonParser()
|
||||||
|
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||||
|
resolver = Resolver()
|
||||||
|
resolver.resolve(*stmts)
|
||||||
|
checker = Checker(resolver.locals, file_path=Path(file.name).resolve())
|
||||||
|
diagnostics: list[Diagnostic] = checker.check(stmts)
|
||||||
|
for diagnostic in diagnostics:
|
||||||
|
print(diagnostic)
|
||||||
|
|
||||||
|
print(
|
||||||
|
json.dumps(
|
||||||
|
UniversalJSONDumper.dump(
|
||||||
|
checker.global_env,
|
||||||
|
[("Environment", "_children")],
|
||||||
|
lambda obj: isinstance(obj, get_args(Type)),
|
||||||
|
),
|
||||||
|
indent=4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if highlight is not None:
|
||||||
|
highlighter = DiagnosticsHighlighter(source)
|
||||||
|
highlighter.highlight(diagnostics)
|
||||||
|
highlighter.dump(highlight)
|
||||||
|
|
||||||
|
|
||||||
|
@midas.group()
|
||||||
|
def utils():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def dump_python_ast(tree: ast.Module) -> str:
|
||||||
|
parser = PythonParser()
|
||||||
|
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||||
|
printer = PythonAstPrinter()
|
||||||
|
dump: str = ""
|
||||||
|
for stmt in stmts:
|
||||||
|
dump += printer.print(stmt)
|
||||||
|
dump += "\n"
|
||||||
|
return dump
|
||||||
|
|
||||||
|
|
||||||
|
def dump_midas_ast(source: str, filename: str) -> str:
|
||||||
|
lexer = MidasLexer(source, file=filename)
|
||||||
|
tokens: list[Token] = lexer.process()
|
||||||
|
parser = MidasParser(tokens)
|
||||||
|
stmts: list[m.Stmt] = parser.parse()
|
||||||
|
if len(parser.errors) != 0:
|
||||||
|
for err in parser.errors:
|
||||||
|
print(err.get_report())
|
||||||
|
raise RuntimeError("A parsing error occurred")
|
||||||
|
printer = MidasAstPrinter()
|
||||||
|
dump: str = ""
|
||||||
|
for stmt in stmts:
|
||||||
|
dump += printer.print(stmt)
|
||||||
|
dump += "\n"
|
||||||
|
return dump
|
||||||
|
|
||||||
|
|
||||||
|
@utils.command()
|
||||||
|
@click.option("-o", "--output", type=click.File("w"))
|
||||||
|
@click.option("-p", "--parse", is_flag=True)
|
||||||
|
@click.argument("file", type=click.File("r"))
|
||||||
|
def dump_ast(output: Optional[TextIO], parse: bool, file: TextIO):
|
||||||
|
source: str = file.read()
|
||||||
|
|
||||||
|
dump: str
|
||||||
|
if file.name.endswith(".py"):
|
||||||
|
tree: ast.Module = ast.parse(source, filename=file.name)
|
||||||
|
if parse:
|
||||||
|
dump = dump_python_ast(tree)
|
||||||
|
else:
|
||||||
|
dump = ast.dump(tree, indent=4)
|
||||||
|
elif file.name.endswith(".midas"):
|
||||||
|
dump = dump_midas_ast(source, file.name)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported file type")
|
||||||
|
|
||||||
|
if output is None:
|
||||||
|
click.echo(dump)
|
||||||
|
else:
|
||||||
|
output.write(dump)
|
||||||
|
|
||||||
|
|
||||||
|
def highlight_python(source: str, path: str) -> Highlighter:
|
||||||
|
tree: ast.Module = ast.parse(source, filename=path)
|
||||||
|
parser = PythonParser()
|
||||||
|
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||||
|
highlighter = PythonHighlighter(source)
|
||||||
|
for stmt in stmts:
|
||||||
|
highlighter.highlight(stmt)
|
||||||
|
return highlighter
|
||||||
|
|
||||||
|
|
||||||
|
def highlight_midas(source: str, path: str) -> Highlighter:
|
||||||
|
lexer = MidasLexer(source, file=path)
|
||||||
|
tokens: list[Token] = lexer.process()
|
||||||
|
parser = MidasParser(tokens)
|
||||||
|
stmts: list[m.Stmt] = parser.parse()
|
||||||
|
highlighter = MidasHighlighter(source)
|
||||||
|
for err in parser.errors:
|
||||||
|
print(err.get_report())
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LocatableToken:
|
||||||
|
token: Token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def location(self) -> Location:
|
||||||
|
return self.token.get_location()
|
||||||
|
|
||||||
|
for stmt in stmts:
|
||||||
|
highlighter.highlight(stmt)
|
||||||
|
for token in tokens:
|
||||||
|
if token.type == TokenType.COMMENT:
|
||||||
|
highlighter.wrap(LocatableToken(token), "comment")
|
||||||
|
elif token.is_keyword:
|
||||||
|
highlighter.wrap(LocatableToken(token), "keyword")
|
||||||
|
return highlighter
|
||||||
|
|
||||||
|
|
||||||
|
@utils.command()
|
||||||
|
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||||
|
@click.argument("file", type=click.File("r"))
|
||||||
|
def highlight(output: TextIO, file: TextIO):
|
||||||
|
source: str = file.read()
|
||||||
|
highlighter: Highlighter
|
||||||
|
|
||||||
|
if file.name.endswith(".py"):
|
||||||
|
highlighter = highlight_python(source, file.name)
|
||||||
|
elif file.name.endswith(".midas"):
|
||||||
|
highlighter = highlight_midas(source, file.name)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported file type")
|
||||||
|
highlighter.dump(output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
midas()
|
||||||
0
midas/lexer/__init__.py
Normal file
0
midas/lexer/__init__.py
Normal file
@@ -1,8 +1,15 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from lexer.position import Position
|
from midas.lexer.position import Position
|
||||||
from lexer.token import Token, TokenType
|
from midas.lexer.token import Token, TokenType
|
||||||
|
|
||||||
|
|
||||||
|
class MidasSyntaxError(Exception):
|
||||||
|
def __init__(self, pos: Position, message: str):
|
||||||
|
super().__init__(f"[ERROR] Error at {pos}: {message}")
|
||||||
|
self.pos: Position = pos
|
||||||
|
self.message: str = message
|
||||||
|
|
||||||
|
|
||||||
class Lexer(ABC):
|
class Lexer(ABC):
|
||||||
@@ -38,9 +45,9 @@ class Lexer(ABC):
|
|||||||
msg (str): the error message
|
msg (str): the error message
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SyntaxError
|
MidasSyntaxError
|
||||||
"""
|
"""
|
||||||
raise SyntaxError(f"[ERROR] Error at {self.start_pos}: {msg}")
|
raise MidasSyntaxError(self.start_pos, msg)
|
||||||
|
|
||||||
def process(self) -> list[Token]:
|
def process(self) -> list[Token]:
|
||||||
"""Scan tokens out of the source text
|
"""Scan tokens out of the source text
|
||||||
@@ -49,7 +56,7 @@ class Lexer(ABC):
|
|||||||
list[Token]: all the tokens that could be scanned
|
list[Token]: all the tokens that could be scanned
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SyntaxError: if a syntax error is found
|
MidasSyntaxError: if a syntax error is found
|
||||||
"""
|
"""
|
||||||
self.scan_tokens()
|
self.scan_tokens()
|
||||||
self.tokens.append(Token(TokenType.EOF, "", None, self.get_position()))
|
self.tokens.append(Token(TokenType.EOF, "", None, self.get_position()))
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
from lexer.base import Lexer
|
from midas.lexer.base import Lexer
|
||||||
from lexer.keyword import MIDAS_KEYWORDS
|
from midas.lexer.token import KEYWORDS, TokenType
|
||||||
from lexer.token import TokenType
|
|
||||||
|
|
||||||
|
|
||||||
class MidasLexer(Lexer):
|
class MidasLexer(Lexer):
|
||||||
@@ -31,30 +30,32 @@ class MidasLexer(Lexer):
|
|||||||
self.add_token(
|
self.add_token(
|
||||||
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
|
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
|
||||||
)
|
)
|
||||||
case "!":
|
case "!" if self.match("="):
|
||||||
if self.match("="):
|
self.add_token(TokenType.BANG_EQUAL)
|
||||||
self.add_token(TokenType.BANG_EQUAL)
|
|
||||||
else:
|
|
||||||
self.error("Unexpected single bang. Did you mean '!=' ?")
|
|
||||||
case ":":
|
case ":":
|
||||||
self.add_token(TokenType.COLON)
|
self.add_token(TokenType.COLON)
|
||||||
case ",":
|
case ".":
|
||||||
self.add_token(TokenType.COMMA)
|
self.add_token(TokenType.DOT)
|
||||||
case "_":
|
case "&":
|
||||||
|
self.add_token(TokenType.AND)
|
||||||
|
case "?":
|
||||||
|
self.add_token(TokenType.QMARK)
|
||||||
|
# case ",":
|
||||||
|
# self.add_token(TokenType.COMMA)
|
||||||
|
case "_" if not self.is_identifier_char(self.peek_next(), start=False):
|
||||||
self.add_token(TokenType.UNDERSCORE)
|
self.add_token(TokenType.UNDERSCORE)
|
||||||
case "+":
|
case "-" if self.match(">"):
|
||||||
self.add_token(TokenType.PLUS)
|
self.add_token(TokenType.ARROW)
|
||||||
|
# case "+":
|
||||||
|
# self.add_token(TokenType.PLUS)
|
||||||
case "-":
|
case "-":
|
||||||
self.add_token(TokenType.MINUS)
|
self.add_token(TokenType.MINUS)
|
||||||
case "*":
|
# case "*":
|
||||||
self.add_token(TokenType.STAR)
|
# self.add_token(TokenType.STAR)
|
||||||
case "/":
|
case "/" if self.match("/"):
|
||||||
if self.match("/"):
|
self.scan_comment()
|
||||||
self.scan_comment()
|
case "/" if self.match("*"):
|
||||||
elif self.match("*"):
|
self.scan_comment_multiline()
|
||||||
self.scan_comment_multiline()
|
|
||||||
else:
|
|
||||||
self.add_token(TokenType.SLASH)
|
|
||||||
case "\n":
|
case "\n":
|
||||||
self.add_token(TokenType.NEWLINE)
|
self.add_token(TokenType.NEWLINE)
|
||||||
case " " | "\r" | "\t":
|
case " " | "\r" | "\t":
|
||||||
@@ -69,7 +70,7 @@ class MidasLexer(Lexer):
|
|||||||
case _:
|
case _:
|
||||||
if char.isdigit():
|
if char.isdigit():
|
||||||
self.scan_number()
|
self.scan_number()
|
||||||
elif char.isalpha():
|
elif self.is_identifier_char(char, start=True):
|
||||||
self.scan_identifier()
|
self.scan_identifier()
|
||||||
else:
|
else:
|
||||||
self.error("Unexpected character")
|
self.error("Unexpected character")
|
||||||
@@ -98,11 +99,11 @@ class MidasLexer(Lexer):
|
|||||||
An identifier starts with a letter, followed by any number of
|
An identifier starts with a letter, followed by any number of
|
||||||
alphanumerical characters or underscores
|
alphanumerical characters or underscores
|
||||||
"""
|
"""
|
||||||
while self.peek().isalnum() or self.peek() == "_":
|
while self.is_identifier_char(self.peek(), start=False):
|
||||||
self.advance()
|
self.advance()
|
||||||
|
|
||||||
lexeme: str = self.source[self.start : self.idx]
|
lexeme: str = self.source[self.start : self.idx]
|
||||||
token_type: TokenType = MIDAS_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
token_type: TokenType = KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
||||||
self.add_token(token_type)
|
self.add_token(token_type)
|
||||||
|
|
||||||
def scan_comment(self):
|
def scan_comment(self):
|
||||||
@@ -129,3 +130,12 @@ class MidasLexer(Lexer):
|
|||||||
if not self.is_at_end():
|
if not self.is_at_end():
|
||||||
self.advance()
|
self.advance()
|
||||||
self.add_token(TokenType.COMMENT)
|
self.add_token(TokenType.COMMENT)
|
||||||
|
|
||||||
|
def is_identifier_char(self, char: str, *, start: bool) -> bool:
|
||||||
|
if char == "_":
|
||||||
|
return True
|
||||||
|
if char.isalpha():
|
||||||
|
return True
|
||||||
|
if not start and char.isdigit():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
@@ -5,6 +5,7 @@ from typing import Optional
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Position:
|
class Position:
|
||||||
"""A simple structure to store the position of a token"""
|
"""A simple structure to store the position of a token"""
|
||||||
|
|
||||||
file: Optional[str]
|
file: Optional[str]
|
||||||
line: int
|
line: int
|
||||||
column: int
|
column: int
|
||||||
104
midas/lexer/token.py
Normal file
104
midas/lexer/token.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.lexer.position import Position
|
||||||
|
|
||||||
|
|
||||||
|
class TokenType(Enum):
|
||||||
|
# Punctuation
|
||||||
|
LEFT_PAREN = auto()
|
||||||
|
RIGHT_PAREN = auto()
|
||||||
|
LEFT_BRACKET = auto()
|
||||||
|
RIGHT_BRACKET = auto()
|
||||||
|
LEFT_BRACE = auto()
|
||||||
|
RIGHT_BRACE = auto()
|
||||||
|
COLON = auto()
|
||||||
|
# COMMA = auto()
|
||||||
|
UNDERSCORE = auto()
|
||||||
|
ARROW = auto()
|
||||||
|
AND = auto()
|
||||||
|
QMARK = auto()
|
||||||
|
DOT = auto()
|
||||||
|
|
||||||
|
# Operators
|
||||||
|
# PLUS = auto()
|
||||||
|
MINUS = auto()
|
||||||
|
# STAR = auto()
|
||||||
|
# SLASH = auto()
|
||||||
|
GREATER = auto()
|
||||||
|
GREATER_EQUAL = auto()
|
||||||
|
LESS = auto()
|
||||||
|
LESS_EQUAL = auto()
|
||||||
|
EQUAL = auto()
|
||||||
|
EQUAL_EQUAL = auto()
|
||||||
|
BANG_EQUAL = auto()
|
||||||
|
|
||||||
|
# Literals
|
||||||
|
IDENTIFIER = auto()
|
||||||
|
NUMBER = auto()
|
||||||
|
TRUE = auto()
|
||||||
|
FALSE = auto()
|
||||||
|
NONE = auto()
|
||||||
|
|
||||||
|
# Keywords
|
||||||
|
TYPE = auto()
|
||||||
|
OP = auto()
|
||||||
|
PREDICATE = auto()
|
||||||
|
EXTEND = auto()
|
||||||
|
WHERE = auto()
|
||||||
|
|
||||||
|
# Misc
|
||||||
|
COMMENT = auto()
|
||||||
|
WHITESPACE = auto()
|
||||||
|
EOF = auto()
|
||||||
|
NEWLINE = auto()
|
||||||
|
|
||||||
|
|
||||||
|
KEYWORDS: dict[str, TokenType] = {
|
||||||
|
"type": TokenType.TYPE,
|
||||||
|
"op": TokenType.OP,
|
||||||
|
"predicate": TokenType.PREDICATE,
|
||||||
|
"extend": TokenType.EXTEND,
|
||||||
|
"where": TokenType.WHERE,
|
||||||
|
"true": TokenType.TRUE,
|
||||||
|
"false": TokenType.FALSE,
|
||||||
|
"none": TokenType.NONE,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Token:
|
||||||
|
"""A scanned token"""
|
||||||
|
|
||||||
|
type: TokenType
|
||||||
|
lexeme: str
|
||||||
|
value: Any
|
||||||
|
position: Position
|
||||||
|
|
||||||
|
def get_location(self) -> Location:
|
||||||
|
lineno: int = self.position.line
|
||||||
|
col_offset: int = self.position.column - 1
|
||||||
|
end_lineno = lineno
|
||||||
|
end_col_offset = col_offset
|
||||||
|
for c in self.lexeme:
|
||||||
|
end_col_offset += 1
|
||||||
|
if c == "\n":
|
||||||
|
end_lineno += 1
|
||||||
|
end_col_offset = 0
|
||||||
|
return Location(
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
end_lineno=end_lineno,
|
||||||
|
end_col_offset=end_col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
def location_to(self, to: Token) -> Location:
|
||||||
|
return Location.span(self.get_location(), to.get_location())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_keyword(self) -> bool:
|
||||||
|
return self.lexeme in KEYWORDS
|
||||||
@@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
from lexer.token import Token, TokenType
|
from midas.lexer.token import Token, TokenType
|
||||||
from parser.errors import ParsingError
|
from midas.parser.errors import ParsingError
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
419
midas/parser/midas.py
Normal file
419
midas/parser/midas.py
Normal file
@@ -0,0 +1,419 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.ast.midas import (
|
||||||
|
BinaryExpr,
|
||||||
|
ComplexTypeStmt,
|
||||||
|
Expr,
|
||||||
|
ExtendStmt,
|
||||||
|
GetExpr,
|
||||||
|
GroupingExpr,
|
||||||
|
LiteralExpr,
|
||||||
|
LogicalExpr,
|
||||||
|
OpStmt,
|
||||||
|
PredicateStmt,
|
||||||
|
PropertyStmt,
|
||||||
|
SimpleTypeExpr,
|
||||||
|
SimpleTypeStmt,
|
||||||
|
Stmt,
|
||||||
|
TemplateExpr,
|
||||||
|
TypeExpr,
|
||||||
|
UnaryExpr,
|
||||||
|
VariableExpr,
|
||||||
|
WildcardExpr,
|
||||||
|
)
|
||||||
|
from midas.lexer.token import Token, TokenType
|
||||||
|
from midas.parser.base import Parser
|
||||||
|
from midas.parser.errors import ParsingError
|
||||||
|
|
||||||
|
|
||||||
|
class MidasParser(Parser):
|
||||||
|
"""A simple parser for midas type definitions"""
|
||||||
|
|
||||||
|
SYNC_BOUNDARY: set[TokenType] = {
|
||||||
|
TokenType.TYPE,
|
||||||
|
TokenType.OP,
|
||||||
|
TokenType.EXTEND,
|
||||||
|
TokenType.PREDICATE,
|
||||||
|
}
|
||||||
|
|
||||||
|
def parse(self) -> list[Stmt]:
|
||||||
|
statements: list[Stmt] = []
|
||||||
|
while not self.is_at_end():
|
||||||
|
stmt: Optional[Stmt] = self.declaration()
|
||||||
|
if stmt is None:
|
||||||
|
print("Early stop")
|
||||||
|
break
|
||||||
|
statements.append(stmt)
|
||||||
|
return statements
|
||||||
|
|
||||||
|
def synchronize(self):
|
||||||
|
"""Skip tokens until a synchronization boundary is found
|
||||||
|
|
||||||
|
This method allows gracefully recovering from a parse error
|
||||||
|
to a safe place and continue parsing
|
||||||
|
"""
|
||||||
|
self.advance()
|
||||||
|
while not self.is_at_end():
|
||||||
|
if self.previous().type == TokenType.NEWLINE:
|
||||||
|
return
|
||||||
|
if self.peek().type in self.SYNC_BOUNDARY:
|
||||||
|
return
|
||||||
|
self.advance()
|
||||||
|
|
||||||
|
def declaration(self) -> Optional[Stmt]:
|
||||||
|
"""Try and parse a declaration
|
||||||
|
|
||||||
|
Any parsing error is caught and None is returned
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Stmt]: the parsed Midas statement, or None if a ParsingError was raised
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.match(TokenType.TYPE):
|
||||||
|
return self.type_declaration()
|
||||||
|
if self.match(TokenType.EXTEND):
|
||||||
|
return self.extend_declaration()
|
||||||
|
if self.match(TokenType.PREDICATE):
|
||||||
|
return self.predicate_declaration()
|
||||||
|
raise self.error(self.peek(), "Unexpected token")
|
||||||
|
except ParsingError:
|
||||||
|
self.synchronize()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def type_declaration(self) -> SimpleTypeStmt | ComplexTypeStmt:
|
||||||
|
"""Parse a type declaration
|
||||||
|
|
||||||
|
A type declaration can either be a simple type alias or a new complex type.
|
||||||
|
In either case, it can have an optional template expression after its name, wrapped in brackets.
|
||||||
|
A simple type alias is derived from a base type expression, and can have a optional constraint expression preceded by the `where` keyword.
|
||||||
|
A full simple type alias is thus written:
|
||||||
|
```
|
||||||
|
type Name[Template](TypeExpr) where Condition
|
||||||
|
```
|
||||||
|
|
||||||
|
A new complex type has a set of properties which are named, have a type and an optional constraint expression (also preceded by the `where` keyword).
|
||||||
|
A full complex type definition is thus written:
|
||||||
|
```
|
||||||
|
type Name[Template] {
|
||||||
|
prop1: TypeExpr1 where Condition1
|
||||||
|
prop2: TypeExpr2 where Condition2
|
||||||
|
...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TypeStmt: the parsed type declaration statement
|
||||||
|
"""
|
||||||
|
keyword: Token = self.previous()
|
||||||
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||||
|
template: Optional[TemplateExpr] = None
|
||||||
|
if self.check(TokenType.LEFT_BRACKET):
|
||||||
|
template = self.template_expr()
|
||||||
|
|
||||||
|
if self.match(TokenType.LEFT_PAREN):
|
||||||
|
base: TypeExpr = self.type_expr()
|
||||||
|
self.consume(TokenType.RIGHT_PAREN, "Unclosed base type parenthesis")
|
||||||
|
constraint: Optional[Expr] = None
|
||||||
|
if self.match(TokenType.WHERE):
|
||||||
|
constraint = self.constraint()
|
||||||
|
return SimpleTypeStmt(
|
||||||
|
location=keyword.location_to(self.previous()),
|
||||||
|
name=name,
|
||||||
|
template=template,
|
||||||
|
base=base,
|
||||||
|
constraint=constraint,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
properties: list[PropertyStmt] = self.type_properties()
|
||||||
|
return ComplexTypeStmt(
|
||||||
|
location=keyword.location_to(self.previous()),
|
||||||
|
name=name,
|
||||||
|
template=template,
|
||||||
|
properties=properties,
|
||||||
|
)
|
||||||
|
|
||||||
|
def template_expr(self) -> TemplateExpr:
|
||||||
|
"""Parse a generic template expression
|
||||||
|
|
||||||
|
A template is written `[TypeExpr]`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TemplateExpr: the parsed template expression
|
||||||
|
"""
|
||||||
|
left: Token = self.consume(
|
||||||
|
TokenType.LEFT_BRACKET, "Missing '[' before template expression"
|
||||||
|
)
|
||||||
|
type: TypeExpr = self.type_expr()
|
||||||
|
right: Token = self.consume(
|
||||||
|
TokenType.RIGHT_BRACKET, "Missing ']' after template expression"
|
||||||
|
)
|
||||||
|
return TemplateExpr(location=left.location_to(right), type=type)
|
||||||
|
|
||||||
|
def type_expr(self) -> TypeExpr:
|
||||||
|
"""Parse a type expression
|
||||||
|
|
||||||
|
A type is an identifier, optionally followed by a template expression.
|
||||||
|
It can also optionally be followed by a '?' to indicate a nullable type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TypeExpr: the parsed type expression
|
||||||
|
"""
|
||||||
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||||
|
template: Optional[TemplateExpr] = None
|
||||||
|
if self.check(TokenType.LEFT_BRACKET):
|
||||||
|
template = self.template_expr()
|
||||||
|
optional: bool = self.match(TokenType.QMARK)
|
||||||
|
return TypeExpr(
|
||||||
|
location=name.location_to(self.previous()),
|
||||||
|
name=name,
|
||||||
|
template=template,
|
||||||
|
optional=optional,
|
||||||
|
)
|
||||||
|
|
||||||
|
def simple_type_expr(self) -> SimpleTypeExpr:
|
||||||
|
"""Parse a simple type expression
|
||||||
|
|
||||||
|
A simple type is just an identifier optionally followed by a '?'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SimpleTypeExpr: the parsed simple type expression
|
||||||
|
"""
|
||||||
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||||
|
optional: bool = self.match(TokenType.QMARK)
|
||||||
|
return SimpleTypeExpr(
|
||||||
|
location=name.location_to(self.previous()), name=name, optional=optional
|
||||||
|
)
|
||||||
|
|
||||||
|
def constraint(self) -> Expr:
|
||||||
|
"""Parse a constraint
|
||||||
|
|
||||||
|
A constraint is basically a logical predicate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Expr: the parsed constraint expression
|
||||||
|
"""
|
||||||
|
return self.and_()
|
||||||
|
|
||||||
|
def and_(self) -> Expr:
|
||||||
|
"""Parse a logical AND expression or a simpler expression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Expr: the parsed expression
|
||||||
|
"""
|
||||||
|
expr: Expr = self.equality()
|
||||||
|
while self.match(TokenType.AND):
|
||||||
|
operator: Token = self.previous()
|
||||||
|
right: Expr = self.equality()
|
||||||
|
location: Location = Location.span(expr.location, right.location)
|
||||||
|
expr = LogicalExpr(
|
||||||
|
location=location, left=expr, operator=operator, right=right
|
||||||
|
)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def equality(self) -> Expr:
|
||||||
|
"""Parse a logical equality expression or a simpler expression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Expr: the parsed expression
|
||||||
|
"""
|
||||||
|
expr: Expr = self.comparison()
|
||||||
|
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
|
||||||
|
operator: Token = self.previous()
|
||||||
|
right: Expr = self.comparison()
|
||||||
|
location: Location = Location.span(expr.location, right.location)
|
||||||
|
expr = BinaryExpr(
|
||||||
|
location=location, left=expr, operator=operator, right=right
|
||||||
|
)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def comparison(self) -> Expr:
|
||||||
|
"""Parse a logical comparison expression or a simpler expression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Expr: the parsed expression
|
||||||
|
"""
|
||||||
|
expr: Expr = self.unary()
|
||||||
|
while self.match(
|
||||||
|
TokenType.LESS,
|
||||||
|
TokenType.LESS_EQUAL,
|
||||||
|
TokenType.GREATER,
|
||||||
|
TokenType.GREATER_EQUAL,
|
||||||
|
):
|
||||||
|
operator: Token = self.previous()
|
||||||
|
right: Expr = self.unary()
|
||||||
|
location: Location = Location.span(expr.location, right.location)
|
||||||
|
expr = BinaryExpr(
|
||||||
|
location=location, left=expr, operator=operator, right=right
|
||||||
|
)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def unary(self) -> Expr:
|
||||||
|
"""Parse a unary expression or a simpler expression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Expr: the parsed expression
|
||||||
|
"""
|
||||||
|
if self.match(TokenType.MINUS):
|
||||||
|
operator: Token = self.previous()
|
||||||
|
right: Expr = self.unary()
|
||||||
|
location: Location = Location.span(operator.get_location(), right.location)
|
||||||
|
return UnaryExpr(location=location, operator=operator, right=right)
|
||||||
|
return self.reference()
|
||||||
|
|
||||||
|
def reference(self) -> Expr:
|
||||||
|
"""Parse an attribute access expression or a simpler expression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Expr: the parsed expression
|
||||||
|
"""
|
||||||
|
expr: Expr = self.primary()
|
||||||
|
while self.match(TokenType.DOT):
|
||||||
|
name: Token = self.consume(
|
||||||
|
TokenType.IDENTIFIER, "Expected property name after '.'"
|
||||||
|
)
|
||||||
|
location: Location = Location.span(expr.location, name.get_location())
|
||||||
|
expr = GetExpr(location=location, expr=expr, name=name)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def primary(self) -> Expr:
|
||||||
|
"""Parse a primary expression
|
||||||
|
|
||||||
|
This includes literals (booleans, numbers, etc.), wildcards, identifiers and grouped expressions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Expr: the parsed expression
|
||||||
|
"""
|
||||||
|
token: Token = self.peek()
|
||||||
|
if self.match(TokenType.FALSE):
|
||||||
|
return LiteralExpr(location=token.get_location(), value=False)
|
||||||
|
if self.match(TokenType.TRUE):
|
||||||
|
return LiteralExpr(location=token.get_location(), value=True)
|
||||||
|
if self.match(TokenType.NONE):
|
||||||
|
return LiteralExpr(location=token.get_location(), value=None)
|
||||||
|
|
||||||
|
if self.match(TokenType.NUMBER):
|
||||||
|
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||||
|
|
||||||
|
if self.match(TokenType.IDENTIFIER):
|
||||||
|
return VariableExpr(location=token.get_location(), name=token)
|
||||||
|
|
||||||
|
if self.match(TokenType.UNDERSCORE):
|
||||||
|
return WildcardExpr(location=token.get_location(), token=token)
|
||||||
|
|
||||||
|
if self.match(TokenType.LEFT_PAREN):
|
||||||
|
expr: Expr = self.constraint()
|
||||||
|
right: Token = self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
||||||
|
return GroupingExpr(location=token.location_to(right), expr=expr)
|
||||||
|
|
||||||
|
raise self.error(self.peek(), "Expected expression")
|
||||||
|
|
||||||
|
def type_properties(self) -> list[PropertyStmt]:
|
||||||
|
"""Parse a type definition body
|
||||||
|
|
||||||
|
A type definition body is a set of whitespace-separated
|
||||||
|
property statements enclosed in curly braces
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[PropertyStmt]: the parsed type properties
|
||||||
|
"""
|
||||||
|
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body")
|
||||||
|
properties: list[PropertyStmt] = []
|
||||||
|
names: set[str] = set()
|
||||||
|
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
||||||
|
prop: PropertyStmt = self.property_stmt()
|
||||||
|
if prop.name.lexeme in names:
|
||||||
|
raise self.error(prop.name, "Duplicate property")
|
||||||
|
names.add(prop.name.lexeme)
|
||||||
|
properties.append(prop)
|
||||||
|
self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
||||||
|
return properties
|
||||||
|
|
||||||
|
def property_stmt(self) -> PropertyStmt:
|
||||||
|
"""Parse a property statement
|
||||||
|
|
||||||
|
A type property statement is written `name: Type` or `name: Type where Condition`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PropertyStmt: the parsed property statement
|
||||||
|
"""
|
||||||
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
||||||
|
self.consume(TokenType.COLON, "Expected ':' after property name")
|
||||||
|
type: TypeExpr = self.type_expr()
|
||||||
|
constraint: Optional[Expr] = None
|
||||||
|
if self.match(TokenType.WHERE):
|
||||||
|
constraint = self.constraint()
|
||||||
|
return PropertyStmt(
|
||||||
|
location=name.location_to(self.previous()),
|
||||||
|
name=name,
|
||||||
|
type=type,
|
||||||
|
constraint=constraint,
|
||||||
|
)
|
||||||
|
|
||||||
|
def extend_declaration(self) -> ExtendStmt:
|
||||||
|
"""Parse an extension definition
|
||||||
|
|
||||||
|
An extension is written `extend Type { operations }`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExtendStmt: the parsed extension statement
|
||||||
|
"""
|
||||||
|
keyword: Token = self.previous()
|
||||||
|
type: TypeExpr = self.type_expr()
|
||||||
|
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
||||||
|
operations: list[OpStmt] = []
|
||||||
|
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
|
||||||
|
operations.append(self.op_declaration())
|
||||||
|
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
|
||||||
|
location: Location = keyword.location_to(self.previous())
|
||||||
|
return ExtendStmt(location=location, type=type, operations=operations)
|
||||||
|
|
||||||
|
def op_declaration(self) -> OpStmt:
|
||||||
|
"""Parse an operation definition
|
||||||
|
|
||||||
|
An operation is written `op name(Type) -> Type`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpStmt: the parsed operation statement
|
||||||
|
"""
|
||||||
|
keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword")
|
||||||
|
|
||||||
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
|
||||||
|
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
|
||||||
|
operand: TypeExpr = self.type_expr()
|
||||||
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
|
||||||
|
|
||||||
|
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||||
|
result: TypeExpr = self.type_expr()
|
||||||
|
|
||||||
|
return OpStmt(
|
||||||
|
location=keyword.location_to(self.previous()),
|
||||||
|
name=name,
|
||||||
|
operand=operand,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def predicate_declaration(self) -> PredicateStmt:
|
||||||
|
"""Parse a predicate declaration
|
||||||
|
|
||||||
|
A predicate is written `predicate Name(subject: Type) = constraint_expression`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PredicateStmt: the parsed predicate declaration statement
|
||||||
|
"""
|
||||||
|
keyword: Token = self.previous()
|
||||||
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
|
||||||
|
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
||||||
|
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
|
||||||
|
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
||||||
|
type: TypeExpr = self.type_expr()
|
||||||
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
||||||
|
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
||||||
|
condition: Expr = self.constraint()
|
||||||
|
return PredicateStmt(
|
||||||
|
location=keyword.location_to(self.previous()),
|
||||||
|
name=name,
|
||||||
|
subject=subject,
|
||||||
|
type=type,
|
||||||
|
condition=condition,
|
||||||
|
)
|
||||||
492
midas/parser/python.py
Normal file
492
midas/parser/python.py
Normal file
@@ -0,0 +1,492 @@
|
|||||||
|
import ast
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.ast.python import (
|
||||||
|
AssignStmt,
|
||||||
|
BaseType,
|
||||||
|
BinaryExpr,
|
||||||
|
CallExpr,
|
||||||
|
CastExpr,
|
||||||
|
CompareExpr,
|
||||||
|
ConstraintType,
|
||||||
|
Expr,
|
||||||
|
ExpressionStmt,
|
||||||
|
FrameColumn,
|
||||||
|
FrameType,
|
||||||
|
Function,
|
||||||
|
GetExpr,
|
||||||
|
IfStmt,
|
||||||
|
LiteralExpr,
|
||||||
|
LogicalExpr,
|
||||||
|
MidasType,
|
||||||
|
ReturnStmt,
|
||||||
|
Stmt,
|
||||||
|
TernaryExpr,
|
||||||
|
TypeAssign,
|
||||||
|
UnaryExpr,
|
||||||
|
VariableExpr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSyntaxError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedSyntaxError(Exception):
|
||||||
|
def __init__(self, expr: ast.expr) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Unsupported syntax at L{expr.lineno}:{expr.col_offset}: {ast.unparse(expr)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PythonParser:
|
||||||
|
CAST_FUNCTION = "cast"
|
||||||
|
|
||||||
|
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
||||||
|
statements: list[Stmt] = []
|
||||||
|
for stmt in node.body:
|
||||||
|
try:
|
||||||
|
parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt)
|
||||||
|
if isinstance(parsed, Stmt):
|
||||||
|
statements.append(parsed)
|
||||||
|
elif parsed is not None:
|
||||||
|
statements.extend(parsed)
|
||||||
|
except UnsupportedSyntaxError as e:
|
||||||
|
print(f"{e}, skipping")
|
||||||
|
continue
|
||||||
|
return statements
|
||||||
|
|
||||||
|
def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]:
|
||||||
|
location: Location = Location.from_ast(node)
|
||||||
|
match node:
|
||||||
|
case ast.AnnAssign():
|
||||||
|
return self.parse_annotation_assign(node)
|
||||||
|
|
||||||
|
case ast.Assign():
|
||||||
|
return self.parse_assign(node)
|
||||||
|
|
||||||
|
case ast.AugAssign():
|
||||||
|
return self.parse_aug_assign(node)
|
||||||
|
|
||||||
|
case ast.FunctionDef():
|
||||||
|
return self.parse_function(node)
|
||||||
|
|
||||||
|
case ast.Expr(value=expr):
|
||||||
|
return ExpressionStmt(
|
||||||
|
location=location,
|
||||||
|
expr=self.parse_expr(expr),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ast.Return(value=value):
|
||||||
|
return ReturnStmt(
|
||||||
|
location=location,
|
||||||
|
value=self.parse_expr(value) if value is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
case ast.If():
|
||||||
|
return self.parse_if(node)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
print(f"Unsupported statement: {ast.unparse(node)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]:
|
||||||
|
statements: list[Stmt] = []
|
||||||
|
loc: Location = Location.from_ast(node)
|
||||||
|
match node:
|
||||||
|
case ast.AnnAssign(
|
||||||
|
target=ast.Name(id=target),
|
||||||
|
annotation=annotation,
|
||||||
|
value=value,
|
||||||
|
simple=1,
|
||||||
|
):
|
||||||
|
type = self._parse_type(annotation)
|
||||||
|
statements.append(
|
||||||
|
TypeAssign(
|
||||||
|
location=loc,
|
||||||
|
name=target,
|
||||||
|
type=type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if value is not None:
|
||||||
|
statements.append(
|
||||||
|
AssignStmt(
|
||||||
|
location=loc,
|
||||||
|
targets=[
|
||||||
|
VariableExpr(
|
||||||
|
location=Location.from_ast(node.target), name=target
|
||||||
|
),
|
||||||
|
],
|
||||||
|
value=self.parse_expr(value),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
print(f"Unsupported annotation: {ast.unparse(node)}")
|
||||||
|
return statements
|
||||||
|
|
||||||
|
def parse_assign(self, node: ast.Assign) -> AssignStmt:
|
||||||
|
targets: list[Expr] = []
|
||||||
|
for target in node.targets:
|
||||||
|
targets.append(self.parse_expr(target))
|
||||||
|
value: Expr = self.parse_expr(node.value)
|
||||||
|
return AssignStmt(
|
||||||
|
location=Location.from_ast(node),
|
||||||
|
targets=targets,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_aug_assign(self, node: ast.AugAssign) -> AssignStmt:
|
||||||
|
location: Location = Location.from_ast(node)
|
||||||
|
target: Expr = self.parse_expr(node.target)
|
||||||
|
value: Expr = self.parse_expr(node.value)
|
||||||
|
return AssignStmt(
|
||||||
|
location=location,
|
||||||
|
targets=[target],
|
||||||
|
value=BinaryExpr(
|
||||||
|
location=location,
|
||||||
|
left=target,
|
||||||
|
operator=node.op,
|
||||||
|
right=value,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_if(self, node: ast.If) -> IfStmt:
|
||||||
|
body: list[Stmt] = []
|
||||||
|
for stmt in node.body:
|
||||||
|
stmts = self.parse_stmt(stmt)
|
||||||
|
if isinstance(stmts, Stmt):
|
||||||
|
body.append(stmts)
|
||||||
|
elif stmts is not None:
|
||||||
|
body.extend(stmts)
|
||||||
|
|
||||||
|
orelse: list[Stmt] = []
|
||||||
|
for stmt in node.orelse:
|
||||||
|
stmts = self.parse_stmt(stmt)
|
||||||
|
if isinstance(stmts, Stmt):
|
||||||
|
orelse.append(stmts)
|
||||||
|
elif stmts is not None:
|
||||||
|
orelse.extend(stmts)
|
||||||
|
|
||||||
|
return IfStmt(
|
||||||
|
location=Location.from_ast(node),
|
||||||
|
test=self.parse_expr(node.test),
|
||||||
|
body=body,
|
||||||
|
orelse=orelse,
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_function(self, node: ast.FunctionDef) -> Function:
|
||||||
|
loc: Location = Location.from_ast(node)
|
||||||
|
match node:
|
||||||
|
case ast.FunctionDef(
|
||||||
|
name=name,
|
||||||
|
args=ast.arguments(
|
||||||
|
posonlyargs=posonlyargs,
|
||||||
|
args=args,
|
||||||
|
vararg=sink,
|
||||||
|
kwonlyargs=kwonlyargs,
|
||||||
|
kwarg=kw_sink,
|
||||||
|
defaults=defaults,
|
||||||
|
kw_defaults=kw_defaults,
|
||||||
|
),
|
||||||
|
returns=returns,
|
||||||
|
body=raw_body,
|
||||||
|
):
|
||||||
|
|
||||||
|
def parse_args(
|
||||||
|
args_list: list[ast.arg], defaults: list[Optional[Expr]]
|
||||||
|
) -> list[Function.Argument]:
|
||||||
|
return [
|
||||||
|
self._parse_function_argument(arg, default)
|
||||||
|
for arg, default in zip(args_list, defaults)
|
||||||
|
]
|
||||||
|
|
||||||
|
body: list[Stmt] = []
|
||||||
|
for stmt in raw_body:
|
||||||
|
stmts = self.parse_stmt(stmt)
|
||||||
|
if isinstance(stmts, Stmt):
|
||||||
|
body.append(stmts)
|
||||||
|
elif stmts is not None:
|
||||||
|
body.extend(stmts)
|
||||||
|
|
||||||
|
parsed_defaults: list[Optional[Expr]] = [
|
||||||
|
self.parse_expr(default) for default in defaults
|
||||||
|
]
|
||||||
|
n_posargs: int = len(posonlyargs)
|
||||||
|
n_args: int = len(args)
|
||||||
|
n_all_posargs = n_posargs + n_args
|
||||||
|
parsed_defaults = [
|
||||||
|
None,
|
||||||
|
] * (n_all_posargs - len(defaults)) + parsed_defaults
|
||||||
|
|
||||||
|
posargs_defaults: list[Optional[Expr]] = parsed_defaults[:n_posargs]
|
||||||
|
args_defaults: list[Optional[Expr]] = parsed_defaults[n_posargs:]
|
||||||
|
kwargs_defaults: list[Optional[Expr]] = [
|
||||||
|
self.parse_expr(default) if default is not None else None
|
||||||
|
for default in kw_defaults
|
||||||
|
]
|
||||||
|
|
||||||
|
return Function(
|
||||||
|
location=loc,
|
||||||
|
name=name,
|
||||||
|
posonlyargs=parse_args(posonlyargs, posargs_defaults),
|
||||||
|
args=parse_args(args, args_defaults),
|
||||||
|
sink=(
|
||||||
|
self._parse_function_argument(sink, None)
|
||||||
|
if sink is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
kwonlyargs=parse_args(kwonlyargs, kwargs_defaults),
|
||||||
|
kw_sink=(
|
||||||
|
self._parse_function_argument(kw_sink, None)
|
||||||
|
if kw_sink is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
returns=self._parse_type(returns) if returns is not None else None,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
print(f"Unsupported function definition: {ast.unparse(node)}")
|
||||||
|
|
||||||
|
def _parse_function_argument(
|
||||||
|
self, arg: ast.arg, default: Optional[Expr]
|
||||||
|
) -> Function.Argument:
|
||||||
|
loc: Location = Location.from_ast(arg)
|
||||||
|
name: str = arg.arg
|
||||||
|
type: Optional[MidasType] = None
|
||||||
|
if arg.annotation is not None:
|
||||||
|
type = self._parse_type(arg.annotation)
|
||||||
|
return Function.Argument(
|
||||||
|
location=loc,
|
||||||
|
name=name,
|
||||||
|
type=type,
|
||||||
|
default=default,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_type(self, type_expr: ast.expr) -> MidasType:
|
||||||
|
loc: Location = Location.from_ast(type_expr)
|
||||||
|
match type_expr:
|
||||||
|
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
|
||||||
|
return self._parse_frame_type(schema)
|
||||||
|
|
||||||
|
case ast.Subscript(value=ast.Name(id=name), slice=param):
|
||||||
|
return BaseType(
|
||||||
|
location=loc,
|
||||||
|
base=name,
|
||||||
|
param=self._parse_type(param),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ast.Name(id=name):
|
||||||
|
return BaseType(
|
||||||
|
location=loc,
|
||||||
|
base=name,
|
||||||
|
param=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
|
||||||
|
left = self._parse_type(left_expr)
|
||||||
|
match left:
|
||||||
|
case None:
|
||||||
|
raise InvalidSyntaxError()
|
||||||
|
|
||||||
|
# If chained constraints, separate base type and rebuild constraint
|
||||||
|
case ConstraintType(type=left_type, constraint=left_constraint):
|
||||||
|
constraint = ast.BinOp(
|
||||||
|
left=left_constraint,
|
||||||
|
op=ast.Add(),
|
||||||
|
right=right_expr,
|
||||||
|
)
|
||||||
|
ast.copy_location(constraint, type_expr)
|
||||||
|
return ConstraintType(
|
||||||
|
location=loc,
|
||||||
|
type=left_type,
|
||||||
|
constraint=constraint,
|
||||||
|
)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
return ConstraintType(
|
||||||
|
location=loc,
|
||||||
|
type=left,
|
||||||
|
constraint=right_expr,
|
||||||
|
)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise UnsupportedSyntaxError(type_expr)
|
||||||
|
|
||||||
|
def _parse_frame_type(self, schema: ast.expr) -> FrameType:
|
||||||
|
loc: Location = Location.from_ast(schema)
|
||||||
|
columns: list[FrameColumn] = []
|
||||||
|
|
||||||
|
match schema:
|
||||||
|
case ast.Tuple(elts=cols):
|
||||||
|
for col in cols:
|
||||||
|
columns.append(self._parse_frame_column(col))
|
||||||
|
|
||||||
|
case ast.Slice() | ast.Name():
|
||||||
|
columns.append(self._parse_frame_column(schema))
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise UnsupportedSyntaxError(schema)
|
||||||
|
|
||||||
|
return FrameType(location=loc, columns=columns)
|
||||||
|
|
||||||
|
def _parse_frame_column(self, column: ast.expr) -> FrameColumn:
|
||||||
|
loc: Location = Location.from_ast(column)
|
||||||
|
match column:
|
||||||
|
case ast.Name():
|
||||||
|
return FrameColumn(
|
||||||
|
location=loc,
|
||||||
|
name=None,
|
||||||
|
type=self._parse_type(column),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ast.Slice(lower=ast.Name(id=name), upper=type_expr):
|
||||||
|
if name == "_":
|
||||||
|
name = None
|
||||||
|
|
||||||
|
type: Optional[MidasType] = None
|
||||||
|
match type_expr:
|
||||||
|
case None:
|
||||||
|
raise InvalidSyntaxError("Missing column type")
|
||||||
|
case ast.Name(id="_"):
|
||||||
|
type = None
|
||||||
|
case ast.expr():
|
||||||
|
type = self._parse_type(type_expr)
|
||||||
|
case _:
|
||||||
|
raise UnsupportedSyntaxError(type_expr)
|
||||||
|
return FrameColumn(location=loc, name=name, type=type)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise UnsupportedSyntaxError(column)
|
||||||
|
|
||||||
|
def parse_expr(self, node: ast.expr) -> Expr:
|
||||||
|
location: Location = Location.from_ast(node)
|
||||||
|
match node:
|
||||||
|
case ast.BoolOp():
|
||||||
|
return self.parse_bool_op(node)
|
||||||
|
|
||||||
|
case ast.BinOp(left=left, op=op, right=right):
|
||||||
|
return BinaryExpr(
|
||||||
|
location=location,
|
||||||
|
left=self.parse_expr(left),
|
||||||
|
operator=op,
|
||||||
|
right=self.parse_expr(right),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ast.UnaryOp(op=op, operand=right):
|
||||||
|
return UnaryExpr(
|
||||||
|
location=location,
|
||||||
|
operator=op,
|
||||||
|
right=self.parse_expr(right),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ast.Compare():
|
||||||
|
return self.parse_compare(node)
|
||||||
|
|
||||||
|
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
|
||||||
|
return self.parse_cast(node)
|
||||||
|
|
||||||
|
case ast.Call():
|
||||||
|
return self.parse_call(node)
|
||||||
|
|
||||||
|
case ast.IfExp():
|
||||||
|
return self.parse_ternary(node)
|
||||||
|
|
||||||
|
case ast.Constant(value=value):
|
||||||
|
return LiteralExpr(location=location, value=value)
|
||||||
|
|
||||||
|
case ast.Attribute(value=object, attr=name):
|
||||||
|
return GetExpr(
|
||||||
|
location=location,
|
||||||
|
object=self.parse_expr(object),
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
case ast.Name(id=name):
|
||||||
|
return VariableExpr(location=location, name=name)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise UnsupportedSyntaxError(node)
|
||||||
|
|
||||||
|
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
|
||||||
|
op: ast.boolop = node.op
|
||||||
|
rights: list[Expr] = [self.parse_expr(expr) for expr in node.values]
|
||||||
|
expr: LogicalExpr = LogicalExpr(
|
||||||
|
location=Location.span(
|
||||||
|
rights[0].location,
|
||||||
|
rights[1].location,
|
||||||
|
),
|
||||||
|
left=rights[0],
|
||||||
|
operator=op,
|
||||||
|
right=rights[1],
|
||||||
|
)
|
||||||
|
for right in rights[2:]:
|
||||||
|
expr = LogicalExpr(
|
||||||
|
location=Location.span(expr.location, right.location),
|
||||||
|
left=expr,
|
||||||
|
operator=op,
|
||||||
|
right=right,
|
||||||
|
)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def parse_compare(self, node: ast.Compare) -> Expr:
|
||||||
|
ops: list[ast.cmpop] = node.ops
|
||||||
|
left: Expr = self.parse_expr(node.left)
|
||||||
|
rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators]
|
||||||
|
expr: Expr = CompareExpr(
|
||||||
|
location=Location.span(
|
||||||
|
left.location,
|
||||||
|
rights[0].location,
|
||||||
|
),
|
||||||
|
left=left,
|
||||||
|
operator=ops[0],
|
||||||
|
right=rights[0],
|
||||||
|
)
|
||||||
|
for i, right in enumerate(rights[1:]):
|
||||||
|
comparison = CompareExpr(
|
||||||
|
location=Location.span(rights[i].location, right.location),
|
||||||
|
left=rights[i],
|
||||||
|
operator=ops[i],
|
||||||
|
right=right,
|
||||||
|
)
|
||||||
|
expr = LogicalExpr(
|
||||||
|
location=Location.span(expr.location, comparison.location),
|
||||||
|
left=expr,
|
||||||
|
operator=ast.And(),
|
||||||
|
right=comparison,
|
||||||
|
)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def parse_cast(self, node: ast.Call) -> CastExpr:
|
||||||
|
match node:
|
||||||
|
case ast.Call(args=[type, expr], keywords=[]):
|
||||||
|
return CastExpr(
|
||||||
|
location=Location.from_ast(node),
|
||||||
|
type=self._parse_type(type),
|
||||||
|
expr=self.parse_expr(expr),
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
f"Invalid call to {self.CAST_FUNCTION}, expected type and expression"
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_call(self, node: ast.Call) -> CallExpr:
|
||||||
|
return CallExpr(
|
||||||
|
location=Location.from_ast(node),
|
||||||
|
callee=self.parse_expr(node.func),
|
||||||
|
arguments=[self.parse_expr(arg) for arg in node.args],
|
||||||
|
keywords={
|
||||||
|
arg.arg: self.parse_expr(arg.value)
|
||||||
|
for arg in node.keywords
|
||||||
|
if arg.arg is not None # Should always be True, type checker happy
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_ternary(self, node: ast.IfExp) -> TernaryExpr:
|
||||||
|
return TernaryExpr(
|
||||||
|
location=Location.from_ast(node),
|
||||||
|
test=self.parse_expr(node.test),
|
||||||
|
if_true=self.parse_expr(node.body),
|
||||||
|
if_false=self.parse_expr(node.orelse),
|
||||||
|
)
|
||||||
71
midas/resolver/builtin.py
Normal file
71
midas/resolver/builtin.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from midas.checker.types import BaseType, Type, UnitType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from midas.resolver.midas import MidasResolver
|
||||||
|
|
||||||
|
|
||||||
|
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
|
||||||
|
ctx.define_operation(
|
||||||
|
left=t1,
|
||||||
|
operator=operator,
|
||||||
|
right=t2,
|
||||||
|
result=t3,
|
||||||
|
)
|
||||||
|
|
||||||
|
def basic_op(ctx: MidasResolver, type: Type, op: str):
|
||||||
|
ctx.define_operation(
|
||||||
|
left=type,
|
||||||
|
operator=op,
|
||||||
|
right=type,
|
||||||
|
result=type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def define_builtins(ctx: MidasResolver):
|
||||||
|
"""Define builtin types and operations"""
|
||||||
|
unit = ctx.define_type("None", UnitType())
|
||||||
|
bool = ctx.define_type("bool", BaseType(name="bool"))
|
||||||
|
int = ctx.define_type("int", BaseType(name="int"))
|
||||||
|
float = ctx.define_type("float", BaseType(name="float"))
|
||||||
|
str = ctx.define_type("str", BaseType(name="str"))
|
||||||
|
|
||||||
|
basic_op(ctx, int, "__add__") # int + int = int
|
||||||
|
basic_op(ctx, int, "__sub__") # int - int = int
|
||||||
|
basic_op(ctx, int, "__mul__") # int * int = int
|
||||||
|
basic_op(ctx, int, "__pow__") # int ** int = int
|
||||||
|
basic_op(ctx, int, "__mod__") # int % int = int
|
||||||
|
basic_op(ctx, int, "__and__") # int & int = int
|
||||||
|
basic_op(ctx, int, "__or__") # int | int = int
|
||||||
|
basic_op(ctx, int, "__xor__") # int ^ int = int
|
||||||
|
op(ctx, int, "__lt__", int, bool) # int < int = bool
|
||||||
|
op(ctx, int, "__gt__", int, bool) # int > int = bool
|
||||||
|
op(ctx, int, "__le__", int, bool) # int <= int = bool
|
||||||
|
op(ctx, int, "__ge__", int, bool) # int >= int = bool
|
||||||
|
op(ctx, int, "__eq__", int, bool) # int == int = bool
|
||||||
|
basic_op(ctx, float, "__add__") # float + float = float
|
||||||
|
basic_op(ctx, float, "__sub__") # float - float = float
|
||||||
|
basic_op(ctx, float, "__mul__") # float * float = float
|
||||||
|
basic_op(ctx, float, "__truediv__") # float / float = float
|
||||||
|
op(ctx, float, "__lt__", float, bool) # float < float = bool
|
||||||
|
op(ctx, float, "__gt__", float, bool) # float > float = bool
|
||||||
|
op(ctx, float, "__le__", float, bool) # float <= float = bool
|
||||||
|
op(ctx, float, "__ge__", float, bool) # float >= float = bool
|
||||||
|
op(ctx, float, "__eq__", float, bool) # float == float = bool
|
||||||
|
basic_op(ctx, str, "__add__") # str + str = str
|
||||||
|
op(ctx, str, "__eq__", str, bool) # str == str = bool
|
||||||
|
|
||||||
|
op(ctx, int, "__lt__", float, bool) # int < float = bool
|
||||||
|
op(ctx, int, "__gt__", float, bool) # int > float = bool
|
||||||
|
op(ctx, int, "__le__", float, bool) # int <= float = bool
|
||||||
|
op(ctx, int, "__ge__", float, bool) # int >= float = bool
|
||||||
|
op(ctx, int, "__eq__", float, bool) # int == float = bool
|
||||||
|
|
||||||
|
op(ctx, float, "__lt__", int, bool) # float < int = bool
|
||||||
|
op(ctx, float, "__gt__", int, bool) # float > int = bool
|
||||||
|
op(ctx, float, "__le__", int, bool) # float <= int = bool
|
||||||
|
op(ctx, float, "__ge__", int, bool) # float >= int = bool
|
||||||
|
op(ctx, float, "__eq__", int, bool) # float == int = bool
|
||||||
153
midas/resolver/midas.py
Normal file
153
midas/resolver/midas.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
from midas.checker.types import BaseType, SimpleType, Type
|
||||||
|
from midas.resolver.builtin import define_builtins
|
||||||
|
|
||||||
|
|
||||||
|
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]):
|
||||||
|
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._types: dict[str, Type] = {}
|
||||||
|
self._operations: dict[tuple[Type, str, Type], Type] = {}
|
||||||
|
|
||||||
|
define_builtins(self)
|
||||||
|
|
||||||
|
def get_type(self, name: str) -> Type:
|
||||||
|
"""Get a type from its name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NameError: if the type is not defined
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the type
|
||||||
|
"""
|
||||||
|
type: Optional[Type] = self._types.get(name)
|
||||||
|
if type is None:
|
||||||
|
raise NameError(f"Undefined type {name}")
|
||||||
|
return type
|
||||||
|
|
||||||
|
def get_operation_result(
|
||||||
|
self, left: Type, operator: str, right: Type
|
||||||
|
) -> Optional[Type]:
|
||||||
|
"""Get the resulting type of an operation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
left (Type): the type of the left operand
|
||||||
|
operator (str): the operation name
|
||||||
|
right (Type): the type of the right operand
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Type]: the result type, or None if no matching operation was found
|
||||||
|
"""
|
||||||
|
operation: tuple[Type, str, Type] = (left, operator, right)
|
||||||
|
result: Optional[Type] = self._operations.get(operation)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def define_type(self, name: str, type: Type) -> Type:
|
||||||
|
"""Define a type in the registry
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the type
|
||||||
|
type (Type): the type to define
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if a type is already defined with that name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the defined type
|
||||||
|
"""
|
||||||
|
if name in self._types:
|
||||||
|
raise ValueError(f"Type {name} already defined")
|
||||||
|
self._types[name] = type
|
||||||
|
return type
|
||||||
|
|
||||||
|
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
|
||||||
|
"""Define an operation in the registry
|
||||||
|
|
||||||
|
Args:
|
||||||
|
left (Type): the type of the left operand
|
||||||
|
operator (str): the operation name
|
||||||
|
right (Type): the type of the right operand
|
||||||
|
result (Type): the result type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if an operation is already defined with these operands and name
|
||||||
|
"""
|
||||||
|
operation: tuple[Type, str, Type] = (left, operator, right)
|
||||||
|
if operation in self._operations:
|
||||||
|
raise ValueError(
|
||||||
|
f"Operation {operator} already defined between {left} and {right}"
|
||||||
|
)
|
||||||
|
self._operations[operation] = result
|
||||||
|
|
||||||
|
def resolve(self, stmts: list[m.Stmt]):
|
||||||
|
"""Process a sequence of statements
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stmts (list[m.Stmt]): the statements
|
||||||
|
"""
|
||||||
|
for stmt in stmts:
|
||||||
|
stmt.accept(self)
|
||||||
|
|
||||||
|
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
|
||||||
|
# TODO generics, optional, constraint
|
||||||
|
base: Type = self.get_type(stmt.base.name.lexeme)
|
||||||
|
match base:
|
||||||
|
case BaseType() | SimpleType():
|
||||||
|
type = SimpleType(
|
||||||
|
name=stmt.name.lexeme,
|
||||||
|
base=base,
|
||||||
|
)
|
||||||
|
self.define_type(type.name, type)
|
||||||
|
case _:
|
||||||
|
raise TypeError(f"Invalid base {base} for simple type")
|
||||||
|
|
||||||
|
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None: ...
|
||||||
|
|
||||||
|
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
||||||
|
|
||||||
|
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||||
|
base: Type = stmt.type.accept(self)
|
||||||
|
for op in stmt.operations:
|
||||||
|
right: Type = op.operand.accept(self)
|
||||||
|
result: Type = op.result.accept(self)
|
||||||
|
self.define_operation(
|
||||||
|
left=base,
|
||||||
|
operator=op.name.lexeme,
|
||||||
|
right=right,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
|
||||||
|
|
||||||
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
|
||||||
|
|
||||||
|
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> Type:
|
||||||
|
return self.get_type(expr.name.lexeme)
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: m.GetExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
|
||||||
|
return expr.expr.accept(self)
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_template_expr(self, expr: m.TemplateExpr) -> Type: ...
|
||||||
|
|
||||||
|
def visit_type_expr(self, expr: m.TypeExpr) -> Type:
|
||||||
|
return self.get_type(expr.name.lexeme)
|
||||||
187
midas/resolver/resolver.py
Normal file
187
midas/resolver/resolver.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
import midas.ast.python as p
|
||||||
|
|
||||||
|
|
||||||
|
class ResolverError(Exception): ...
|
||||||
|
|
||||||
|
|
||||||
|
class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||||
|
"""A variable assignment and reference resolver
|
||||||
|
|
||||||
|
This class keeps track of which scope a variable is defined in and which
|
||||||
|
scope is referred to when a variable is referenced
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.locals: dict[p.Expr, int] = {}
|
||||||
|
self.scopes: list[dict[str, bool]] = []
|
||||||
|
|
||||||
|
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
|
||||||
|
"""Resolve the given statements or expressions"""
|
||||||
|
|
||||||
|
for obj in objects:
|
||||||
|
obj.accept(self)
|
||||||
|
|
||||||
|
def begin_scope(self):
|
||||||
|
"""Begin a new scope inside the current one"""
|
||||||
|
self.scopes.append({})
|
||||||
|
|
||||||
|
def end_scope(self):
|
||||||
|
"""Close the current scope"""
|
||||||
|
self.scopes.pop()
|
||||||
|
|
||||||
|
def declare(self, name: str) -> None:
|
||||||
|
"""Declare a variable in the current scope
|
||||||
|
|
||||||
|
This method must be called *before* evaluating the variable initializer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the variable
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ResolverError: if the variable has already been declared in the current scope
|
||||||
|
"""
|
||||||
|
if len(self.scopes) == 0:
|
||||||
|
return
|
||||||
|
scope: dict[str, bool] = self.scopes[-1]
|
||||||
|
if name in scope:
|
||||||
|
raise ResolverError(
|
||||||
|
f"A variable with the name {name} is already declared in this scope"
|
||||||
|
)
|
||||||
|
scope[name] = False
|
||||||
|
|
||||||
|
def define(self, name: str) -> None:
|
||||||
|
"""Define a variable in the current scope
|
||||||
|
|
||||||
|
This method must be called *after* evaluating the variable initializer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the variable
|
||||||
|
"""
|
||||||
|
if len(self.scopes) == 0:
|
||||||
|
return
|
||||||
|
self.scopes[-1][name] = True
|
||||||
|
|
||||||
|
def resolve_local(self, expr: p.Expr, name: str) -> None:
|
||||||
|
"""Resolve a variable reference and store the scope distance
|
||||||
|
|
||||||
|
This method associates to the variable expression a number representing
|
||||||
|
the "distance" of the variable declaration, i.e. the number of scope
|
||||||
|
levels to go "up" to find the closest declaration for that variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr (p.Expr): the variable expression
|
||||||
|
name (str): the name of the variable
|
||||||
|
"""
|
||||||
|
for i, scope in enumerate(reversed(self.scopes)):
|
||||||
|
if name in scope:
|
||||||
|
self.locals[expr] = i
|
||||||
|
return
|
||||||
|
|
||||||
|
def resolve_function(self, function: p.Function) -> None:
|
||||||
|
"""Resolve a function definition
|
||||||
|
|
||||||
|
This method creates a new scope for the function, resolves all the
|
||||||
|
parameter declarations and then the body.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function (p.Function): the function to resolve
|
||||||
|
"""
|
||||||
|
self.begin_scope()
|
||||||
|
for param in function.all_args:
|
||||||
|
self.declare(param.name)
|
||||||
|
self.define(param.name)
|
||||||
|
self.resolve(*function.body)
|
||||||
|
self.end_scope()
|
||||||
|
|
||||||
|
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||||
|
stmt.expr.accept(self)
|
||||||
|
|
||||||
|
def visit_function(self, stmt: p.Function) -> None:
|
||||||
|
# Declare before resolving body to allow recursion
|
||||||
|
self.declare(stmt.name)
|
||||||
|
self.define(stmt.name)
|
||||||
|
self.resolve_function(stmt)
|
||||||
|
|
||||||
|
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||||
|
self.declare(stmt.name)
|
||||||
|
# NOTE: resolve type here?
|
||||||
|
self.define(stmt.name)
|
||||||
|
|
||||||
|
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||||
|
self.resolve(stmt.value)
|
||||||
|
for target in stmt.targets:
|
||||||
|
match target:
|
||||||
|
case p.VariableExpr(name=name):
|
||||||
|
self.resolve_local(target, name)
|
||||||
|
# TODO: declare if not found
|
||||||
|
case _:
|
||||||
|
raise Exception(f"Unsupported assignment to {target}")
|
||||||
|
|
||||||
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
|
if stmt.value is not None:
|
||||||
|
self.resolve(stmt.value)
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||||
|
# Not resolved in sub-environment because assignments in the test leak out of the if
|
||||||
|
# For example:
|
||||||
|
# if (m := 1 + 1) < 2:
|
||||||
|
# ...
|
||||||
|
# print(m) # <- m is still defined
|
||||||
|
self.resolve(stmt.test)
|
||||||
|
|
||||||
|
# Body
|
||||||
|
self.begin_scope()
|
||||||
|
self.resolve(*stmt.body)
|
||||||
|
self.end_scope()
|
||||||
|
|
||||||
|
# Else
|
||||||
|
self.begin_scope()
|
||||||
|
self.resolve(*stmt.orelse)
|
||||||
|
self.end_scope()
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||||
|
self.resolve(expr.left)
|
||||||
|
self.resolve(expr.right)
|
||||||
|
|
||||||
|
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||||
|
self.resolve(expr.left)
|
||||||
|
self.resolve(expr.right)
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||||
|
self.resolve(expr.right)
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||||
|
self.resolve(expr.callee)
|
||||||
|
for arg in expr.arguments:
|
||||||
|
self.resolve(arg)
|
||||||
|
for arg in expr.keywords.values():
|
||||||
|
self.resolve(arg)
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||||
|
self.resolve(expr.object)
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||||
|
if len(self.scopes) != 0 and self.scopes[-1].get(expr.name) is False:
|
||||||
|
raise ResolverError(
|
||||||
|
f"Cannot use local variable '{expr.name}' in its own initializer"
|
||||||
|
) # aka. UnboundLocalError
|
||||||
|
self.resolve_local(expr, expr.name)
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
|
||||||
|
self.resolve(expr.left)
|
||||||
|
self.resolve(expr.right)
|
||||||
|
|
||||||
|
def visit_set_expr(self, expr: p.SetExpr) -> None:
|
||||||
|
self.resolve(expr.value)
|
||||||
|
self.resolve(expr.object)
|
||||||
|
|
||||||
|
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||||
|
self.resolve(expr.expr)
|
||||||
|
|
||||||
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||||
|
self.resolve(expr.test)
|
||||||
|
self.resolve(expr.if_true)
|
||||||
|
self.resolve(expr.if_false)
|
||||||
54
midas/utils.py
Normal file
54
midas/utils.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
AllowRepeat = Callable[[object], bool]
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalJSONDumper:
|
||||||
|
@classmethod
|
||||||
|
def dump(
|
||||||
|
cls,
|
||||||
|
obj: Any,
|
||||||
|
include_keys: Optional[list[str | tuple[str, str]]] = None,
|
||||||
|
allow_repeat: Optional[AllowRepeat] = None,
|
||||||
|
) -> Any:
|
||||||
|
if include_keys is None:
|
||||||
|
include_keys = []
|
||||||
|
return cls._dump(obj, include_keys, allow_repeat, [])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _dump(
|
||||||
|
cls,
|
||||||
|
obj: Any,
|
||||||
|
include_keys: list[str | tuple[str, str]],
|
||||||
|
allow_repeat: Optional[AllowRepeat],
|
||||||
|
visited: list[Any],
|
||||||
|
) -> Any:
|
||||||
|
if obj in visited:
|
||||||
|
return None
|
||||||
|
match obj:
|
||||||
|
case str() | int() | float() | None:
|
||||||
|
return obj
|
||||||
|
case list() | set() | tuple():
|
||||||
|
return [
|
||||||
|
cls._dump(child, include_keys, allow_repeat, visited)
|
||||||
|
for child in obj
|
||||||
|
]
|
||||||
|
case dict():
|
||||||
|
return {
|
||||||
|
str(k): cls._dump(v, include_keys, allow_repeat, visited)
|
||||||
|
for k, v in obj.items()
|
||||||
|
}
|
||||||
|
case object():
|
||||||
|
if allow_repeat is None or not allow_repeat(obj):
|
||||||
|
visited.append(obj)
|
||||||
|
return {
|
||||||
|
"_type": obj.__class__.__name__,
|
||||||
|
} | {
|
||||||
|
k: cls._dump(v, include_keys, allow_repeat, visited)
|
||||||
|
for k, v in obj.__dict__.items()
|
||||||
|
if not k.startswith("_")
|
||||||
|
or k in include_keys
|
||||||
|
or (obj.__class__.__name__, k) in include_keys
|
||||||
|
}
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported value: {obj}")
|
||||||
@@ -1,152 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from core.ast.annotations import (
|
|
||||||
AnnotationStmt,
|
|
||||||
ConstraintExpr,
|
|
||||||
Expr,
|
|
||||||
LiteralExpr,
|
|
||||||
SchemaElementExpr,
|
|
||||||
SchemaExpr,
|
|
||||||
Stmt,
|
|
||||||
TypeExpr,
|
|
||||||
WildcardExpr,
|
|
||||||
)
|
|
||||||
from lexer.token import Token, TokenType
|
|
||||||
from parser.base import Parser
|
|
||||||
from parser.errors import ParsingError
|
|
||||||
|
|
||||||
|
|
||||||
class AnnotationParser(Parser):
|
|
||||||
"""A simple parser for custom type annotations"""
|
|
||||||
|
|
||||||
SYNC_BOUNDARY: set[TokenType] = set()
|
|
||||||
|
|
||||||
def parse(self) -> Optional[Stmt]:
|
|
||||||
stmt: Optional[Stmt] = None
|
|
||||||
try:
|
|
||||||
stmt = self.annotation()
|
|
||||||
except ParsingError:
|
|
||||||
self.synchronize()
|
|
||||||
if not self.is_at_end():
|
|
||||||
self.error(self.peek(), "Extra tokens")
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
def synchronize(self):
|
|
||||||
"""Skip tokens until a synchronization boundary is found
|
|
||||||
|
|
||||||
This method allows gracefully recovering from a parse error
|
|
||||||
to a safe place and continue parsing
|
|
||||||
"""
|
|
||||||
self.advance()
|
|
||||||
while not self.is_at_end():
|
|
||||||
if self.peek().type in self.SYNC_BOUNDARY:
|
|
||||||
return
|
|
||||||
self.advance()
|
|
||||||
|
|
||||||
def annotation(self) -> AnnotationStmt:
|
|
||||||
"""Parse an annotation
|
|
||||||
|
|
||||||
An annotation is written as `Type` or `Type[Schema]`
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AnnotationStmt: the parsed annotation statement
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type identifier")
|
|
||||||
schema: Optional[SchemaExpr] = None
|
|
||||||
if self.match(TokenType.LEFT_BRACKET):
|
|
||||||
schema = self.schema()
|
|
||||||
return AnnotationStmt(name=name, schema=schema)
|
|
||||||
|
|
||||||
def type_expr(self) -> TypeExpr:
|
|
||||||
"""Parse a type expression
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TypeExpr: the parsed type expression
|
|
||||||
"""
|
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
|
||||||
constraints: list[ConstraintExpr] = []
|
|
||||||
|
|
||||||
while not self.is_at_end() and self.match(TokenType.PLUS):
|
|
||||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before type constraint")
|
|
||||||
constraints.append(self.constraint_expr())
|
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after type constraint")
|
|
||||||
|
|
||||||
return TypeExpr(name=name, constraints=constraints)
|
|
||||||
|
|
||||||
def constraint_expr(self) -> ConstraintExpr:
|
|
||||||
"""Parse a type constraint
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ConstraintExpr: the parsed type constraint expression
|
|
||||||
"""
|
|
||||||
|
|
||||||
left: Expr = self.constraint_value()
|
|
||||||
op: Token = self.constraint_operator()
|
|
||||||
right: Expr = self.constraint_value()
|
|
||||||
return ConstraintExpr(left=left, op=op, right=right)
|
|
||||||
|
|
||||||
def constraint_value(self) -> Expr:
|
|
||||||
if self.match(TokenType.UNDERSCORE):
|
|
||||||
return WildcardExpr(self.previous())
|
|
||||||
return self.literal()
|
|
||||||
|
|
||||||
def literal(self) -> LiteralExpr:
|
|
||||||
if self.match(TokenType.FALSE):
|
|
||||||
return LiteralExpr(False)
|
|
||||||
if self.match(TokenType.TRUE):
|
|
||||||
return LiteralExpr(True)
|
|
||||||
if self.match(TokenType.NONE):
|
|
||||||
return LiteralExpr(None)
|
|
||||||
|
|
||||||
if self.match(TokenType.NUMBER):
|
|
||||||
return LiteralExpr(self.previous().value)
|
|
||||||
|
|
||||||
raise self.error(self.peek(), "Expected literal")
|
|
||||||
|
|
||||||
def constraint_operator(self) -> Token:
|
|
||||||
if self.match(TokenType.LESS, TokenType.LESS_EQUAL, TokenType.GREATER, TokenType.GREATER_EQUAL, TokenType.EQUAL_EQUAL, TokenType.BANG_EQUAL):
|
|
||||||
return self.previous()
|
|
||||||
raise self.error(self.peek(), "Expected constraint operator")
|
|
||||||
|
|
||||||
def schema(self) -> SchemaExpr:
|
|
||||||
"""Parse a schema definition
|
|
||||||
|
|
||||||
A comma separated list of schema elements
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SchemaExpr: the parsed schema expression
|
|
||||||
"""
|
|
||||||
left: Token = self.previous()
|
|
||||||
elements: list[Expr] = []
|
|
||||||
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
|
|
||||||
elements.append(self.schema_element())
|
|
||||||
if not self.check(TokenType.RIGHT_BRACKET):
|
|
||||||
self.consume(TokenType.COMMA, "Expected ',' between schema elements")
|
|
||||||
|
|
||||||
right: Token = self.consume(TokenType.RIGHT_BRACKET, "Unclosed schema")
|
|
||||||
return SchemaExpr(left=left, elements=elements, right=right)
|
|
||||||
|
|
||||||
def schema_element(self) -> SchemaElementExpr:
|
|
||||||
"""Parse a schema element
|
|
||||||
|
|
||||||
An anonymous element (`_`), a type, an untyped named column (`name: _`),
|
|
||||||
or a named column (`name: Type`)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SchemaElementExpr: the parsed schema element expression
|
|
||||||
"""
|
|
||||||
if self.match(TokenType.UNDERSCORE):
|
|
||||||
return SchemaElementExpr(name=None, type=None)
|
|
||||||
|
|
||||||
if not self.check(TokenType.IDENTIFIER):
|
|
||||||
raise self.error(self.peek(), "Expected schema element")
|
|
||||||
|
|
||||||
name: Optional[Token] = None
|
|
||||||
type: Optional[TypeExpr] = None
|
|
||||||
if self.check_next(TokenType.COLON):
|
|
||||||
name = self.advance()
|
|
||||||
self.advance()
|
|
||||||
if not self.match(TokenType.UNDERSCORE):
|
|
||||||
type = self.type_expr()
|
|
||||||
return SchemaElementExpr(name=name, type=type)
|
|
||||||
217
parser/midas.py
217
parser/midas.py
@@ -1,217 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from core.ast.midas import (
|
|
||||||
ConstraintExpr,
|
|
||||||
ConstraintStmt,
|
|
||||||
Expr,
|
|
||||||
LiteralExpr,
|
|
||||||
OpStmt,
|
|
||||||
PropertyStmt,
|
|
||||||
Stmt,
|
|
||||||
TypeBodyExpr,
|
|
||||||
TypeExpr,
|
|
||||||
TypeStmt,
|
|
||||||
WildcardExpr,
|
|
||||||
)
|
|
||||||
from lexer.token import Token, TokenType
|
|
||||||
from parser.base import Parser
|
|
||||||
from parser.errors import ParsingError
|
|
||||||
|
|
||||||
|
|
||||||
class MidasParser(Parser):
|
|
||||||
"""A simple parser for midas type definitions"""
|
|
||||||
|
|
||||||
SYNC_BOUNDARY: set[TokenType] = {TokenType.TYPE, TokenType.OP, TokenType.CONSTRAINT}
|
|
||||||
|
|
||||||
def parse(self) -> list[Stmt]:
|
|
||||||
statements: list[Stmt] = []
|
|
||||||
while not self.is_at_end():
|
|
||||||
stmt: Optional[Stmt] = self.declaration()
|
|
||||||
if stmt is None:
|
|
||||||
print("Early stop")
|
|
||||||
break
|
|
||||||
statements.append(stmt)
|
|
||||||
return statements
|
|
||||||
|
|
||||||
def synchronize(self):
|
|
||||||
"""Skip tokens until a synchronization boundary is found
|
|
||||||
|
|
||||||
This method allows gracefully recovering from a parse error
|
|
||||||
to a safe place and continue parsing
|
|
||||||
"""
|
|
||||||
self.advance()
|
|
||||||
while not self.is_at_end():
|
|
||||||
if self.previous().type == TokenType.NEWLINE:
|
|
||||||
return
|
|
||||||
if self.peek().type in self.SYNC_BOUNDARY:
|
|
||||||
return
|
|
||||||
self.advance()
|
|
||||||
|
|
||||||
def declaration(self) -> Optional[Stmt]:
|
|
||||||
"""Try and parse a declaration
|
|
||||||
|
|
||||||
Any parsing error is caught and None is returned
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Stmt]: the parsed Midas statement, or None if a ParsingError was raised
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if self.match(TokenType.TYPE):
|
|
||||||
return self.type_declaration()
|
|
||||||
if self.match(TokenType.OP):
|
|
||||||
return self.op_declaration()
|
|
||||||
if self.match(TokenType.CONSTRAINT):
|
|
||||||
return self.constraint_declaration()
|
|
||||||
raise self.error(self.peek(), "Unexpected token")
|
|
||||||
except ParsingError:
|
|
||||||
self.synchronize()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def type_declaration(self) -> TypeStmt:
|
|
||||||
"""Parse a type declaration
|
|
||||||
|
|
||||||
A type declaration is written `type Name<TypeExpr, ...>` optionally followed by a brace-wrapped body
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TypeStmt: the parsed type declaration statement
|
|
||||||
"""
|
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
|
||||||
self.consume(TokenType.LESS, "Expected '<' after type name")
|
|
||||||
bases: list[TypeExpr] = []
|
|
||||||
while not self.check(TokenType.GREATER) and not self.is_at_end():
|
|
||||||
bases.append(self.type_expr())
|
|
||||||
if not self.check(TokenType.GREATER):
|
|
||||||
self.consume(TokenType.COMMA, "Expected ',' between type bases")
|
|
||||||
self.consume(TokenType.GREATER, "Expected '>' after base type")
|
|
||||||
|
|
||||||
body: Optional[TypeBodyExpr] = None
|
|
||||||
|
|
||||||
if self.check(TokenType.LEFT_BRACE):
|
|
||||||
body = self.type_body_expr()
|
|
||||||
return TypeStmt(name=name, bases=bases, body=body)
|
|
||||||
|
|
||||||
def type_expr(self) -> TypeExpr:
|
|
||||||
"""Parse a type expression
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TypeExpr: the parsed type expression
|
|
||||||
"""
|
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
|
||||||
constraints: list[ConstraintExpr] = []
|
|
||||||
|
|
||||||
while not self.is_at_end() and self.match(TokenType.PLUS):
|
|
||||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before type constraint")
|
|
||||||
constraints.append(self.constraint_expr())
|
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after type constraint")
|
|
||||||
|
|
||||||
return TypeExpr(name=name, constraints=constraints)
|
|
||||||
|
|
||||||
def constraint_expr(self) -> ConstraintExpr:
|
|
||||||
"""Parse a type constraint
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ConstraintExpr: the parsed type constraint expression
|
|
||||||
"""
|
|
||||||
|
|
||||||
left: Expr = self.constraint_value()
|
|
||||||
op: Token = self.constraint_operator()
|
|
||||||
right: Expr = self.constraint_value()
|
|
||||||
return ConstraintExpr(left=left, op=op, right=right)
|
|
||||||
|
|
||||||
def constraint_value(self) -> Expr:
|
|
||||||
if self.match(TokenType.UNDERSCORE):
|
|
||||||
return WildcardExpr(self.previous())
|
|
||||||
return self.literal()
|
|
||||||
|
|
||||||
def literal(self) -> LiteralExpr:
|
|
||||||
if self.match(TokenType.FALSE):
|
|
||||||
return LiteralExpr(False)
|
|
||||||
if self.match(TokenType.TRUE):
|
|
||||||
return LiteralExpr(True)
|
|
||||||
if self.match(TokenType.NONE):
|
|
||||||
return LiteralExpr(None)
|
|
||||||
|
|
||||||
if self.match(TokenType.NUMBER):
|
|
||||||
return LiteralExpr(self.previous().value)
|
|
||||||
|
|
||||||
raise self.error(self.peek(), "Expected literal")
|
|
||||||
|
|
||||||
def constraint_operator(self) -> Token:
|
|
||||||
if self.match(
|
|
||||||
TokenType.LESS,
|
|
||||||
TokenType.LESS_EQUAL,
|
|
||||||
TokenType.GREATER,
|
|
||||||
TokenType.GREATER_EQUAL,
|
|
||||||
TokenType.EQUAL_EQUAL,
|
|
||||||
TokenType.BANG_EQUAL,
|
|
||||||
):
|
|
||||||
return self.previous()
|
|
||||||
raise self.error(self.peek(), "Expected constraint operator")
|
|
||||||
|
|
||||||
def type_body_expr(self) -> TypeBodyExpr:
|
|
||||||
"""Parse a type definition body
|
|
||||||
|
|
||||||
A type definition body is a set of whitespace-separated
|
|
||||||
property statements enclosed in curly braces
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TypeBodyExpr: the parsed type body expression
|
|
||||||
"""
|
|
||||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body")
|
|
||||||
properties: list[PropertyStmt] = []
|
|
||||||
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
|
||||||
properties.append(self.property_stmt())
|
|
||||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
|
||||||
return TypeBodyExpr(properties=properties)
|
|
||||||
|
|
||||||
def property_stmt(self) -> PropertyStmt:
|
|
||||||
"""Parse a property statement
|
|
||||||
|
|
||||||
A type property statement is written `name: Type`
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PropertyStmt: the parsed property statement
|
|
||||||
"""
|
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
|
||||||
self.consume(TokenType.COLON, "Expected ':' after property name")
|
|
||||||
type: TypeExpr = self.type_expr()
|
|
||||||
return PropertyStmt(name=name, type=type)
|
|
||||||
|
|
||||||
def op_declaration(self) -> OpStmt:
|
|
||||||
"""Parse an operation definition
|
|
||||||
|
|
||||||
An operation is written `op <Type1> operator <Type2> = <Type3>` where `operator` can be any single token
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OpStmt: the parsed operation statement
|
|
||||||
"""
|
|
||||||
self.consume(TokenType.LESS, "Expected '<' before first type")
|
|
||||||
left: TypeExpr = self.type_expr()
|
|
||||||
self.consume(TokenType.GREATER, "Expected '>' after first type")
|
|
||||||
|
|
||||||
op: Token = self.advance()
|
|
||||||
|
|
||||||
self.consume(TokenType.LESS, "Expected '<' before second type")
|
|
||||||
right: TypeExpr = self.type_expr()
|
|
||||||
self.consume(TokenType.GREATER, "Expected '>' after second type")
|
|
||||||
|
|
||||||
self.consume(TokenType.EQUAL, "Expected '=' after second type")
|
|
||||||
|
|
||||||
self.consume(TokenType.LESS, "Expected '<' before result type")
|
|
||||||
result: TypeExpr = self.type_expr()
|
|
||||||
self.consume(TokenType.GREATER, "Expected '>' after result type")
|
|
||||||
|
|
||||||
return OpStmt(left=left, op=op, right=right, result=result)
|
|
||||||
|
|
||||||
def constraint_declaration(self) -> ConstraintStmt:
|
|
||||||
"""Parse a type constraint declaration
|
|
||||||
|
|
||||||
A constraint is written `constraint Name = constraint_expression`
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ConstraintStmt: the parsed constraint declaration statement
|
|
||||||
"""
|
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected constraint name")
|
|
||||||
self.consume(TokenType.EQUAL, "Expected '=' after constraint name")
|
|
||||||
constraint: ConstraintExpr = self.constraint_expr()
|
|
||||||
return ConstraintStmt(name=name, constraint=constraint)
|
|
||||||
22
pyproject.toml
Normal file
22
pyproject.toml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
[project]
|
||||||
|
name = "midas"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "A static-first type checking framework for Python data-frames"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
authors = [
|
||||||
|
{ name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" },
|
||||||
|
]
|
||||||
|
classifiers = ["Programming Language :: Python :: 3"]
|
||||||
|
dependencies = ["click>=8.4.1"]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://git.kbk28.ch/HEL/midas"
|
||||||
|
Repository = "https://git.kbk28.ch/HEL/midas"
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
midas = "midas.cli.main:midas"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ['hatchling']
|
||||||
|
build-backend = 'hatchling.build'
|
||||||
@@ -1,26 +1,35 @@
|
|||||||
identifier ::= '[a-zA-Z][a-zA-Z_]*'
|
// W3C EBNF syntax definition for Midas
|
||||||
|
Identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
|
||||||
|
|
||||||
integer ::= '\d+'
|
Integer ::= '\d+'
|
||||||
number ::= integer ["." integer]
|
Number ::= "-"? Integer ("." Integer)?
|
||||||
boolean ::= "False" | "True"
|
Boolean ::= "False" | "True"
|
||||||
none ::= "None"
|
None ::= "None"
|
||||||
|
|
||||||
value ::= number | boolean | none
|
Value ::= Number | Boolean | None
|
||||||
lambda-value ::= "_" | value
|
|
||||||
lambda-operator ::= ">" | "<" | ">=" | "<=" | "==" | "!="
|
|
||||||
lambda ::= lambda-value lambda-operator lambda-value
|
|
||||||
|
|
||||||
constraint ::= identifier | "(" lambda ")"
|
ComparisonOp ::= ">" | "<" | ">=" | "<="
|
||||||
base-type ::= identifier
|
EqualityOp ::= "==" | "!="
|
||||||
type ::= base-type { "+" constraint }
|
|
||||||
|
|
||||||
type-property ::= 'identifier' ":" 'type'
|
Grouping ::= "(" Constraint ")"
|
||||||
type-body ::= "{" { 'type-property' } "}"
|
Primary ::= "_" | Value | Identifier | Grouping
|
||||||
|
Reference ::= Primary ("." Identifier)*
|
||||||
|
Unary ::= "-"? Unary | Reference
|
||||||
|
Comparison ::= Unary (ComparisonOp Unary)*
|
||||||
|
Equality ::= Comparison (EqualityOp Comparison)*
|
||||||
|
Constraint ::= Equality ("&" Equality)*
|
||||||
|
|
||||||
operation-type ::= "<" 'type' ">"
|
SimpleType ::= Identifier "?"?
|
||||||
|
Template ::= "[" Type "]"
|
||||||
|
Type ::= Identifier Template? "?"?
|
||||||
|
|
||||||
type-statement ::= "type" 'identifier' "<" 'type' {"," 'type'} ">" ['type-body']
|
TypeProperty ::= Identifier ":" Type ("where" Constraints)?
|
||||||
operation-statement ::= "op" 'operation-type' 'operator' 'operation-type' "=" 'operation-type'
|
ComplexTypeBody ::= "{" TypeProperty* "}"
|
||||||
constraint-statement ::= "constraint" 'identifier' "=" 'lambda'
|
OpDefinition ::= "op" Identifier "(" Type ")" "->" Type
|
||||||
|
ExtendBody ::= "{" OpDefinition* "}"
|
||||||
|
|
||||||
statement ::= type-statement | operation-statement | constraint-statement
|
TypeStatement ::= "type" Identifier Template? ("(" Type ")" ("where" Constraint)? | ComplexTypeBody)
|
||||||
|
ExtendStatement ::= "extend" Type ExtendBody
|
||||||
|
PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint
|
||||||
|
|
||||||
|
Statement ::= TypeStatement | ExtendStatement | PredicateStatement
|
||||||
|
|||||||
172
syntax/midas.typ
172
syntax/midas.typ
@@ -1,4 +1,11 @@
|
|||||||
#import "@preview/fervojo:0.1.1": render
|
#import "@preview/fervojo:0.1.1": default-css, render
|
||||||
|
|
||||||
|
#let extra-css = ```css
|
||||||
|
svg.railroad .terminal rect {
|
||||||
|
fill: #F7DCD4;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
#let css = default-css() + bytes(extra-css.text)
|
||||||
|
|
||||||
#let value = ```
|
#let value = ```
|
||||||
{[`value` <
|
{[`value` <
|
||||||
@@ -8,90 +15,157 @@
|
|||||||
>]}
|
>]}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let constraint = ```
|
#let grouping = ```
|
||||||
{[`constraint` <"_", 'value'> <">", "<", ">=", "<=", "==", "!="> <"_", 'value'>]}
|
{[`grouping` "(" 'constraint' ")"]}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let type-with-constraints = ```
|
#let primary = ```
|
||||||
{[`type-with-constraints` 'identifier' <!, ["+" "(" 'constraint' ")"] * !>]}
|
{[`primary` <"_", 'value', 'identifier', 'grouping'>]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let reference = ```
|
||||||
|
{[`reference` 'primary' <!, ["." 'identifier']*!>]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let unary = ```
|
||||||
|
{[`unary` <[<!, "-"> 'unary'], 'reference'>]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let comparison = ```
|
||||||
|
{[`comparison` 'unary'*<">", "<", ">=", "<=">]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let equality = ```
|
||||||
|
{[`equality` 'comparison'*<"==", "!=">]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let constraint = ```
|
||||||
|
{[`constraint` 'equality'*"&"]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let simple-type = ```
|
||||||
|
{[`simple-type` 'identifier' <!, "?">]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let template = ```
|
||||||
|
{[`template` "[" 'type' "]"]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let type = ```
|
||||||
|
{[`type` 'identifier' <!, 'template'> <!, "?">]}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let type-property = ```
|
#let type-property = ```
|
||||||
{[`type-property` 'identifier' ":" 'type-with-constraints']}
|
{[`type-property` 'identifier' ":" 'type' <!, ["where" 'constraint']>]}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let type-body = ```
|
#let type-body = ```
|
||||||
{[`type-body` "{" <!, 'type-property'*!> "}"]}
|
{[`type-body` "{" <!, 'type-property'*!> "}"]}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let operation-type = ```
|
|
||||||
{[`operation-type` "<" 'type-with-constraints' ">"]}
|
|
||||||
```
|
|
||||||
|
|
||||||
#let type-statement = ```
|
#let type-statement = ```
|
||||||
{[`type-statement` "type" 'identifier' "<" 'type-with-constraints'*"," ">" <!, 'type-body'>]}
|
{[`type-statement` "type" 'identifier' <!, 'template'> <[["(" 'type' ")"] <!, ["where" 'constraint']>], 'type-body'>]}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let operation-statement = ```
|
#let op-definition = ```
|
||||||
{[`operation-statement` "op" 'operation-type' "operator" 'operation-type' "=" 'operation-type']}
|
{[`op-definition` "op" 'identifier' "(" 'type' ")" "->" 'type']}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let constraint-statement = ```
|
#let extend-statement = ```
|
||||||
{[`constraint-statement` "constraint" 'identifier' "=" 'constraint']}
|
{[`extend-statement` "extend" 'type' "{" <!, 'op-definition'*!> "}"]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let predicate-statement = ```
|
||||||
|
{[`predicate-statement` "predicate" 'identifier' "(" 'identifier' ":" 'type' ")" "=" 'constraint']}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let statement = ```
|
#let statement = ```
|
||||||
{[`statement` <'type-statement', 'operation-statement', 'constraint-statement'>]}
|
{[`statement` <'type-statement', 'extend-statement', 'predicate-statement'>]}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let rules = (
|
#let rules = (
|
||||||
value,
|
value: value,
|
||||||
constraint,
|
grouping: grouping,
|
||||||
type-with-constraints,
|
primary: primary,
|
||||||
type-property,
|
reference: reference,
|
||||||
type-body,
|
unary: unary,
|
||||||
operation-type,
|
comparison: comparison,
|
||||||
type-statement,
|
equality: equality,
|
||||||
operation-statement,
|
constraint: constraint,
|
||||||
constraint-statement,
|
simple-type: simple-type,
|
||||||
statement,
|
template: template,
|
||||||
|
type: type,
|
||||||
|
type-property: type-property,
|
||||||
|
type-body: type-body,
|
||||||
|
type-statement: type-statement,
|
||||||
|
op-definition: op-definition,
|
||||||
|
extend-statement: extend-statement,
|
||||||
|
predicate-statement: predicate-statement,
|
||||||
|
statement: statement,
|
||||||
|
)
|
||||||
|
|
||||||
|
#let inline = (
|
||||||
|
"grouping",
|
||||||
|
"value",
|
||||||
|
"template",
|
||||||
|
"simple-type",
|
||||||
|
"type-property",
|
||||||
|
"type-body",
|
||||||
|
"op-definition",
|
||||||
|
"type-statement",
|
||||||
|
"extend-statement",
|
||||||
|
"predicate-statement",
|
||||||
)
|
)
|
||||||
|
|
||||||
#set text(font: "Source Sans 3")
|
#set text(font: "Source Sans 3")
|
||||||
|
|
||||||
= Midas type definition syntax
|
#title[Midas type definition syntax]
|
||||||
|
|
||||||
#for rule in rules {
|
= Outline
|
||||||
render(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
#box(
|
||||||
#let by-name = (
|
columns(
|
||||||
value: value,
|
2,
|
||||||
constraint: constraint,
|
outline(title: none),
|
||||||
type-with-constraints: type-with-constraints,
|
),
|
||||||
type-property: type-property,
|
height: 9cm,
|
||||||
type-body: type-body,
|
stroke: 1pt,
|
||||||
operation-type: operation-type,
|
inset: 1em,
|
||||||
type-statement: type-statement,
|
|
||||||
operation-statement: operation-statement,
|
|
||||||
constraint-statement: constraint-statement,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
= Statements and expressions
|
||||||
|
|
||||||
|
#for (name, rule) in rules.pairs().rev() {
|
||||||
|
[== #name]
|
||||||
|
render(rule, css: css)
|
||||||
|
}
|
||||||
|
|
||||||
#let substitute(base-rule) = {
|
#let substitute(base-rule) = {
|
||||||
let new-rule = base-rule
|
let new-rule = base-rule
|
||||||
for (key, rule) in by-name.pairs() {
|
for name in inline {
|
||||||
new-rule = new-rule.replace("'" + key + "'", rule.text.slice(1, -1))
|
let rule = rules.at(name)
|
||||||
|
let replacement = rule.text.slice(1, -1).replace(regex("\[`.*?`"), "[")
|
||||||
|
replacement = "[" + replacement + "#`" + name + "`]"
|
||||||
|
new-rule = new-rule.replace(
|
||||||
|
"'" + name + "'",
|
||||||
|
replacement,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if new-rule != base-rule {
|
if new-rule != base-rule {
|
||||||
new-rule = substitute(new-rule)
|
new-rule = substitute(new-rule)
|
||||||
}
|
}
|
||||||
return new-rule.replace(regex("`.*?`"), "")
|
return new-rule
|
||||||
}
|
}
|
||||||
|
|
||||||
#let combined = raw(substitute(statement.text))
|
|
||||||
|
|
||||||
|
|
||||||
#set page(flipped: true)
|
#set page(flipped: true)
|
||||||
#render(combined)
|
|
||||||
*/
|
= Combined rules
|
||||||
|
|
||||||
|
#for (name, rule) in rules.pairs() {
|
||||||
|
if not name in inline {
|
||||||
|
[== #name]
|
||||||
|
let combined = substitute(rule.text)
|
||||||
|
render(raw(combined), css: css)
|
||||||
|
//raw(block: true, combined)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
35
test.py
35
test.py
@@ -1,40 +1,21 @@
|
|||||||
import importlib
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from core.ast.printer import AnnotationAstPrinter, MidasAstPrinter
|
from midas.ast.printer import MidasAstPrinter
|
||||||
from lexer.annotations import AnnotationLexer
|
from midas.lexer.midas import MidasLexer
|
||||||
from lexer.midas import MidasLexer
|
from midas.lexer.token import Token
|
||||||
from lexer.token import Token
|
from midas.parser.midas import MidasParser
|
||||||
from parser.annotations import AnnotationParser
|
|
||||||
from parser.midas import MidasParser
|
|
||||||
|
|
||||||
|
|
||||||
def test_annotation():
|
|
||||||
# Frame annotation
|
|
||||||
mod = importlib.import_module("examples.00_syntax_prototype.01_simple_types")
|
|
||||||
|
|
||||||
annotation: str = mod.__annotations__["df"]
|
|
||||||
lexer: AnnotationLexer = AnnotationLexer(annotation, "01_simple_types.py")
|
|
||||||
tokens: list[Token] = lexer.process()
|
|
||||||
# print([f"{t.type.name}('{t.lexeme}')" for t in tokens])
|
|
||||||
|
|
||||||
parser = AnnotationParser(tokens)
|
|
||||||
parsed = parser.parse()
|
|
||||||
print(parsed)
|
|
||||||
for err in parser.errors:
|
|
||||||
print(err.get_report())
|
|
||||||
printer = AnnotationAstPrinter()
|
|
||||||
if parsed is not None:
|
|
||||||
print(printer.print(parsed))
|
|
||||||
|
|
||||||
|
|
||||||
def test_midas():
|
def test_midas():
|
||||||
# Midas type definitions
|
# Midas type definitions
|
||||||
path: Path = Path("examples") / "00_syntax_prototype" / "02_custom_types.midas"
|
path: Path = Path("examples") / "00_syntax_prototype" / "03_custom_types_v2.midas"
|
||||||
definitions: str = path.read_text()
|
definitions: str = path.read_text()
|
||||||
midas_lexer: MidasLexer = MidasLexer(definitions, path.name)
|
midas_lexer: MidasLexer = MidasLexer(definitions, path.name)
|
||||||
tokens: list[Token] = midas_lexer.process()
|
tokens: list[Token] = midas_lexer.process()
|
||||||
# print([f"{t.type.name}('{t.lexeme}')" for t in tokens])
|
# print([f"{t.type.name}('{t.lexeme}')" for t in tokens])
|
||||||
|
with open("tokens.json", "w") as f:
|
||||||
|
json.dump([f"{t.type.name}('{t.lexeme}')" for t in tokens], f, indent=4)
|
||||||
|
|
||||||
parser = MidasParser(tokens)
|
parser = MidasParser(tokens)
|
||||||
parsed = parser.parse()
|
parsed = parser.parse()
|
||||||
|
|||||||
143
tests/base.py
Normal file
143
tests/base.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import difflib
|
||||||
|
import sys
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterator, Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class CaseResult(Protocol):
|
||||||
|
def dumps(self) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
|
class Tester(ABC):
|
||||||
|
"""A test runner to check for regressions in the lexer and parser"""
|
||||||
|
|
||||||
|
CASES_DIR: Path = Path(__file__).parent / "cases"
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def namespace(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_dir(self) -> Path:
|
||||||
|
return self.CASES_DIR / self.namespace
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _list_tests(self) -> list[Path]: ...
|
||||||
|
|
||||||
|
def run_all_tests(self) -> bool:
|
||||||
|
paths: list[Path] = self._list_tests()
|
||||||
|
return self.run_tests(paths)
|
||||||
|
|
||||||
|
def run_tests(self, tests: list[Path]) -> bool:
|
||||||
|
rule: str = "-" * 80
|
||||||
|
n: int = len(tests)
|
||||||
|
successes: int = 0
|
||||||
|
failures: int = 0
|
||||||
|
|
||||||
|
print(rule)
|
||||||
|
for i, test in enumerate(tests):
|
||||||
|
print(f"Case {i+1}/{n}: {test.relative_to(self.CASES_DIR)}")
|
||||||
|
success: bool = self._run_test(test)
|
||||||
|
if success:
|
||||||
|
successes += 1
|
||||||
|
else:
|
||||||
|
failures += 1
|
||||||
|
|
||||||
|
print(rule)
|
||||||
|
print(f"Success: {successes}/{n}")
|
||||||
|
print(f"Failed: {failures}/{n}")
|
||||||
|
print(rule)
|
||||||
|
return failures == 0
|
||||||
|
|
||||||
|
def _run_test(self, path: Path) -> bool:
|
||||||
|
result_path: Path = self._result_path(path)
|
||||||
|
if not result_path.exists():
|
||||||
|
print("Missing snapshot. Please run the update command first")
|
||||||
|
return False
|
||||||
|
result: CaseResult = self._exec_case(path)
|
||||||
|
expected: str = result_path.read_text()
|
||||||
|
actual: str = result.dumps()
|
||||||
|
|
||||||
|
if expected == actual:
|
||||||
|
return True
|
||||||
|
|
||||||
|
diff = difflib.unified_diff(
|
||||||
|
expected.splitlines(keepends=True),
|
||||||
|
actual.splitlines(keepends=True),
|
||||||
|
fromfile="Snapshot",
|
||||||
|
tofile="Result",
|
||||||
|
)
|
||||||
|
self._print_diff(diff)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _exec_case(self, path: Path) -> CaseResult: ...
|
||||||
|
|
||||||
|
def update_all_tests(self):
|
||||||
|
paths: list[Path] = self._list_tests()
|
||||||
|
return self.update_tests(paths)
|
||||||
|
|
||||||
|
def update_tests(self, tests: list[Path]):
|
||||||
|
updated: int = 0
|
||||||
|
for test in tests:
|
||||||
|
if self._update_test(test):
|
||||||
|
updated += 1
|
||||||
|
print(f"Updated {updated}/{len(tests)} tests")
|
||||||
|
|
||||||
|
def _update_test(self, path: Path) -> bool:
|
||||||
|
result: CaseResult = self._exec_case(path)
|
||||||
|
result_path: Path = self._result_path(path)
|
||||||
|
current: str = result_path.read_text() if result_path.exists() else ""
|
||||||
|
new: str = result.dumps()
|
||||||
|
if current == new:
|
||||||
|
return False
|
||||||
|
result_path.write_text(new)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _result_path(self, test_path: Path) -> Path:
|
||||||
|
return test_path.parent / (test_path.name + ".ref.json")
|
||||||
|
|
||||||
|
def _print_diff(self, diff: Iterator[str]):
|
||||||
|
for line in diff:
|
||||||
|
if line.startswith("+") and not line.startswith("+++"):
|
||||||
|
print(f"\033[92m{line}\033[0m", end="")
|
||||||
|
elif line.startswith("-") and not line.startswith("---"):
|
||||||
|
print(f"\033[91m{line}\033[0m", end="")
|
||||||
|
else:
|
||||||
|
print(line, end="")
|
||||||
|
print()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def main(cls):
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
subparsers = parser.add_subparsers(dest="subcommand")
|
||||||
|
|
||||||
|
update = subparsers.add_parser("update")
|
||||||
|
update.add_argument("-a", "--all", action="store_true")
|
||||||
|
update.add_argument("FILE", type=Path, nargs="*")
|
||||||
|
|
||||||
|
run = subparsers.add_parser("run")
|
||||||
|
run.add_argument("-a", "--all", action="store_true")
|
||||||
|
run.add_argument("FILE", type=Path, nargs="*")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
tester: Tester = cls()
|
||||||
|
|
||||||
|
match args.subcommand:
|
||||||
|
case "update":
|
||||||
|
if args.all:
|
||||||
|
tester.update_all_tests()
|
||||||
|
else:
|
||||||
|
tester.update_tests(args.FILE)
|
||||||
|
case "run":
|
||||||
|
success: bool
|
||||||
|
if args.all:
|
||||||
|
success = tester.run_all_tests()
|
||||||
|
else:
|
||||||
|
success = tester.run_tests(args.FILE)
|
||||||
|
if not success:
|
||||||
|
sys.exit(1)
|
||||||
14
tests/cases/checker/01_simple_types.py
Normal file
14
tests/cases/checker/01_simple_types.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# type: ignore
|
||||||
|
# ruff: disable[F821]
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
df: Frame[
|
||||||
|
verified: bool,
|
||||||
|
birth_year: int,
|
||||||
|
height: float + ( _ > 0 ) + ( _ < 250 ),
|
||||||
|
name: str,
|
||||||
|
date: datetime,
|
||||||
|
float,
|
||||||
|
unknown: _,
|
||||||
|
_
|
||||||
|
]
|
||||||
3
tests/cases/checker/01_simple_types.py.ref.json
Normal file
3
tests/cases/checker/01_simple_types.py.ref.json
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"diagnostics": []
|
||||||
|
}
|
||||||
11
tests/cases/checker/02_simple_operations.py
Normal file
11
tests/cases/checker/02_simple_operations.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
a: int = 3
|
||||||
|
b: int = 4
|
||||||
|
|
||||||
|
c = a + b
|
||||||
|
|
||||||
|
c = "invalid"
|
||||||
|
|
||||||
|
d = True
|
||||||
|
e = d + d
|
||||||
|
|
||||||
|
f: float = a
|
||||||
46
tests/cases/checker/02_simple_operations.py.ref.json
Normal file
46
tests/cases/checker/02_simple_operations.py.ref.json
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
{
|
||||||
|
"diagnostics": [
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
6,
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
6,
|
||||||
|
13
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
9,
|
||||||
|
4
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
9,
|
||||||
|
9
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Undefined operation __add__ between BaseType(name='bool') and BaseType(name='bool')"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
11,
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
11,
|
||||||
|
12
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Cannot assign BaseType(name='int') to f of type BaseType(name='float')"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
18
tests/cases/checker/03_functions.py
Normal file
18
tests/cases/checker/03_functions.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
def foo(a: int, /, b: float, *, c: str):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
r1 = foo()
|
||||||
|
r2 = foo(1)
|
||||||
|
r3 = foo(1, 2.0)
|
||||||
|
r4 = foo(1, b=2.0)
|
||||||
|
r5 = foo(1, 2.0, "test")
|
||||||
|
r6 = foo(1, 2.0, b=3.0)
|
||||||
|
r7 = foo(a=1)
|
||||||
|
r8 = foo(g="test")
|
||||||
|
|
||||||
|
r9a = foo(1, 2.0, c="test")
|
||||||
|
r9b = foo(1, b=2.0, c="test")
|
||||||
|
r9c = foo(1, c="test", b=2.0)
|
||||||
|
|
||||||
|
r10 = foo("a", 3, c=False)
|
||||||
270
tests/cases/checker/03_functions.py.ref.json
Normal file
270
tests/cases/checker/03_functions.py.ref.json
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
{
|
||||||
|
"diagnostics": [
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
5,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
5,
|
||||||
|
10
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Missing required positional arguments: 'a' and 'b'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
5,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
5,
|
||||||
|
10
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Missing required keyword argument: 'c'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
6,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
6,
|
||||||
|
11
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Missing required positional argument: 'b'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
6,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
6,
|
||||||
|
11
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Missing required keyword argument: 'c'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
7,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
7,
|
||||||
|
16
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Missing required keyword argument: 'c'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
8,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
8,
|
||||||
|
18
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Missing required keyword argument: 'c'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
9,
|
||||||
|
17
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
9,
|
||||||
|
23
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Too many positional arguments"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
9,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
9,
|
||||||
|
24
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Missing required keyword argument: 'c'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
10,
|
||||||
|
19
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
10,
|
||||||
|
22
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Multiple values for argument 'b'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
10,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
10,
|
||||||
|
23
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Missing required keyword argument: 'c'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
11,
|
||||||
|
11
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
11,
|
||||||
|
12
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Unknown keyword argument 'a'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
11,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
11,
|
||||||
|
13
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Missing required positional arguments: 'a' and 'b'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
11,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
11,
|
||||||
|
13
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Missing required keyword argument: 'c'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
12,
|
||||||
|
11
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
12,
|
||||||
|
17
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Unknown keyword argument 'g'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
12,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
12,
|
||||||
|
18
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Missing required positional arguments: 'a' and 'b'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
12,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
12,
|
||||||
|
18
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Missing required keyword argument: 'c'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
18,
|
||||||
|
10
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
18,
|
||||||
|
13
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Wrong type for argument 'a', expected BaseType(name='int'), got BaseType(name='str')"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
18,
|
||||||
|
15
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
18,
|
||||||
|
16
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Wrong type for argument 'b', expected BaseType(name='float'), got BaseType(name='int')"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
18,
|
||||||
|
20
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
18,
|
||||||
|
25
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Wrong type for argument 'c', expected BaseType(name='str'), got BaseType(name='bool')"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
14
tests/cases/checker/04_custom_types.midas
Normal file
14
tests/cases/checker/04_custom_types.midas
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
type Meter(float)
|
||||||
|
type Second(float)
|
||||||
|
type MeterPerSecond(float)
|
||||||
|
|
||||||
|
extend Meter {
|
||||||
|
op __add__(Meter) -> Meter
|
||||||
|
op __sub__(Meter) -> Meter
|
||||||
|
op __truediv__(Second) -> MeterPerSecond
|
||||||
|
}
|
||||||
|
|
||||||
|
extend Second {
|
||||||
|
op __add__(Second) -> Second
|
||||||
|
op __sub__(Second) -> Second
|
||||||
|
}
|
||||||
8
tests/cases/checker/04_custom_types.py
Normal file
8
tests/cases/checker/04_custom_types.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# type: ignore
|
||||||
|
# ruff: disable [F821]
|
||||||
|
|
||||||
|
midas.using("04_custom_types.midas")
|
||||||
|
|
||||||
|
distance: Meter = cast(Meter, 123.45)
|
||||||
|
time: Second = cast(Second, 6.7)
|
||||||
|
speed = distance / time
|
||||||
3
tests/cases/checker/04_custom_types.py.ref.json
Normal file
3
tests/cases/checker/04_custom_types.py.ref.json
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"diagnostics": []
|
||||||
|
}
|
||||||
25
tests/cases/checker/05_control_flow.py
Normal file
25
tests/cases/checker/05_control_flow.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
def valid(a: int, b: int) -> int:
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
def with_if(a: int, b: int) -> int:
|
||||||
|
if a < b:
|
||||||
|
return b - a
|
||||||
|
else:
|
||||||
|
return a - b
|
||||||
|
|
||||||
|
def unreachable1():
|
||||||
|
return
|
||||||
|
a = 0
|
||||||
|
|
||||||
|
def unreachable2(a: int) -> int:
|
||||||
|
if a > 10:
|
||||||
|
return a - 10
|
||||||
|
else:
|
||||||
|
return a
|
||||||
|
b = 0
|
||||||
|
|
||||||
|
def mixed(a: int, b: int):
|
||||||
|
if a < b:
|
||||||
|
return b - a
|
||||||
|
else:
|
||||||
|
return "oops"
|
||||||
46
tests/cases/checker/05_control_flow.py.ref.json
Normal file
46
tests/cases/checker/05_control_flow.py.ref.json
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
{
|
||||||
|
"diagnostics": [
|
||||||
|
{
|
||||||
|
"type": "Warning",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
12,
|
||||||
|
4
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
12,
|
||||||
|
9
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Unreachable statement"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Warning",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
19,
|
||||||
|
4
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
19,
|
||||||
|
9
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Unreachable statement"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
21,
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
25,
|
||||||
|
21
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
57
tests/cases/midas-parser/01_simple_types.midas
Normal file
57
tests/cases/midas-parser/01_simple_types.midas
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
// Simple custom type derived from float
|
||||||
|
type Custom(float)
|
||||||
|
|
||||||
|
// Simple custom types with constraints
|
||||||
|
type Latitude(float) where (-90 <= _ <= 90)
|
||||||
|
type Longitude(float) where (-180 <= _ <= 180)
|
||||||
|
|
||||||
|
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
|
||||||
|
type Difference[T](T)
|
||||||
|
|
||||||
|
// Complex custom type, containing two values accessible through properties
|
||||||
|
type GeoLocation {
|
||||||
|
lat: Latitude
|
||||||
|
lon: Longitude
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define operations on our custom type
|
||||||
|
extend GeoLocation {
|
||||||
|
// This type is compatible with the `-` operation with another GeoLocation
|
||||||
|
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
|
||||||
|
// in a Difference of GeoLocations
|
||||||
|
op __sub__(GeoLocation) -> Difference[GeoLocation]
|
||||||
|
}
|
||||||
|
|
||||||
|
// For complex generics, you need to specify how the genericity the properties
|
||||||
|
// are handled
|
||||||
|
type Difference[GeoLocation] {
|
||||||
|
lat: Difference[Latitude]
|
||||||
|
lon: Difference[Longitude]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple operation defined on our custom types
|
||||||
|
extend Latitude {
|
||||||
|
op __sub__(Latitude) -> Difference[Latitude]
|
||||||
|
}
|
||||||
|
|
||||||
|
extend Longitude {
|
||||||
|
op __sub__(Longitude) -> Difference[Longitude]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Predefined custom predicates that can be referenced in other definitions
|
||||||
|
predicate Positive(v: float) = v >= 0
|
||||||
|
predicate StrictlyPositive(v: float) = v > 0
|
||||||
|
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
|
||||||
|
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
|
||||||
|
|
||||||
|
type Person {
|
||||||
|
name: str
|
||||||
|
|
||||||
|
// Property with an inline constraint
|
||||||
|
age: int? where (0 <= _ < 150)
|
||||||
|
|
||||||
|
// Property referencing a predicate
|
||||||
|
height: float where StrictlyPositive
|
||||||
|
|
||||||
|
home: GeoLocation
|
||||||
|
}
|
||||||
2659
tests/cases/midas-parser/01_simple_types.midas.ref.json
Normal file
2659
tests/cases/midas-parser/01_simple_types.midas.ref.json
Normal file
File diff suppressed because it is too large
Load Diff
14
tests/cases/python-parser/01_simple_types.py
Normal file
14
tests/cases/python-parser/01_simple_types.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# type: ignore
|
||||||
|
# ruff: disable[F821]
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
df: Frame[
|
||||||
|
verified: bool,
|
||||||
|
birth_year: int,
|
||||||
|
height: float + ( _ > 0 ) + ( _ < 250 ),
|
||||||
|
name: str,
|
||||||
|
date: datetime,
|
||||||
|
float,
|
||||||
|
unknown: _,
|
||||||
|
_
|
||||||
|
]
|
||||||
85
tests/cases/python-parser/01_simple_types.py.ref.json
Normal file
85
tests/cases/python-parser/01_simple_types.py.ref.json
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
{
|
||||||
|
"stmts": [
|
||||||
|
{
|
||||||
|
"_type": "TypeAssign",
|
||||||
|
"name": "df",
|
||||||
|
"type": {
|
||||||
|
"_type": "FrameType",
|
||||||
|
"columns": [
|
||||||
|
{
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": "verified",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "bool",
|
||||||
|
"param": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": "birth_year",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "int",
|
||||||
|
"param": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": "height",
|
||||||
|
"type": {
|
||||||
|
"_type": "ConstraintType",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "float",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"constraint": "(_ > 0) + (_ < 250)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": "name",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "str",
|
||||||
|
"param": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": "date",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "datetime",
|
||||||
|
"param": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": null,
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "float",
|
||||||
|
"param": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": "unknown",
|
||||||
|
"type": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": null,
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "_",
|
||||||
|
"param": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
29
tests/cases/python-parser/02_custom_types.py
Normal file
29
tests/cases/python-parser/02_custom_types.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# type: ignore
|
||||||
|
# ruff: disable[F821]
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import midas
|
||||||
|
|
||||||
|
midas.using("02_custom_types.midas")
|
||||||
|
|
||||||
|
df: Frame[
|
||||||
|
location: GeoLocation
|
||||||
|
]
|
||||||
|
|
||||||
|
lat: Column[GeoLocation] = df["location"].lat
|
||||||
|
lon: Column[GeoLocation] = df["location"].lon
|
||||||
|
|
||||||
|
lat + lon
|
||||||
|
|
||||||
|
lat1: Latitude = lat[0]
|
||||||
|
lat2: Latitude = lat[1]
|
||||||
|
lat_diff: Difference[Latitude] = lat2 - lat1
|
||||||
|
|
||||||
|
df2: Frame[
|
||||||
|
age: int + (_ >= 0),
|
||||||
|
height: float + (_ >= 0),
|
||||||
|
]
|
||||||
|
df2_bis: Frame[
|
||||||
|
age: int + Positive,
|
||||||
|
height: float + Positive,
|
||||||
|
]
|
||||||
162
tests/cases/python-parser/02_custom_types.py.ref.json
Normal file
162
tests/cases/python-parser/02_custom_types.py.ref.json
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
{
|
||||||
|
"stmts": [
|
||||||
|
{
|
||||||
|
"_type": "ExpressionStmt",
|
||||||
|
"expr": {
|
||||||
|
"_type": "CallExpr",
|
||||||
|
"callee": {
|
||||||
|
"_type": "GetExpr",
|
||||||
|
"object": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "midas"
|
||||||
|
},
|
||||||
|
"name": "using"
|
||||||
|
},
|
||||||
|
"arguments": [
|
||||||
|
{
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "02_custom_types.midas"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"keywords": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "TypeAssign",
|
||||||
|
"name": "df",
|
||||||
|
"type": {
|
||||||
|
"_type": "FrameType",
|
||||||
|
"columns": [
|
||||||
|
{
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": "location",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "GeoLocation",
|
||||||
|
"param": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "ExpressionStmt",
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "lat"
|
||||||
|
},
|
||||||
|
"operator": "+",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "lon"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "TypeAssign",
|
||||||
|
"name": "lat_diff",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "Difference",
|
||||||
|
"param": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "Latitude",
|
||||||
|
"param": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "AssignStmt",
|
||||||
|
"targets": [
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "lat_diff"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"value": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "lat2"
|
||||||
|
},
|
||||||
|
"operator": "-",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "lat1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "TypeAssign",
|
||||||
|
"name": "df2",
|
||||||
|
"type": {
|
||||||
|
"_type": "FrameType",
|
||||||
|
"columns": [
|
||||||
|
{
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": "age",
|
||||||
|
"type": {
|
||||||
|
"_type": "ConstraintType",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "int",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"constraint": "_ >= 0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": "height",
|
||||||
|
"type": {
|
||||||
|
"_type": "ConstraintType",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "float",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"constraint": "_ >= 0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "TypeAssign",
|
||||||
|
"name": "df2_bis",
|
||||||
|
"type": {
|
||||||
|
"_type": "FrameType",
|
||||||
|
"columns": [
|
||||||
|
{
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": "age",
|
||||||
|
"type": {
|
||||||
|
"_type": "ConstraintType",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "int",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"constraint": "Positive"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": "height",
|
||||||
|
"type": {
|
||||||
|
"_type": "ConstraintType",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "float",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"constraint": "Positive"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
15
tests/cases/python-parser/03_functions.py
Normal file
15
tests/cases/python-parser/03_functions.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# type: ignore
|
||||||
|
# ruff: disable[F821]
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
def func(
|
||||||
|
col1: Column[float + (0 <= _ <= 1)],
|
||||||
|
col2: Column[float + (0 <= _ <= 1)],
|
||||||
|
) -> Column[float + (0 <= _ <= 2)]:
|
||||||
|
result: Column[float + (0 <= _ <= 2)] = col1 + col2
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def func2(a: int, /, b: float, *, c: str):
|
||||||
|
pass
|
||||||
149
tests/cases/python-parser/03_functions.py.ref.json
Normal file
149
tests/cases/python-parser/03_functions.py.ref.json
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
{
|
||||||
|
"stmts": [
|
||||||
|
{
|
||||||
|
"_type": "Function",
|
||||||
|
"name": "func",
|
||||||
|
"posonlyargs": [],
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "col1",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "Column",
|
||||||
|
"param": {
|
||||||
|
"_type": "ConstraintType",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "float",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"constraint": "0 <= _ <= 1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"default": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "col2",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "Column",
|
||||||
|
"param": {
|
||||||
|
"_type": "ConstraintType",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "float",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"constraint": "0 <= _ <= 1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"default": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sink": null,
|
||||||
|
"kwonlyargs": [],
|
||||||
|
"kw_sink": null,
|
||||||
|
"returns": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "Column",
|
||||||
|
"param": {
|
||||||
|
"_type": "ConstraintType",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "float",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"constraint": "0 <= _ <= 2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"body": [
|
||||||
|
{
|
||||||
|
"_type": "TypeAssign",
|
||||||
|
"name": "result",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "Column",
|
||||||
|
"param": {
|
||||||
|
"_type": "ConstraintType",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "float",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"constraint": "0 <= _ <= 2"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "AssignStmt",
|
||||||
|
"targets": [
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"value": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "col1"
|
||||||
|
},
|
||||||
|
"operator": "+",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "col2"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "ReturnStmt",
|
||||||
|
"value": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "result"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "Function",
|
||||||
|
"name": "func2",
|
||||||
|
"posonlyargs": [
|
||||||
|
{
|
||||||
|
"name": "a",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "int",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"default": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "b",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "float",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"default": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sink": null,
|
||||||
|
"kwonlyargs": [
|
||||||
|
{
|
||||||
|
"name": "c",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "str",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"default": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw_sink": null,
|
||||||
|
"returns": null,
|
||||||
|
"body": []
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
67
tests/checker.py
Normal file
67
tests/checker.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import ast
|
||||||
|
import json
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.checker.checker import Checker
|
||||||
|
from midas.checker.diagnostic import Diagnostic
|
||||||
|
from midas.parser.python import PythonParser
|
||||||
|
from midas.resolver.resolver import Resolver
|
||||||
|
from tests.base import Tester
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CaseResult:
|
||||||
|
diagnostics: list[dict] = field(default_factory=list)
|
||||||
|
|
||||||
|
def dumps(self) -> str:
|
||||||
|
return json.dumps(asdict(self), indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
class CheckerTester(Tester):
|
||||||
|
@property
|
||||||
|
def namespace(self) -> str:
|
||||||
|
return "checker"
|
||||||
|
|
||||||
|
def _list_tests(self) -> list[Path]:
|
||||||
|
return list(self.base_dir.rglob("*.py"))
|
||||||
|
|
||||||
|
def _exec_case(self, path: Path) -> CaseResult:
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||||
|
if not path.is_file():
|
||||||
|
raise TypeError(f"Test '{path}' is not a file")
|
||||||
|
|
||||||
|
source: str = path.read_text()
|
||||||
|
tree: ast.Module = ast.parse(source, filename=path)
|
||||||
|
parser = PythonParser()
|
||||||
|
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||||
|
resolver = Resolver()
|
||||||
|
resolver.resolve(*stmts)
|
||||||
|
result: CaseResult = CaseResult()
|
||||||
|
checker = Checker(resolver.locals, file_path=path)
|
||||||
|
diagnostics: list[Diagnostic] = checker.check(stmts)
|
||||||
|
for diagnostic in diagnostics:
|
||||||
|
result.diagnostics.append(
|
||||||
|
{
|
||||||
|
"type": str(diagnostic.type),
|
||||||
|
"location": {
|
||||||
|
"start": (
|
||||||
|
diagnostic.location.lineno,
|
||||||
|
diagnostic.location.col_offset,
|
||||||
|
),
|
||||||
|
"end": (
|
||||||
|
diagnostic.location.end_lineno,
|
||||||
|
diagnostic.location.end_col_offset,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"message": diagnostic.message,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
CheckerTester.main()
|
||||||
@@ -1,129 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lexer.annotations import AnnotationLexer
|
|
||||||
from lexer.token import Token, TokenType
|
|
||||||
|
|
||||||
|
|
||||||
def scan(source: str) -> list[Token]:
|
|
||||||
return AnnotationLexer(source).process()
|
|
||||||
|
|
||||||
|
|
||||||
def assert_n_tokens(tokens: list[Token], n: int):
|
|
||||||
assert len(tokens) == n + 1
|
|
||||||
assert tokens[-1].type == TokenType.EOF
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("(", TokenType.LEFT_PAREN),
|
|
||||||
(")", TokenType.RIGHT_PAREN),
|
|
||||||
("[", TokenType.LEFT_BRACKET),
|
|
||||||
("]", TokenType.RIGHT_BRACKET),
|
|
||||||
(":", TokenType.COLON),
|
|
||||||
(",", TokenType.COMMA),
|
|
||||||
("_", TokenType.UNDERSCORE),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_punctuation(src: str, expected: TokenType):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("+", TokenType.PLUS),
|
|
||||||
(">", TokenType.GREATER),
|
|
||||||
(">=", TokenType.GREATER_EQUAL),
|
|
||||||
("<", TokenType.LESS),
|
|
||||||
("<=", TokenType.LESS_EQUAL),
|
|
||||||
("=", TokenType.EQUAL),
|
|
||||||
("==", TokenType.EQUAL_EQUAL),
|
|
||||||
("!=", TokenType.BANG_EQUAL),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_operators(src: str, expected: TokenType):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("a", TokenType.IDENTIFIER),
|
|
||||||
("foo", TokenType.IDENTIFIER),
|
|
||||||
("foo1", TokenType.IDENTIFIER),
|
|
||||||
("foo_", TokenType.IDENTIFIER),
|
|
||||||
("foo_bar1_baz2", TokenType.IDENTIFIER),
|
|
||||||
("FOO_BAR1_BAZ2", TokenType.IDENTIFIER),
|
|
||||||
("True", TokenType.TRUE),
|
|
||||||
("False", TokenType.FALSE),
|
|
||||||
("None", TokenType.NONE),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_identifiers_keywords(src: str, expected: TokenType):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("#", TokenType.COMMENT),
|
|
||||||
("# This is a comment", TokenType.COMMENT),
|
|
||||||
(" ", TokenType.WHITESPACE),
|
|
||||||
("\t", TokenType.WHITESPACE),
|
|
||||||
("\r", TokenType.WHITESPACE),
|
|
||||||
(" \t \t", TokenType.WHITESPACE),
|
|
||||||
("\n", TokenType.NEWLINE),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_misc(src: str, expected: TokenType):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected_type,expected_value",
|
|
||||||
[
|
|
||||||
("0", TokenType.NUMBER, 0),
|
|
||||||
("0.0", TokenType.NUMBER, 0),
|
|
||||||
("1234.56", TokenType.NUMBER, 1234.56),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_literals(src: str, expected_type: TokenType, expected_value: Any):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected_type
|
|
||||||
assert tokens[0].value == expected_value
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_bang_error():
|
|
||||||
with pytest.raises(SyntaxError):
|
|
||||||
scan("!")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src",
|
|
||||||
[
|
|
||||||
"-",
|
|
||||||
"*",
|
|
||||||
"/",
|
|
||||||
"{",
|
|
||||||
"}",
|
|
||||||
"@",
|
|
||||||
'"',
|
|
||||||
"'",
|
|
||||||
".",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_unexpected_character(src: str):
|
|
||||||
with pytest.raises(SyntaxError):
|
|
||||||
scan(src)
|
|
||||||
@@ -1,129 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lexer.midas import MidasLexer
|
|
||||||
from lexer.token import Token, TokenType
|
|
||||||
|
|
||||||
|
|
||||||
def scan(source: str) -> list[Token]:
|
|
||||||
return MidasLexer(source).process()
|
|
||||||
|
|
||||||
|
|
||||||
def assert_n_tokens(tokens: list[Token], n: int):
|
|
||||||
assert len(tokens) == n + 1
|
|
||||||
assert tokens[-1].type == TokenType.EOF
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("(", TokenType.LEFT_PAREN),
|
|
||||||
(")", TokenType.RIGHT_PAREN),
|
|
||||||
("[", TokenType.LEFT_BRACKET),
|
|
||||||
("]", TokenType.RIGHT_BRACKET),
|
|
||||||
("{", TokenType.LEFT_BRACE),
|
|
||||||
("}", TokenType.RIGHT_BRACE),
|
|
||||||
(":", TokenType.COLON),
|
|
||||||
(",", TokenType.COMMA),
|
|
||||||
("_", TokenType.UNDERSCORE),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_punctuation(src: str, expected: TokenType):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("+", TokenType.PLUS),
|
|
||||||
("-", TokenType.MINUS),
|
|
||||||
("*", TokenType.STAR),
|
|
||||||
("/", TokenType.SLASH),
|
|
||||||
(">", TokenType.GREATER),
|
|
||||||
(">=", TokenType.GREATER_EQUAL),
|
|
||||||
("<", TokenType.LESS),
|
|
||||||
("<=", TokenType.LESS_EQUAL),
|
|
||||||
("=", TokenType.EQUAL),
|
|
||||||
("==", TokenType.EQUAL_EQUAL),
|
|
||||||
("!=", TokenType.BANG_EQUAL),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_operators(src: str, expected: TokenType):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("a", TokenType.IDENTIFIER),
|
|
||||||
("foo", TokenType.IDENTIFIER),
|
|
||||||
("foo1", TokenType.IDENTIFIER),
|
|
||||||
("foo_", TokenType.IDENTIFIER),
|
|
||||||
("foo_bar1_baz2", TokenType.IDENTIFIER),
|
|
||||||
("FOO_BAR1_BAZ2", TokenType.IDENTIFIER),
|
|
||||||
("true", TokenType.TRUE),
|
|
||||||
("false", TokenType.FALSE),
|
|
||||||
("none", TokenType.NONE),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_identifiers_keywords(src: str, expected: TokenType):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("// This is a comment", TokenType.COMMENT),
|
|
||||||
("/* This is a comment */", TokenType.COMMENT),
|
|
||||||
(" ", TokenType.WHITESPACE),
|
|
||||||
("\t", TokenType.WHITESPACE),
|
|
||||||
("\r", TokenType.WHITESPACE),
|
|
||||||
(" \t \t", TokenType.WHITESPACE),
|
|
||||||
("\n", TokenType.NEWLINE),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_misc(src: str, expected: TokenType):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected_type,expected_value",
|
|
||||||
[
|
|
||||||
("0", TokenType.NUMBER, 0),
|
|
||||||
("0.0", TokenType.NUMBER, 0),
|
|
||||||
("1234.56", TokenType.NUMBER, 1234.56),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_literals(src: str, expected_type: TokenType, expected_value: Any):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected_type
|
|
||||||
assert tokens[0].value == expected_value
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_bang_error():
|
|
||||||
with pytest.raises(SyntaxError):
|
|
||||||
scan("!")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src",
|
|
||||||
[
|
|
||||||
"@",
|
|
||||||
'"',
|
|
||||||
"'",
|
|
||||||
".",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_unexpected_character(src: str):
|
|
||||||
with pytest.raises(SyntaxError):
|
|
||||||
scan(src)
|
|
||||||
82
tests/midas.py
Normal file
82
tests/midas.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import json
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.ast.midas import Stmt
|
||||||
|
from midas.lexer.base import MidasSyntaxError
|
||||||
|
from midas.lexer.midas import MidasLexer
|
||||||
|
from midas.lexer.token import Token
|
||||||
|
from midas.parser.midas import MidasParser
|
||||||
|
from tests.base import Tester
|
||||||
|
from tests.serializer.midas import MidasAstJsonSerializer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CaseResult:
|
||||||
|
tokens: Optional[list[dict]] = None
|
||||||
|
stmts: Optional[list[dict]] = None
|
||||||
|
errors: list[dict] = field(default_factory=list)
|
||||||
|
|
||||||
|
def dumps(self) -> str:
|
||||||
|
return json.dumps(asdict(self), indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
class MidasTester(Tester):
|
||||||
|
@property
|
||||||
|
def namespace(self) -> str:
|
||||||
|
return "midas-parser"
|
||||||
|
|
||||||
|
def _list_tests(self) -> list[Path]:
|
||||||
|
return list(self.base_dir.rglob("*.midas"))
|
||||||
|
|
||||||
|
def _exec_case(self, path: Path) -> CaseResult:
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||||
|
if not path.is_file():
|
||||||
|
raise TypeError(f"Test '{path}' is not a file")
|
||||||
|
|
||||||
|
result: CaseResult = CaseResult()
|
||||||
|
content: str = path.read_text()
|
||||||
|
lexer: MidasLexer = MidasLexer(content)
|
||||||
|
tokens: list[Token] = []
|
||||||
|
try:
|
||||||
|
tokens = lexer.process()
|
||||||
|
result.tokens = [
|
||||||
|
{
|
||||||
|
"type": token.type.name,
|
||||||
|
"lexeme": token.lexeme,
|
||||||
|
"line": token.position.line,
|
||||||
|
"column": token.position.column,
|
||||||
|
}
|
||||||
|
for token in tokens
|
||||||
|
]
|
||||||
|
except MidasSyntaxError as e:
|
||||||
|
result.errors.append(
|
||||||
|
{
|
||||||
|
"type": "SyntaxError",
|
||||||
|
"line": e.pos.line,
|
||||||
|
"column": e.pos.column,
|
||||||
|
"message": e.message,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
parser: MidasParser = MidasParser(tokens)
|
||||||
|
stmts: list[Stmt] = parser.parse()
|
||||||
|
result.stmts = MidasAstJsonSerializer().serialize(stmts)
|
||||||
|
result.errors.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"line": e.token.position.line,
|
||||||
|
"column": e.token.position.column,
|
||||||
|
"message": e.message,
|
||||||
|
}
|
||||||
|
for e in parser.errors
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
MidasTester.main()
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from core.ast.annotations import (
|
|
||||||
AnnotationStmt,
|
|
||||||
ConstraintExpr,
|
|
||||||
Expr,
|
|
||||||
LiteralExpr,
|
|
||||||
SchemaElementExpr,
|
|
||||||
SchemaExpr,
|
|
||||||
Stmt,
|
|
||||||
TypeExpr,
|
|
||||||
WildcardExpr,
|
|
||||||
)
|
|
||||||
from lexer.annotations import AnnotationLexer
|
|
||||||
from lexer.position import Position
|
|
||||||
from lexer.token import Token
|
|
||||||
from parser.annotations import AnnotationParser
|
|
||||||
|
|
||||||
|
|
||||||
class AstSerializer(Stmt.Visitor[str], Expr.Visitor[str]):
|
|
||||||
def serialize(self, stmt: Stmt):
|
|
||||||
return stmt.accept(self)
|
|
||||||
|
|
||||||
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> str:
|
|
||||||
schema: str = ""
|
|
||||||
if stmt.schema is not None:
|
|
||||||
schema = " " + stmt.schema.accept(self)
|
|
||||||
return f"(annotation {stmt.name.lexeme}{schema})"
|
|
||||||
|
|
||||||
def visit_schema_expr(self, expr: SchemaExpr) -> str:
|
|
||||||
elements: list[str] = [elmt.accept(self) for elmt in expr.elements]
|
|
||||||
return f"(schema {' '.join(elements)})"
|
|
||||||
|
|
||||||
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> str:
|
|
||||||
name: str = expr.name.lexeme if expr.name is not None else "_"
|
|
||||||
type: str = expr.type.accept(self) if expr.type is not None else "_"
|
|
||||||
return f"({name} {type})"
|
|
||||||
|
|
||||||
def visit_type_expr(self, expr: TypeExpr) -> str:
|
|
||||||
res: str = f"({expr.name.lexeme}"
|
|
||||||
for constraint in expr.constraints:
|
|
||||||
res += " " + constraint.accept(self)
|
|
||||||
res += ")"
|
|
||||||
return res
|
|
||||||
|
|
||||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> str:
|
|
||||||
return f"(constraint {expr.left.accept(self)} {expr.op.lexeme} {expr.right.accept(self)})"
|
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> str:
|
|
||||||
return "(_)"
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: LiteralExpr) -> str:
|
|
||||||
return f"({expr.value})"
|
|
||||||
|
|
||||||
|
|
||||||
def parse(source: str) -> Optional[Stmt]:
|
|
||||||
tokens: list[Token] = AnnotationLexer(source).process()
|
|
||||||
return AnnotationParser(tokens).parse()
|
|
||||||
|
|
||||||
|
|
||||||
def must_parse(source: str) -> Stmt:
|
|
||||||
stmt: Optional[Stmt] = parse(source)
|
|
||||||
assert stmt is not None
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
|
|
||||||
def ast_str(source: str) -> str:
|
|
||||||
stmt: Stmt = must_parse(source)
|
|
||||||
return AstSerializer().serialize(stmt)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("Type", "(annotation Type)"),
|
|
||||||
("Type[]", "(annotation Type (schema ))"),
|
|
||||||
(
|
|
||||||
"""
|
|
||||||
Frame[
|
|
||||||
verified: bool,
|
|
||||||
birth_year: int,
|
|
||||||
height: float + ( _ > 0 ) + ( _ < 250 ),
|
|
||||||
name: str,
|
|
||||||
date: datetime,
|
|
||||||
float, # unnamed
|
|
||||||
unknown: _, # untyped
|
|
||||||
_ # unnamed and untyped
|
|
||||||
]
|
|
||||||
""",
|
|
||||||
"(annotation Frame (schema (verified (bool)) (birth_year (int)) (height (float (constraint (_) > (0.0)) (constraint (_) < (250.0)))) (name (str)) (date (datetime)) (_ (float)) (unknown _) (_ _)))",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_expressions(src: str, expected: str):
|
|
||||||
assert ast_str(src) == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,pos,should_fail",
|
|
||||||
[
|
|
||||||
("", (1, 1), True),
|
|
||||||
("42", (1, 1), True),
|
|
||||||
("True", (1, 1), True),
|
|
||||||
("Type[", (1, 6), True),
|
|
||||||
("Type[] Type2", (1, 8), False),
|
|
||||||
("Type[bool:]", (1, 11), True),
|
|
||||||
("Type[3]", (1, 6), True),
|
|
||||||
("Type[bool float]", (1, 11), True),
|
|
||||||
("Type[bool (_ < 2)]", (1, 11), True),
|
|
||||||
("Type[bool + _ < 2)]", (1, 13), True),
|
|
||||||
("Type[bool + (_ < 2]", (1, 19), True),
|
|
||||||
("Type[bool + (< 2)]", (1, 14), True),
|
|
||||||
("Type[bool + (_ + 2)]", (1, 16), True),
|
|
||||||
("Type[bool + (Foo + Bar)]", (1, 14), True),
|
|
||||||
# ("Type[bool,]", (1, 11), True), # trailing comma is accepted, TODO: update parser or EBNF
|
|
||||||
("Type[bool, Type[]]", (1, 16), True),
|
|
||||||
("Type[foo: 3]", (1, 11), True),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_parsing_error(src: str, pos: tuple[int, int], should_fail: bool):
|
|
||||||
tokens: list[Token] = AnnotationLexer(src).process()
|
|
||||||
parser: AnnotationParser = AnnotationParser(tokens)
|
|
||||||
stmt: Optional[Stmt] = parser.parse()
|
|
||||||
if should_fail:
|
|
||||||
assert stmt is None
|
|
||||||
assert len(parser.errors) != 0
|
|
||||||
error_pos: Position = parser.errors[0].token.position
|
|
||||||
assert (error_pos.line, error_pos.column) == pos
|
|
||||||
@@ -1,202 +0,0 @@
|
|||||||
import textwrap
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from core.ast.midas import (
|
|
||||||
ConstraintExpr,
|
|
||||||
ConstraintStmt,
|
|
||||||
Expr,
|
|
||||||
LiteralExpr,
|
|
||||||
OpStmt,
|
|
||||||
PropertyStmt,
|
|
||||||
Stmt,
|
|
||||||
TypeBodyExpr,
|
|
||||||
TypeExpr,
|
|
||||||
TypeStmt,
|
|
||||||
WildcardExpr,
|
|
||||||
)
|
|
||||||
from lexer.midas import MidasLexer
|
|
||||||
from lexer.position import Position
|
|
||||||
from lexer.token import Token
|
|
||||||
from parser.midas import MidasParser
|
|
||||||
|
|
||||||
|
|
||||||
class AstSerializer(Stmt.Visitor[str], Expr.Visitor[str]):
|
|
||||||
def serialize(self, stmt: Stmt):
|
|
||||||
return stmt.accept(self)
|
|
||||||
|
|
||||||
def visit_type_stmt(self, stmt: TypeStmt) -> str:
|
|
||||||
res: str = f"(type_def {stmt.name.lexeme}"
|
|
||||||
for base in stmt.bases:
|
|
||||||
res += " " + base.accept(self)
|
|
||||||
if stmt.body is not None:
|
|
||||||
res += " " + stmt.body.accept(self)
|
|
||||||
res += ")"
|
|
||||||
return res
|
|
||||||
|
|
||||||
def visit_type_expr(self, expr: TypeExpr) -> str:
|
|
||||||
res: str = f"({expr.name.lexeme}"
|
|
||||||
for constraint in expr.constraints:
|
|
||||||
res += " " + constraint.accept(self)
|
|
||||||
res += ")"
|
|
||||||
return res
|
|
||||||
|
|
||||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> str:
|
|
||||||
return f"(constraint {expr.left.accept(self)} {expr.op.lexeme} {expr.right.accept(self)})"
|
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> str:
|
|
||||||
return "(_)"
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: LiteralExpr) -> str:
|
|
||||||
return f"({expr.value})"
|
|
||||||
|
|
||||||
def visit_type_body_expr(self, expr: TypeBodyExpr) -> str:
|
|
||||||
res: str = "(body"
|
|
||||||
for prop in expr.properties:
|
|
||||||
res += " " + prop.accept(self)
|
|
||||||
res += ")"
|
|
||||||
return res
|
|
||||||
|
|
||||||
def visit_property_stmt(self, stmt: PropertyStmt) -> str:
|
|
||||||
return f"(property {stmt.name.lexeme} {stmt.type.accept(self)})"
|
|
||||||
|
|
||||||
def visit_op_stmt(self, stmt: OpStmt) -> str:
|
|
||||||
left: str = stmt.left.accept(self)
|
|
||||||
right: str = stmt.right.accept(self)
|
|
||||||
result: str = stmt.result.accept(self)
|
|
||||||
return f"(op_def {left} {stmt.op.lexeme} {right} {result})"
|
|
||||||
|
|
||||||
def visit_constraint_stmt(self, stmt: ConstraintStmt) -> str:
|
|
||||||
return f"(constraint_def {stmt.name.lexeme} {stmt.constraint.accept(self)})"
|
|
||||||
|
|
||||||
|
|
||||||
def parse(source: str) -> list[Stmt]:
|
|
||||||
tokens: list[Token] = MidasLexer(source).process()
|
|
||||||
return MidasParser(tokens).parse()
|
|
||||||
|
|
||||||
|
|
||||||
def ast_str(source: str) -> list[str]:
|
|
||||||
stmts: list[Stmt] = parse(source)
|
|
||||||
return [AstSerializer().serialize(stmt) for stmt in stmts]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("type Foo<>", "(type_def Foo)"),
|
|
||||||
("type Foo<Bar>", "(type_def Foo (Bar))"),
|
|
||||||
("type Foo<Bar, Baz>", "(type_def Foo (Bar) (Baz))"),
|
|
||||||
(
|
|
||||||
"type Foo<Bar + (_ < 2), Baz>",
|
|
||||||
"(type_def Foo (Bar (constraint (_) < (2.0))) (Baz))",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"""
|
|
||||||
type Foo<> {
|
|
||||||
foo: Bar
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"(type_def Foo (body (property foo (Bar))))",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"""
|
|
||||||
type Foo<> {
|
|
||||||
foo: Bar + (_ != none)
|
|
||||||
foo2: Bar2 + (0 <= _) + (_ <= 100)
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"(type_def Foo (body (property foo (Bar (constraint (_) != (None)))) (property foo2 (Bar2 (constraint (0.0) <= (_)) (constraint (_) <= (100.0))))))",
|
|
||||||
),
|
|
||||||
("op <A> + <B> = <C>", "(op_def (A) + (B) (C))"),
|
|
||||||
(
|
|
||||||
"op <A + (_ < 100)> + <B + (_ < 100)> = <C + (_ < 200)>",
|
|
||||||
"(op_def (A (constraint (_) < (100.0))) + (B (constraint (_) < (100.0))) (C (constraint (_) < (200.0))))",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"constraint Positive = _ >= 0",
|
|
||||||
"(constraint_def Positive (constraint (_) >= (0.0)))",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_expressions(src: str, expected: str | list[str]):
|
|
||||||
if isinstance(expected, str):
|
|
||||||
expected = [expected]
|
|
||||||
assert ast_str(src) == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,pos",
|
|
||||||
[
|
|
||||||
###
|
|
||||||
# Misc
|
|
||||||
###
|
|
||||||
("42", (1, 1)),
|
|
||||||
("true", (1, 1)),
|
|
||||||
("foo", (1, 1)),
|
|
||||||
###
|
|
||||||
# Type statements
|
|
||||||
###
|
|
||||||
("type", (1, 5)),
|
|
||||||
("type true", (1, 6)),
|
|
||||||
("type Foo", (1, 9)),
|
|
||||||
("type Foo<1>", (1, 10)),
|
|
||||||
# ("type Foo<float,>", (1, 16)), # trailing comma is accepted, TODO: update parser or EBNF
|
|
||||||
("type Foo<float, 1>", (1, 17)),
|
|
||||||
("type Foo<float", (1, 15)),
|
|
||||||
("type Foo<float> { 3 }", (1, 19)),
|
|
||||||
(
|
|
||||||
"""
|
|
||||||
type Foo<float> {
|
|
||||||
foo
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
(4, 1),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"""
|
|
||||||
type Foo<float> {
|
|
||||||
foo: 3
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
(3, 10),
|
|
||||||
),
|
|
||||||
###
|
|
||||||
# Operation statements
|
|
||||||
###
|
|
||||||
("op", (1, 3)),
|
|
||||||
("op float", (1, 4)),
|
|
||||||
("op <", (1, 5)),
|
|
||||||
("op <float", (1, 10)),
|
|
||||||
("op <float>", (1, 11)),
|
|
||||||
("op <float> +", (1, 13)),
|
|
||||||
("op <float> + float", (1, 14)),
|
|
||||||
("op <float> + <", (1, 15)),
|
|
||||||
("op <float> + <float", (1, 20)),
|
|
||||||
("op <float> + <float>", (1, 21)),
|
|
||||||
("op <float> + <float> =", (1, 23)),
|
|
||||||
("op <float> + <float> = float", (1, 24)),
|
|
||||||
("op <float> + <float> = <", (1, 25)),
|
|
||||||
("op <float> + <float> = <float", (1, 30)),
|
|
||||||
("op <float + 3> + <float> = <float>", (1, 13)),
|
|
||||||
("op <float> + <float + 3> = <float>", (1, 23)),
|
|
||||||
("op <float> + <float> = <float + 3>", (1, 33)),
|
|
||||||
###
|
|
||||||
# Constraint statements
|
|
||||||
###
|
|
||||||
("constraint", (1, 11)),
|
|
||||||
("constraint 3", (1, 12)),
|
|
||||||
("constraint Foo", (1, 15)),
|
|
||||||
("constraint Foo =", (1, 17)),
|
|
||||||
("constraint Foo = 3", (1, 19)),
|
|
||||||
("constraint Foo = 3 <", (1, 21)),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_parsing_error(src: str, pos: tuple[int, int]):
|
|
||||||
src = textwrap.dedent(src)
|
|
||||||
tokens: list[Token] = MidasLexer(src).process()
|
|
||||||
parser: MidasParser = MidasParser(tokens)
|
|
||||||
stmt: list[Stmt] = parser.parse()
|
|
||||||
assert len(stmt) == 0
|
|
||||||
assert len(parser.errors) != 0
|
|
||||||
error_pos: Position = parser.errors[0].token.position
|
|
||||||
assert (error_pos.line, error_pos.column) == pos
|
|
||||||
46
tests/python.py
Normal file
46
tests/python.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import ast
|
||||||
|
import json
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.ast.python import Stmt
|
||||||
|
from midas.parser.python import PythonParser
|
||||||
|
from tests.base import Tester
|
||||||
|
from tests.serializer.python import PythonAstJsonSerializer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CaseResult:
|
||||||
|
stmts: Optional[list[dict]] = None
|
||||||
|
|
||||||
|
def dumps(self) -> str:
|
||||||
|
return json.dumps(asdict(self), indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
class PythonTester(Tester):
|
||||||
|
@property
|
||||||
|
def namespace(self) -> str:
|
||||||
|
return "python-parser"
|
||||||
|
|
||||||
|
def _list_tests(self) -> list[Path]:
|
||||||
|
return list(self.base_dir.rglob("*.py"))
|
||||||
|
|
||||||
|
def _exec_case(self, path: Path) -> CaseResult:
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||||
|
if not path.is_file():
|
||||||
|
raise TypeError(f"Test '{path}' is not a file")
|
||||||
|
|
||||||
|
result: CaseResult = CaseResult()
|
||||||
|
content: str = path.read_text()
|
||||||
|
tree: ast.Module = ast.parse(content)
|
||||||
|
|
||||||
|
parser: PythonParser = PythonParser()
|
||||||
|
stmts: list[Stmt] = parser.parse_module(tree)
|
||||||
|
result.stmts = PythonAstJsonSerializer().serialize(stmts)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
PythonTester.main()
|
||||||
159
tests/serializer/midas.py
Normal file
159
tests/serializer/midas.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
from midas.ast.midas import (
|
||||||
|
BinaryExpr,
|
||||||
|
ComplexTypeStmt,
|
||||||
|
Expr,
|
||||||
|
ExtendStmt,
|
||||||
|
GetExpr,
|
||||||
|
GroupingExpr,
|
||||||
|
LiteralExpr,
|
||||||
|
LogicalExpr,
|
||||||
|
OpStmt,
|
||||||
|
PredicateStmt,
|
||||||
|
PropertyStmt,
|
||||||
|
SimpleTypeExpr,
|
||||||
|
SimpleTypeStmt,
|
||||||
|
Stmt,
|
||||||
|
TemplateExpr,
|
||||||
|
TypeExpr,
|
||||||
|
UnaryExpr,
|
||||||
|
VariableExpr,
|
||||||
|
WildcardExpr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
|
||||||
|
"""An AST serializer which produces a JSON-compatible structure"""
|
||||||
|
|
||||||
|
def serialize(self, stmts: list[Stmt]) -> list[dict]:
|
||||||
|
return [stmt.accept(self) for stmt in stmts]
|
||||||
|
|
||||||
|
def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]:
|
||||||
|
if element is None:
|
||||||
|
return None
|
||||||
|
return element.accept(self)
|
||||||
|
|
||||||
|
def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]:
|
||||||
|
return [element.accept(self) for element in elements]
|
||||||
|
|
||||||
|
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "SimpleTypeStmt",
|
||||||
|
"name": stmt.name.lexeme,
|
||||||
|
"template": self._serialize_optional(stmt.template),
|
||||||
|
"base": stmt.base.accept(self),
|
||||||
|
"constraint": self._serialize_optional(stmt.constraint),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ComplexTypeStmt",
|
||||||
|
"name": stmt.name.lexeme,
|
||||||
|
"template": self._serialize_optional(stmt.template),
|
||||||
|
"properties": self._serialize_list(stmt.properties),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "PropertyStmt",
|
||||||
|
"name": stmt.name.lexeme,
|
||||||
|
"type": stmt.type.accept(self),
|
||||||
|
"constraint": self._serialize_optional(stmt.constraint),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ExtendStmt",
|
||||||
|
"type": stmt.type.accept(self),
|
||||||
|
"operations": self._serialize_list(stmt.operations),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_op_stmt(self, stmt: OpStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "OpStmt",
|
||||||
|
"name": stmt.name.lexeme,
|
||||||
|
"operand": stmt.operand.accept(self),
|
||||||
|
"result": stmt.result.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_predicate_stmt(self, stmt: PredicateStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "PredicateStmt",
|
||||||
|
"name": stmt.name.lexeme,
|
||||||
|
"subject": stmt.subject.lexeme,
|
||||||
|
"type": stmt.type.accept(self),
|
||||||
|
"condition": stmt.condition.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "SimpleTypeExpr",
|
||||||
|
"name": expr.name.lexeme,
|
||||||
|
"optional": expr.optional,
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "LogicalExpr",
|
||||||
|
"left": expr.left.accept(self),
|
||||||
|
"operator": expr.operator.lexeme,
|
||||||
|
"right": expr.right.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": expr.left.accept(self),
|
||||||
|
"operator": expr.operator.lexeme,
|
||||||
|
"right": expr.right.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: UnaryExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "UnaryExpr",
|
||||||
|
"operator": expr.operator.lexeme,
|
||||||
|
"right": expr.right.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: GetExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "GetExpr",
|
||||||
|
"expr": expr.expr.accept(self),
|
||||||
|
"name": expr.name.lexeme,
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: VariableExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": expr.name.lexeme,
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: GroupingExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "GroupingExpr",
|
||||||
|
"expr": expr.expr.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: LiteralExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": expr.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
|
||||||
|
return {"_type": "WildcardExpr"}
|
||||||
|
|
||||||
|
def visit_template_expr(self, expr: TemplateExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "TemplateExpr",
|
||||||
|
"type": expr.type.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_type_expr(self, expr: TypeExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "TypeExpr",
|
||||||
|
"name": expr.name.lexeme,
|
||||||
|
"template": self._serialize_optional(expr.template),
|
||||||
|
"optional": expr.optional,
|
||||||
|
}
|
||||||
247
tests/serializer/python.py
Normal file
247
tests/serializer/python.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
import ast
|
||||||
|
from typing import Optional, Sequence, Type
|
||||||
|
|
||||||
|
from midas.ast.python import (
|
||||||
|
AssignStmt,
|
||||||
|
BaseType,
|
||||||
|
BinaryExpr,
|
||||||
|
CallExpr,
|
||||||
|
CastExpr,
|
||||||
|
CompareExpr,
|
||||||
|
ConstraintType,
|
||||||
|
Expr,
|
||||||
|
ExpressionStmt,
|
||||||
|
FrameColumn,
|
||||||
|
FrameType,
|
||||||
|
Function,
|
||||||
|
GetExpr,
|
||||||
|
IfStmt,
|
||||||
|
LiteralExpr,
|
||||||
|
LogicalExpr,
|
||||||
|
MidasType,
|
||||||
|
ReturnStmt,
|
||||||
|
SetExpr,
|
||||||
|
Stmt,
|
||||||
|
TypeAssign,
|
||||||
|
UnaryExpr,
|
||||||
|
VariableExpr,
|
||||||
|
)
|
||||||
|
|
||||||
|
unary_ops: dict[Type[ast.unaryop], str] = {
|
||||||
|
ast.Invert: "~",
|
||||||
|
ast.Not: "not",
|
||||||
|
ast.UAdd: "+",
|
||||||
|
ast.USub: "-",
|
||||||
|
}
|
||||||
|
binary_ops: dict[Type[ast.operator], str] = {
|
||||||
|
ast.Add: "+",
|
||||||
|
ast.Sub: "-",
|
||||||
|
ast.Mult: "*",
|
||||||
|
ast.MatMult: "@",
|
||||||
|
ast.Div: "/",
|
||||||
|
ast.Mod: "%",
|
||||||
|
ast.LShift: "<<",
|
||||||
|
ast.RShift: ">>",
|
||||||
|
ast.BitOr: "|",
|
||||||
|
ast.BitXor: "^",
|
||||||
|
ast.BitAnd: "&",
|
||||||
|
ast.FloorDiv: "//",
|
||||||
|
ast.Pow: "**",
|
||||||
|
}
|
||||||
|
compare_ops: dict[Type[ast.cmpop], str] = {
|
||||||
|
ast.Eq: "==",
|
||||||
|
ast.NotEq: "!=",
|
||||||
|
ast.Lt: "<",
|
||||||
|
ast.LtE: "<=",
|
||||||
|
ast.Gt: ">",
|
||||||
|
ast.GtE: ">=",
|
||||||
|
ast.Is: "is",
|
||||||
|
ast.IsNot: "is not",
|
||||||
|
ast.In: "in",
|
||||||
|
ast.NotIn: "not in",
|
||||||
|
}
|
||||||
|
boolean_ops: dict[Type[ast.boolop], str] = {
|
||||||
|
ast.And: "and",
|
||||||
|
ast.Or: "or",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PythonAstJsonSerializer(
|
||||||
|
Stmt.Visitor[dict], Expr.Visitor[dict], MidasType.Visitor[dict]
|
||||||
|
):
|
||||||
|
"""An AST serializer which produces a JSON-compatible structure"""
|
||||||
|
|
||||||
|
def serialize(self, stmts: list[Stmt]) -> list[dict]:
|
||||||
|
return [stmt.accept(self) for stmt in stmts]
|
||||||
|
|
||||||
|
def _serialize_optional(
|
||||||
|
self, element: Optional[Stmt | Expr | MidasType]
|
||||||
|
) -> Optional[dict]:
|
||||||
|
if element is None:
|
||||||
|
return None
|
||||||
|
return element.accept(self)
|
||||||
|
|
||||||
|
def _serialize_list(
|
||||||
|
self, elements: Sequence[Stmt | Expr | MidasType]
|
||||||
|
) -> list[dict]:
|
||||||
|
return [element.accept(self) for element in elements]
|
||||||
|
|
||||||
|
def visit_base_type(self, node: BaseType) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": node.base,
|
||||||
|
"param": self._serialize_optional(node.param),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_constraint_type(self, node: ConstraintType) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ConstraintType",
|
||||||
|
"type": node.type.accept(self),
|
||||||
|
"constraint": ast.unparse(node.constraint),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_frame_column(self, node: FrameColumn) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "FrameColumn",
|
||||||
|
"name": node.name,
|
||||||
|
"type": self._serialize_optional(node.type),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_frame_type(self, node: FrameType) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "FrameType",
|
||||||
|
"columns": self._serialize_list(node.columns),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_expression_stmt(self, stmt: ExpressionStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ExpressionStmt",
|
||||||
|
"expr": stmt.expr.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _serialize_argument(self, arg: Function.Argument) -> dict:
|
||||||
|
return {
|
||||||
|
"name": arg.name,
|
||||||
|
"type": self._serialize_optional(arg.type),
|
||||||
|
"default": self._serialize_optional(arg.default),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_function(self, stmt: Function) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "Function",
|
||||||
|
"name": stmt.name,
|
||||||
|
"posonlyargs": [self._serialize_argument(arg) for arg in stmt.posonlyargs],
|
||||||
|
"args": [self._serialize_argument(arg) for arg in stmt.args],
|
||||||
|
"sink": (
|
||||||
|
self._serialize_argument(stmt.sink) if stmt.sink is not None else None
|
||||||
|
),
|
||||||
|
"kwonlyargs": [self._serialize_argument(arg) for arg in stmt.kwonlyargs],
|
||||||
|
"kw_sink": (
|
||||||
|
self._serialize_argument(stmt.kw_sink)
|
||||||
|
if stmt.kw_sink is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"returns": self._serialize_optional(stmt.returns),
|
||||||
|
"body": self._serialize_list(stmt.body),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_type_assign(self, stmt: TypeAssign) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "TypeAssign",
|
||||||
|
"name": stmt.name,
|
||||||
|
"type": stmt.type.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_assign_stmt(self, stmt: AssignStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "AssignStmt",
|
||||||
|
"targets": self._serialize_list(stmt.targets),
|
||||||
|
"value": stmt.value.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_return_stmt(self, stmt: ReturnStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ReturnStmt",
|
||||||
|
"value": self._serialize_optional(stmt.value),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: IfStmt) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "IfStmt",
|
||||||
|
"test": stmt.test.accept(self),
|
||||||
|
"body": self._serialize_list(stmt.body),
|
||||||
|
"orelse": self._serialize_list(stmt.orelse),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": expr.left.accept(self),
|
||||||
|
"operator": binary_ops[expr.operator.__class__],
|
||||||
|
"right": expr.right.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_compare_expr(self, expr: CompareExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "CompareExpr",
|
||||||
|
"left": expr.left.accept(self),
|
||||||
|
"operator": compare_ops[expr.operator.__class__],
|
||||||
|
"right": expr.right.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: UnaryExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "UnaryExpr",
|
||||||
|
"operator": unary_ops[expr.operator.__class__],
|
||||||
|
"right": expr.right.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: CallExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "CallExpr",
|
||||||
|
"callee": expr.callee.accept(self),
|
||||||
|
"arguments": self._serialize_list(expr.arguments),
|
||||||
|
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: GetExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "GetExpr",
|
||||||
|
"object": expr.object.accept(self),
|
||||||
|
"name": expr.name,
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: LiteralExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": expr.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: VariableExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": expr.name,
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "LogicalExpr",
|
||||||
|
"left": expr.left.accept(self),
|
||||||
|
"operator": boolean_ops[expr.operator.__class__],
|
||||||
|
"right": expr.right.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_set_expr(self, expr: SetExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "SetExpr",
|
||||||
|
"object": expr.object.accept(self),
|
||||||
|
"name": expr.name,
|
||||||
|
"value": expr.value.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_cast_expr(self, expr: CastExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "CastExpr",
|
||||||
|
"type": expr.type.accept(self),
|
||||||
|
"expr": expr.expr.accept(self),
|
||||||
|
}
|
||||||
@@ -31,22 +31,32 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"type-base": {
|
"type-base": {
|
||||||
"begin": "<",
|
"begin": "(\\()([a-zA-Z_][a-zA-Z_\\d]*)(\\))",
|
||||||
"end": ">",
|
"end": "$",
|
||||||
"beginCaptures": {
|
"beginCaptures": {
|
||||||
"0": {
|
"1": {
|
||||||
"name": "punctuation.definition.base.begin.midas"
|
"name": "punctuation.definition.base.begin.midas"
|
||||||
}
|
},
|
||||||
},
|
"2": {
|
||||||
"endCaptures": {
|
"name": "variable.name"
|
||||||
"0": {
|
},
|
||||||
|
"3": {
|
||||||
"name": "punctuation.definition.base.end.midas"
|
"name": "punctuation.definition.base.end.midas"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"patterns": [
|
"patterns": [
|
||||||
{"include": "source.python"}
|
{ "include": "#type-cond" }
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"type-cond": {
|
||||||
|
"begin": "where",
|
||||||
|
"end": "$",
|
||||||
|
"beginCaptures": {
|
||||||
|
"0": {
|
||||||
|
"name": "keyword.control.where.midas"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"type-body": {
|
"type-body": {
|
||||||
"begin": "\\{",
|
"begin": "\\{",
|
||||||
"end": "\\}",
|
"end": "\\}",
|
||||||
@@ -61,7 +71,8 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"patterns": [
|
"patterns": [
|
||||||
{"include": "#type-prop"}
|
{"include": "#type-prop"},
|
||||||
|
{"include": "#comment"}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"type-prop": {
|
"type-prop": {
|
||||||
@@ -78,44 +89,67 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"op-def": {
|
"extend-def": {
|
||||||
"match": "\\b(op)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>\\s+(\\S+)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>\\s+(=)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>",
|
"begin": "\\b(extend)\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\s+(\\{)",
|
||||||
"captures": {
|
"end": "\\}",
|
||||||
"1": {
|
|
||||||
"name": "keyword.control.op.midas"
|
|
||||||
},
|
|
||||||
"2": {
|
|
||||||
"name" : "variable.name"
|
|
||||||
},
|
|
||||||
"3": {
|
|
||||||
"name" : "keyword.operator"
|
|
||||||
},
|
|
||||||
"4": {
|
|
||||||
"name" : "variable.name"
|
|
||||||
},
|
|
||||||
"5": {
|
|
||||||
"name" : "keyword.operator.assignment"
|
|
||||||
},
|
|
||||||
"6": {
|
|
||||||
"name" : "variable.name"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"patterns": [
|
|
||||||
{ "include": "#type-base" },
|
|
||||||
{ "include": "#type-body" }
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"constr-def": {
|
|
||||||
"begin": "(constraint)\\s+([a-zA-Z_][a-zA-Z_\\d]*)\\s*(=)",
|
|
||||||
"end": "$",
|
|
||||||
"beginCaptures": {
|
"beginCaptures": {
|
||||||
"1": {
|
"1": {
|
||||||
"name": "keyword.control.constr.midas"
|
"name": "keyword.control.extend.midas"
|
||||||
},
|
},
|
||||||
"2": {
|
"2": {
|
||||||
"name": "variable.name"
|
"name": "variable.name"
|
||||||
},
|
},
|
||||||
"3": {
|
"3": {
|
||||||
|
"name": "punctuation.definition.extend-body.begin.midas"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"endCaptures": {
|
||||||
|
"0": {
|
||||||
|
"name": "punctuation.definition.extend-body.end.midas"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"patterns": [
|
||||||
|
{"include": "#op-def"},
|
||||||
|
{"include": "#comment"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"op-def": {
|
||||||
|
"match": "\\b(op)\\s+(\\S+)\\s*\\(\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\s*\\)\\s*(->)\\s*([a-zA-Z_][a-zA-Z_\\d]*)",
|
||||||
|
"captures": {
|
||||||
|
"1": {
|
||||||
|
"name": "keyword.control.op.midas"
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"name" : "keyword.operator"
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"name" : "variable.name"
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"name" : "keyword.operator.assignment"
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"name" : "variable.name"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"pred-def": {
|
||||||
|
"begin": "(predicate)\\s+([a-zA-Z_][a-zA-Z_\\d]*)\\(([a-zA-Z_][a-zA-Z_\\d]*):\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\)\\s*(=)",
|
||||||
|
"end": "$",
|
||||||
|
"beginCaptures": {
|
||||||
|
"1": {
|
||||||
|
"name": "keyword.control.pred.midas"
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"name": "variable.name"
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"name": "variable.name"
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"name": "variable.name"
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
"name": "keyword.operator.assignment"
|
"name": "keyword.operator.assignment"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -127,8 +161,8 @@
|
|||||||
"patterns": [
|
"patterns": [
|
||||||
{ "include": "#comment" },
|
{ "include": "#comment" },
|
||||||
{ "include": "#type-def" },
|
{ "include": "#type-def" },
|
||||||
{ "include": "#op-def" },
|
{ "include": "#extend-def" },
|
||||||
{ "include": "#constr-def" }
|
{ "include": "#pred-def" }
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user