Compare commits
104 Commits
v0.0.1-pro
...
bea3f399ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
bea3f399ad
|
|||
|
55060bfecd
|
|||
|
dd126f2559
|
|||
|
4151f5373d
|
|||
|
bd31713ab4
|
|||
|
f4dc57cb96
|
|||
|
261fd47494
|
|||
|
1b66a8553d
|
|||
|
65164abadb
|
|||
|
9d45163d9c
|
|||
|
ab0fa1de1a
|
|||
|
5d4df7978b
|
|||
|
86ad348b99
|
|||
|
29f691e38a
|
|||
|
f2c61d24e2
|
|||
|
112ed0e816
|
|||
|
7eb1e13b70
|
|||
|
893e1ba190
|
|||
|
1a1b0e8e15
|
|||
|
4ddde364ed
|
|||
|
4a3363a3d6
|
|||
|
0a3216e07d
|
|||
|
c29c0ed3ec
|
|||
|
fa7e56cb77
|
|||
|
13c19db818
|
|||
|
95b218fbed
|
|||
|
c3722c7438
|
|||
|
9dd547d6c1
|
|||
|
e2d5943517
|
|||
|
86e4763a12
|
|||
|
89ec63cb05
|
|||
|
e6375f1aa9
|
|||
|
d16e192a3a
|
|||
|
3f61f84e5a
|
|||
|
fd5399f50a
|
|||
|
8906ac3db8
|
|||
|
022aebf55b
|
|||
|
5dc6903425
|
|||
|
1b078b832c
|
|||
|
7515716864
|
|||
|
218b0c5b78
|
|||
|
928901ef9c
|
|||
|
4b62c78874
|
|||
|
f882eebaf5
|
|||
|
a872938405
|
|||
|
146be72fd7
|
|||
|
6de54e1da1
|
|||
|
c82b41a4df
|
|||
|
8304760fe0
|
|||
|
6bf91db757
|
|||
|
3f6b650a4b
|
|||
| ec079f32ca | |||
|
6524b3591a
|
|||
|
170101aa37
|
|||
|
0b3f33d7fe
|
|||
|
8a9b4f3989
|
|||
|
bbd0e3ae8d
|
|||
|
4d23e8840e
|
|||
|
c64d626d1c
|
|||
|
ecab1b74a4
|
|||
|
0bbdf04621
|
|||
|
939e5af4ce
|
|||
|
a735113466
|
|||
|
0e0a1b26f2
|
|||
|
e94db2181f
|
|||
|
9b59058881
|
|||
|
d0c54db33a
|
|||
|
5aedddfabb
|
|||
|
8d7c115432
|
|||
|
832c350b61
|
|||
|
3d599b3462
|
|||
|
4f799caaf5
|
|||
|
f4d2be3b1b
|
|||
|
7ce2840f03
|
|||
|
e2f3cabe15
|
|||
|
5a112332f2
|
|||
|
eb79cf6dc3
|
|||
|
8a9bb6ef4e
|
|||
|
6e0190a378
|
|||
| b5969e9a2b | |||
|
409d9f8fa6
|
|||
|
12d762429d
|
|||
|
53929ee514
|
|||
|
2f6e137f1a
|
|||
|
5224e79d9f
|
|||
|
bdcb12c58a
|
|||
|
5cb4d587e3
|
|||
|
8f9ec8d73b
|
|||
|
c1c50a448e
|
|||
|
19229db0b1
|
|||
|
f3b6bd146f
|
|||
|
98c3510bd4
|
|||
|
429d0d98fe
|
|||
|
db8fe5d3ff
|
|||
|
7477ec8d70
|
|||
|
adf7f4e7a2
|
|||
|
abf6787946
|
|||
|
e282b08597
|
|||
|
0a02b9d3d9
|
|||
| 875ca589e4 | |||
|
88f92d6e1f
|
|||
|
db4ed74365
|
|||
|
7cbf4fdece
|
|||
|
1fa9a09bfe
|
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,3 +4,6 @@ __pycache__
|
||||
venv
|
||||
.venv
|
||||
*.pyc
|
||||
uv.lock
|
||||
.python-version
|
||||
/out
|
||||
@@ -1,107 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from lexer.token import Token
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Stmt(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnnotationStmt(Stmt):
|
||||
name: Token
|
||||
schema: Optional[SchemaExpr]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_annotation_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Expr(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_expr(self, expr: TypeExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_schema_expr(self, expr: SchemaExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WildcardExpr(Expr):
|
||||
token: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_wildcard_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteralExpr(Expr):
|
||||
value: Any
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_literal_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeExpr(Expr):
|
||||
name: Token
|
||||
constraints: list[ConstraintExpr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_type_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintExpr(Expr):
|
||||
left: Expr
|
||||
op: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SchemaExpr(Expr):
|
||||
left: Token
|
||||
elements: list[Expr]
|
||||
right: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_schema_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SchemaElementExpr(Expr):
|
||||
name: Optional[Token]
|
||||
type: Optional[Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_schema_element_expr(self)
|
||||
@@ -1,138 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from lexer.token import Token
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# Statements
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Stmt(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_op_stmt(self, stmt: OpStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_stmt(self, stmt: ConstraintStmt) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeStmt(Stmt):
|
||||
name: Token
|
||||
bases: list[TypeExpr]
|
||||
body: Optional[TypeBodyExpr]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_type_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PropertyStmt(Stmt):
|
||||
name: Token
|
||||
type: TypeExpr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_property_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OpStmt(Stmt):
|
||||
left: TypeExpr
|
||||
op: Token
|
||||
right: TypeExpr
|
||||
result: TypeExpr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_op_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintStmt(Stmt):
|
||||
name: Token
|
||||
constraint: ConstraintExpr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_stmt(self)
|
||||
|
||||
|
||||
# Expressions
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Expr(ABC):
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_expr(self, expr: TypeExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_body_expr(self, expr: TypeBodyExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WildcardExpr(Expr):
|
||||
token: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_wildcard_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteralExpr(Expr):
|
||||
value: Any
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_literal_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeExpr(Expr):
|
||||
name: Token
|
||||
constraints: list[ConstraintExpr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_type_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintExpr(Expr):
|
||||
left: Expr
|
||||
op: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeBodyExpr(Expr):
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_type_body_expr(self)
|
||||
@@ -1,360 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
import io
|
||||
from typing import Generator, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import core.ast.annotations as a
|
||||
import core.ast.midas as m
|
||||
|
||||
|
||||
class _Level(Enum):
|
||||
EMPTY = auto()
|
||||
ACTIVE = auto()
|
||||
LAST = auto()
|
||||
|
||||
|
||||
class Expr(Protocol):
|
||||
def accept(self, printer: AstPrinter) -> None: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Expr)
|
||||
|
||||
|
||||
class AstPrinter(Generic[T]):
|
||||
LAST_CHILD = "└── "
|
||||
CHILD = "├── "
|
||||
VERTICAL = "│ "
|
||||
EMPTY = " "
|
||||
|
||||
def __init__(self):
|
||||
self._levels: list[_Level] = []
|
||||
self._idx: Optional[int] = None
|
||||
self._buf: io.StringIO = io.StringIO()
|
||||
|
||||
def print(self, expr: T):
|
||||
self._buf = io.StringIO()
|
||||
expr.accept(self)
|
||||
return self._buf.getvalue()
|
||||
|
||||
@contextmanager
|
||||
def _child_level(self, last: bool = False) -> Generator[None, None, None]:
|
||||
self._levels.append(_Level.LAST if last else _Level.ACTIVE)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._levels.pop()
|
||||
|
||||
def _mark_last(self):
|
||||
if self._levels:
|
||||
self._levels[-1] = _Level.LAST
|
||||
|
||||
def _write_line(self, text: str, *, last: bool = False):
|
||||
if last:
|
||||
self._mark_last()
|
||||
indent: str = self._build_indent()
|
||||
if self._idx is not None:
|
||||
text = f"[{self._idx}] {text}"
|
||||
self._idx = None
|
||||
self._buf.write(indent + text + "\n")
|
||||
|
||||
def _build_indent(self) -> str:
|
||||
parts: list[str] = []
|
||||
for level in self._levels[:-1]:
|
||||
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
|
||||
if self._levels:
|
||||
if self._levels[-1] == _Level.LAST:
|
||||
parts.append(self.LAST_CHILD)
|
||||
self._levels[-1] = _Level.EMPTY
|
||||
else:
|
||||
parts.append(self.CHILD)
|
||||
return "".join(parts)
|
||||
|
||||
def _write_optional_child(
|
||||
self, label: str, child: Optional[T], *, last: bool = False
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
if child is None:
|
||||
self._write_line(f"{label}: None")
|
||||
else:
|
||||
self._write_line(label)
|
||||
with self._child_level(last=True):
|
||||
child.accept(self)
|
||||
|
||||
|
||||
class AnnotationAstPrinter(AstPrinter, a.Expr.Visitor[None], a.Stmt.Visitor[None]):
|
||||
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> None:
|
||||
self._write_line("AnnotationStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_optional_child("schema", stmt.schema, last=True)
|
||||
|
||||
def visit_type_expr(self, expr: a.TypeExpr):
|
||||
self._write_line("TypeExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||
self._write_line("constraints", last=True)
|
||||
with self._child_level():
|
||||
for i, constraint in enumerate(expr.constraints):
|
||||
self._idx = i
|
||||
if i == len(expr.constraints) - 1:
|
||||
self._mark_last()
|
||||
constraint.accept(self)
|
||||
|
||||
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> None:
|
||||
self._write_line("ConstraintExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.op.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_schema_expr(self, expr: a.SchemaExpr):
|
||||
self._write_line("SchemaExpr")
|
||||
with self._child_level():
|
||||
for i, elmt in enumerate(expr.elements):
|
||||
self._idx = i
|
||||
if i == len(expr.elements) - 1:
|
||||
self._mark_last()
|
||||
elmt.accept(self)
|
||||
|
||||
def visit_schema_element_expr(self, expr: a.SchemaElementExpr):
|
||||
self._write_line("SchemaElementExpr")
|
||||
with self._child_level():
|
||||
name_text: str = "None" if expr.name is None else f'"{expr.name.lexeme}"'
|
||||
self._write_line(f"name: {name_text}")
|
||||
self._write_optional_child("type", expr.type, last=True)
|
||||
|
||||
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
def visit_literal_expr(self, expr: a.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"value: {expr.value}", last=True)
|
||||
|
||||
|
||||
class AnnotationPrinter(a.Expr.Visitor[str], a.Stmt.Visitor[str]):
|
||||
def print(self, expr: a.Expr | a.Stmt):
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> str:
|
||||
schema: str = ""
|
||||
if stmt.schema is not None:
|
||||
schema = stmt.schema.accept(self)
|
||||
return f"{stmt.name.lexeme}{schema}"
|
||||
|
||||
def visit_type_expr(self, expr: a.TypeExpr) -> str:
|
||||
parts: list[str] = [expr.name.lexeme]
|
||||
for constraint in expr.constraints:
|
||||
parts.append("(" + constraint.accept(self) + ")")
|
||||
return " + ".join(parts)
|
||||
|
||||
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> str:
|
||||
parts: list[str] = [
|
||||
expr.left.accept(self),
|
||||
expr.op.lexeme,
|
||||
expr.right.accept(self),
|
||||
]
|
||||
return " ".join(parts)
|
||||
|
||||
def visit_schema_expr(self, expr: a.SchemaExpr) -> str:
|
||||
res: str = expr.left.lexeme
|
||||
res += ", ".join(elmt.accept(self) for elmt in expr.elements)
|
||||
res += expr.right.lexeme
|
||||
return res
|
||||
|
||||
def visit_schema_element_expr(self, expr: a.SchemaElementExpr) -> str:
|
||||
parts: list[str] = []
|
||||
if expr.name is not None:
|
||||
parts.append(expr.name.lexeme)
|
||||
|
||||
if expr.type is None:
|
||||
parts.append("_")
|
||||
else:
|
||||
parts.append(expr.type.accept(self))
|
||||
return ": ".join(parts)
|
||||
|
||||
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> str:
|
||||
return "_"
|
||||
|
||||
def visit_literal_expr(self, expr: a.LiteralExpr) -> str:
|
||||
return str(expr.value)
|
||||
|
||||
|
||||
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt):
|
||||
self._write_line("TypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("bases")
|
||||
with self._child_level():
|
||||
for i, base in enumerate(stmt.bases):
|
||||
self._idx = i
|
||||
if i == len(stmt.bases) - 1:
|
||||
self._mark_last()
|
||||
base.accept(self)
|
||||
self._write_optional_child("body", stmt.body, last=True)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
self._write_line("PropertyStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
||||
self._write_line("OpStmt")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.left.accept(self)
|
||||
|
||||
self._write_line(f'op: "{stmt.op.lexeme}"')
|
||||
|
||||
self._write_line("right")
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.right.accept(self)
|
||||
|
||||
self._write_line("result", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.result.accept(self)
|
||||
|
||||
def visit_constraint_stmt(self, stmt: m.ConstraintStmt):
|
||||
self._write_line("ConstraintStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("constraint", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
stmt.constraint.accept(self)
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr):
|
||||
self._write_line("TypeExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||
self._write_line("constraints", last=True)
|
||||
with self._child_level():
|
||||
for i, constraint in enumerate(expr.constraints):
|
||||
self._idx = i
|
||||
if i == len(expr.constraints) - 1:
|
||||
self._mark_last()
|
||||
constraint.accept(self)
|
||||
|
||||
def visit_constraint_expr(self, expr: m.ConstraintExpr):
|
||||
self._write_line("ConstraintExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.op.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level():
|
||||
self._mark_last()
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_type_body_expr(self, expr: m.TypeBodyExpr):
|
||||
self._write_line("TypeBodyExpr")
|
||||
with self._child_level():
|
||||
self._write_line("properties", last=True)
|
||||
with self._child_level():
|
||||
for i, property in enumerate(expr.properties):
|
||||
self._idx = i
|
||||
if i == len(expr.properties) - 1:
|
||||
self._mark_last()
|
||||
property.accept(self)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"value: {expr.value}", last=True)
|
||||
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
def __init__(self, indent: int = 4):
|
||||
self.indent: int = indent
|
||||
self.level: int = 0
|
||||
|
||||
def indented(self, text: str) -> str:
|
||||
return " " * (self.level * self.indent) + text
|
||||
|
||||
def print(self, expr: m.Expr | m.Stmt):
|
||||
self.level = 0
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt):
|
||||
bases: list[str] = [
|
||||
b.accept(self)
|
||||
for b in stmt.bases
|
||||
]
|
||||
|
||||
res: str = self.indented(f"type {stmt.name.lexeme}<{', '.join(bases)}>")
|
||||
if stmt.body is not None:
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
res += stmt.body.accept(self)
|
||||
self.level -= 1
|
||||
res += "\n" + self.indented("}")
|
||||
|
||||
return res
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
return f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt):
|
||||
left: str = stmt.left.accept(self)
|
||||
op: str = stmt.op.lexeme
|
||||
right: str = stmt.right.accept(self)
|
||||
result: str = stmt.result.accept(self)
|
||||
return self.indented(f"op <{left}> {op} <{right}> = <{result}>")
|
||||
|
||||
def visit_constraint_stmt(self, stmt: m.ConstraintStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
constraint: str = stmt.constraint.accept(self)
|
||||
return self.indented(f"constraint {name} = {constraint}")
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr):
|
||||
parts: list[str] = [expr.name.lexeme]
|
||||
for constraint in expr.constraints:
|
||||
parts.append("(" + constraint.accept(self) + ")")
|
||||
return " + ".join(parts)
|
||||
|
||||
def visit_constraint_expr(self, expr: m.ConstraintExpr):
|
||||
parts: list[str] = [
|
||||
expr.left.accept(self),
|
||||
expr.op.lexeme,
|
||||
expr.right.accept(self),
|
||||
]
|
||||
return " ".join(parts)
|
||||
|
||||
def visit_type_body_expr(self, expr: m.TypeBodyExpr):
|
||||
properties: list[str] = [
|
||||
self.indented(prop.accept(self))
|
||||
for prop in expr.properties
|
||||
]
|
||||
return "\n".join(properties)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
||||
return "_"
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr):
|
||||
return str(expr.value)
|
||||
150
docs/architecture.typ
Normal file
150
docs/architecture.typ
Normal file
@@ -0,0 +1,150 @@
|
||||
#import "@preview/cetz:0.5.2": canvas, draw
|
||||
|
||||
#let diagram-only = false
|
||||
|
||||
#set document(
|
||||
title: [Midas Architecture],
|
||||
//author: "Louis Heredero",
|
||||
)
|
||||
|
||||
#set text(
|
||||
font: "Source Sans 3",
|
||||
)
|
||||
|
||||
#let diagram = canvas({
|
||||
let framed = draw.content.with(
|
||||
padding: (x: .8em, y: 1em),
|
||||
frame: "rect",
|
||||
stroke: black,
|
||||
)
|
||||
let arrow = draw.line.with(mark: (end: ">", fill: black))
|
||||
framed(
|
||||
(0, 0),
|
||||
name: "python-parser",
|
||||
)[Python parser]
|
||||
|
||||
draw.content(
|
||||
(rel: (0, 1), to: "python-parser.north"),
|
||||
padding: 5pt,
|
||||
anchor: "south",
|
||||
name: "source-py",
|
||||
)[_`source.py`_]
|
||||
arrow("source-py", "python-parser")
|
||||
|
||||
framed(
|
||||
(rel: (3, 0), to: "python-parser.east"),
|
||||
anchor: "west",
|
||||
name: "custom-parser",
|
||||
align(center)[Custom python\ parser],
|
||||
)
|
||||
|
||||
arrow("python-parser", "custom-parser", name: "arrow-python-ast")
|
||||
draw.content(
|
||||
"arrow-python-ast",
|
||||
anchor: "south",
|
||||
padding: 5pt,
|
||||
)[`ast.Module`]
|
||||
|
||||
framed(
|
||||
(rel: (-3, -2), to: "custom-parser.south"),
|
||||
anchor: "east",
|
||||
name: "python-resolver",
|
||||
)[Python Resolver]
|
||||
arrow(
|
||||
"custom-parser",
|
||||
((), "|-", "python-resolver.east"),
|
||||
"python-resolver",
|
||||
name: "arrow-python-custom-ast",
|
||||
)
|
||||
draw.content(
|
||||
(rel: (1.5, 0), to: "arrow-python-custom-ast.end"),
|
||||
padding: 5pt,
|
||||
anchor: "south",
|
||||
)[P-AST#footnote[#strong[P]ython *AST*]<fn-past>]
|
||||
draw.content(
|
||||
"python-resolver.west",
|
||||
padding: 5pt,
|
||||
anchor: "south-east",
|
||||
)[Resolved P-AST@fn-past]
|
||||
|
||||
draw.circle(
|
||||
(rel: (1, -2), to: "custom-parser.south-east"),
|
||||
radius: .4,
|
||||
name: "midas-loader",
|
||||
)
|
||||
arrow(
|
||||
"custom-parser",
|
||||
"midas-loader",
|
||||
name: "arrow-load-midas",
|
||||
mark: (end: (symbol: ">", fill: black), start: "o"),
|
||||
)
|
||||
draw.content(
|
||||
"arrow-load-midas",
|
||||
anchor: "west",
|
||||
padding: 5pt,
|
||||
)[```python midas.using("types.midas")```]
|
||||
|
||||
framed(
|
||||
(rel: (0, -2), to: "midas-loader.south"),
|
||||
name: "midas-parser",
|
||||
)[Midas lexer/parser]
|
||||
arrow("midas-loader", "midas-parser", name: "arrow-midas-source")
|
||||
draw.content(
|
||||
"arrow-midas-source",
|
||||
anchor: "west",
|
||||
padding: 5pt,
|
||||
)[_`types.midas`_]
|
||||
|
||||
|
||||
framed(
|
||||
(rel: (-2, 0), to: "midas-parser.west"),
|
||||
anchor: "east",
|
||||
name: "midas-resolver",
|
||||
)[Midas Resolver]
|
||||
arrow("midas-parser", "midas-resolver", name: "arrow-midas-ast")
|
||||
draw.content(
|
||||
"arrow-midas-ast",
|
||||
anchor: "south",
|
||||
padding: 5pt,
|
||||
)[M-AST#footnote[#strong[M]idas *AST*]<fn-mast>]
|
||||
|
||||
framed(
|
||||
(rel: (-3, 0), to: "midas-resolver.west"),
|
||||
anchor: "east",
|
||||
name: "checker",
|
||||
)[Checker]
|
||||
arrow("midas-resolver", "checker", name: "arrow-type-ctx")
|
||||
arrow(
|
||||
"python-resolver",
|
||||
((), "-|", "checker.north"),
|
||||
"checker",
|
||||
)
|
||||
draw.content(
|
||||
"arrow-type-ctx",
|
||||
anchor: "south",
|
||||
padding: 5pt,
|
||||
)[Types context]
|
||||
})
|
||||
|
||||
#show: doc => if diagram-only {
|
||||
set page(width: auto, height: auto, margin: .5cm)
|
||||
diagram
|
||||
} else { doc }
|
||||
|
||||
#align(center, title())
|
||||
|
||||
#v(1cm)
|
||||
|
||||
#figure(
|
||||
diagram,
|
||||
caption: [Midas type-checker architecture],
|
||||
)
|
||||
|
||||
== Components
|
||||
|
||||
- *Python parser*: builtin Python AST parser, extracts abstract syntax from the raw Python source (```python ast.parse(...)```)
|
||||
- *Custom python parser*: converts the raw Python AST into custom, more suitable constructs, especially for type annotations
|
||||
- *Python resolver*: resolves bindings and references, tracks binding scopes
|
||||
- *Midas lexer/parser*: parses a Midas type definition file and extracts its AST
|
||||
- *Midas resolver*: walks the AST and fills the environment with the defined types and operations
|
||||
- *Checker*: evaluates expressions and checks type coherence
|
||||
@@ -21,7 +21,7 @@ lat + lon # Invalid operation
|
||||
# Registered operations are permitted
|
||||
lat1: Latitude = lat[0]
|
||||
lat2: Latitude = lat[1]
|
||||
lat_diff: LatitudeDiff = lat2 - lat1 # Valid operation
|
||||
lat_diff: Difference[Latitude] = lat2 - lat1 # Valid operation
|
||||
|
||||
# In addition to the type, a column can have one or more constraints, either defined inline or in a separate file
|
||||
df2: Frame[
|
||||
|
||||
73
examples/00_syntax_prototype/03_custom_types_v2.midas
Normal file
73
examples/00_syntax_prototype/03_custom_types_v2.midas
Normal file
@@ -0,0 +1,73 @@
|
||||
// Simple custom type derived from float
|
||||
type Custom(float)
|
||||
|
||||
// Simple custom types with constraints
|
||||
type Latitude(float) where (-90 <= _ <= 90)
|
||||
type Longitude(float) where (-180 <= _ <= 180)
|
||||
|
||||
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
|
||||
type Difference[T](T)
|
||||
|
||||
// Complex custom type, containing two values accessible through properties
|
||||
type GeoLocation {
|
||||
lat: Latitude
|
||||
lon: Longitude
|
||||
}
|
||||
|
||||
// Define operations on our custom type
|
||||
extend GeoLocation {
|
||||
// This type is compatible with the `-` operation with another GeoLocation
|
||||
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
|
||||
// in a Difference of GeoLocations
|
||||
op __sub__(GeoLocation) -> Difference[GeoLocation]
|
||||
}
|
||||
|
||||
// For complex generics, you need to specify how the genericity the properties
|
||||
// are handled
|
||||
type Difference[GeoLocation] {
|
||||
lat: Difference[Latitude]
|
||||
lon: Difference[Longitude]
|
||||
}
|
||||
|
||||
// Simple operation defined on our custom types
|
||||
extend Latitude {
|
||||
op __sub__(Latitude) -> Difference[Latitude]
|
||||
}
|
||||
|
||||
extend Longitude {
|
||||
op __sub__(Longitude) -> Difference[Longitude]
|
||||
}
|
||||
|
||||
// Predefined custom predicates that can be referenced in other definitions
|
||||
predicate Positive(v: float) = v >= 0
|
||||
predicate StrictlyPositive(v: float) = v > 0
|
||||
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
|
||||
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
|
||||
|
||||
type Person {
|
||||
name: str
|
||||
|
||||
// Property with an inline constraint
|
||||
age: int? where (0 <= _ < 150)
|
||||
|
||||
// Property referencing a predicate
|
||||
height: float where StrictlyPositive
|
||||
|
||||
home: GeoLocation
|
||||
}
|
||||
|
||||
// Custom complex type derived from another complex type, with a constraint
|
||||
// on a property
|
||||
// Multiple proposed syntaxes, not yet defined
|
||||
|
||||
// Explicit, but new keyword
|
||||
type EquatorialPerson refines Person where Equatorial(_.home)
|
||||
|
||||
// Explicit with existing keyword, might be confusing if expectations regarding 'is'
|
||||
type EquatorialPerson is Person where Equatorial(_.home)
|
||||
|
||||
// Consistent and Python-friendly but can be confused with structural extension
|
||||
type EquatorialPerson(Person) where Equatorial(_.home)
|
||||
|
||||
// Allow new properties, probably not useful
|
||||
type EquatorialPerson extends Person where Equatorial(_.home)
|
||||
15
examples/00_syntax_prototype/04_functions.py
Normal file
15
examples/00_syntax_prototype/04_functions.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def func(
|
||||
col1: Column[float + (0 <= _ <= 1)],
|
||||
col2: Column[float + (0 <= _ <= 1)],
|
||||
) -> Column[float + (0 <= _ <= 2)]:
|
||||
result: Column[float + (0 <= _ <= 2)] = col1 + col2
|
||||
return result
|
||||
|
||||
|
||||
def func2(a: int, /, b: float, *, c: str):
|
||||
pass
|
||||
11
examples/01_simple_type_checking/01_simple_operations.py
Normal file
11
examples/01_simple_type_checking/01_simple_operations.py
Normal file
@@ -0,0 +1,11 @@
|
||||
a: int = 3
|
||||
b: int = 4
|
||||
|
||||
c = a + b # -> int
|
||||
|
||||
c = "invalid" # -> can't assign str to int variable
|
||||
|
||||
d = True
|
||||
e = d + d
|
||||
|
||||
f: float = a
|
||||
14
examples/01_simple_type_checking/02_simple_types.midas
Normal file
14
examples/01_simple_type_checking/02_simple_types.midas
Normal file
@@ -0,0 +1,14 @@
|
||||
type Meter(float)
|
||||
type Second(float)
|
||||
type MeterPerSecond(float)
|
||||
|
||||
extend Meter {
|
||||
op __add__(Meter) -> Meter
|
||||
op __sub__(Meter) -> Meter
|
||||
op __truediv__(Second) -> MeterPerSecond
|
||||
}
|
||||
|
||||
extend Second {
|
||||
op __add__(Second) -> Second
|
||||
op __sub__(Second) -> Second
|
||||
}
|
||||
8
examples/01_simple_type_checking/02_simple_types.py
Normal file
8
examples/01_simple_type_checking/02_simple_types.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
midas.using("02_simple_types.midas")
|
||||
|
||||
distance: Meter = cast(Meter, 123.45)
|
||||
time: Second = cast(Second, 6.7)
|
||||
speed = distance / time
|
||||
16
examples/01_simple_type_checking/03_control_flow.py
Normal file
16
examples/01_simple_type_checking/03_control_flow.py
Normal file
@@ -0,0 +1,16 @@
|
||||
def minimum(x: int, y: int):
|
||||
if x < y:
|
||||
return x
|
||||
else:
|
||||
return y
|
||||
|
||||
a = 15
|
||||
b = 72
|
||||
c = minimum(a, b)
|
||||
|
||||
def factorial(n: int) -> int:
|
||||
if n <= 1:
|
||||
return 1
|
||||
return n * factorial(n - 1)
|
||||
|
||||
category = "Category 1" if a < 10 else "Category 2"
|
||||
146
gen/gen.py
Normal file
146
gen/gen.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
HEADER = '''"""
|
||||
This file was generated by a script. Any manual changes might be overwritten.
|
||||
Please modify {defs_path} instead and run {gen_path}
|
||||
"""'''
|
||||
|
||||
SECTION_TEMPLATE = """{banner}
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class {base}(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
{visitor_methods}
|
||||
|
||||
|
||||
{classes}"""
|
||||
|
||||
TEMPLATE = """{header}
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
{imports}
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
{sections}
|
||||
"""
|
||||
|
||||
VISITOR_METHOD_TEMPLATE = """
|
||||
@abstractmethod
|
||||
def visit_{func_name}(self, {param}: {cls}) -> T: ...
|
||||
"""
|
||||
|
||||
CLASS_TEMPLATE = """
|
||||
@dataclass(frozen=True)
|
||||
class {cls}({base}):
|
||||
{body}
|
||||
|
||||
def accept(self, visitor: {base}.Visitor[T]) -> T:
|
||||
return visitor.visit_{func_name}(self)
|
||||
"""
|
||||
|
||||
SECTION_REGEX = re.compile(
|
||||
r"^###>\s*(?P<base>[^\n]*?)\s*\|\s*(?P<name>[^\n]*?)(\s*\|\s*(?P<param>[^\n]*?))?\s*?\n(?P<body>.*?)\n###<$",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
IMPORTS_REGEX = re.compile(
|
||||
r"^###>\s*Imports\s*?\n(?P<body>.*?)\n###<$",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def snake_case(text: str) -> str:
|
||||
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
||||
|
||||
|
||||
def make_visitor_method(cls: str, param: str):
|
||||
method: str = VISITOR_METHOD_TEMPLATE.format(
|
||||
func_name=snake_case(cls), param=param, cls=cls
|
||||
)
|
||||
return method.strip("\n")
|
||||
|
||||
|
||||
def make_class(name: str, cls: str, base: str):
|
||||
body: str = cls.split("\n", 1)[1]
|
||||
func_name: str = snake_case(name)
|
||||
cls_def: str = CLASS_TEMPLATE.format(
|
||||
cls=name,
|
||||
base=base,
|
||||
body=body,
|
||||
func_name=func_name,
|
||||
)
|
||||
return cls_def.strip("\n")
|
||||
|
||||
|
||||
def make_banner(text: str) -> str:
|
||||
middle: str = f"# {text} #"
|
||||
rule: str = "#" * len(middle)
|
||||
return "\n".join((rule, middle, rule))
|
||||
|
||||
|
||||
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||
visitor_methods: list[str] = []
|
||||
classes: list[str] = []
|
||||
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
||||
for cls in definitions:
|
||||
cls = cls.strip("\n")
|
||||
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
||||
print(f"Processing {name}")
|
||||
visitor_methods.append(make_visitor_method(name, param))
|
||||
classes.append(make_class(name, cls, base))
|
||||
|
||||
return SECTION_TEMPLATE.format(
|
||||
banner=make_banner(full_name),
|
||||
base=base,
|
||||
visitor_methods="\n\n".join(visitor_methods),
|
||||
classes="\n\n\n".join(classes),
|
||||
)
|
||||
|
||||
|
||||
def generate(definitions_path: Path, out_path: Path):
|
||||
root_dir: Path = Path(__file__).parent.parent
|
||||
rel_path: Path = definitions_path.relative_to(root_dir)
|
||||
src: str = definitions_path.read_text()
|
||||
sections: list[str] = []
|
||||
|
||||
imports: str = ""
|
||||
if m := IMPORTS_REGEX.search(src):
|
||||
imports = m.group("body").strip("\n")
|
||||
|
||||
for section_m in SECTION_REGEX.finditer(src):
|
||||
full_name: str = section_m.group("name")
|
||||
base: str = section_m.group("base")
|
||||
param: str = section_m.group("param") or base.lower()
|
||||
body: str = section_m.group("body")
|
||||
sections.append(make_section(full_name, base, param, body))
|
||||
|
||||
result: str = TEMPLATE.format(
|
||||
header=HEADER.format(
|
||||
defs_path=rel_path,
|
||||
gen_path=Path(__file__).relative_to(root_dir),
|
||||
),
|
||||
imports=imports,
|
||||
sections="\n\n\n".join(sections),
|
||||
)
|
||||
out_path.write_text(result)
|
||||
|
||||
|
||||
def main():
|
||||
root: Path = Path(__file__).parent.parent
|
||||
defs_dir: Path = root / "gen"
|
||||
ast_dir: Path = root / "midas" / "ast"
|
||||
generate(defs_dir / "midas.py", ast_dir / "midas.py")
|
||||
generate(defs_dir / "python.py", ast_dir / "python.py")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
110
gen/midas.py
Normal file
110
gen/midas.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821, F401]
|
||||
|
||||
###> Imports
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.lexer.token import Token
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Stmt | Statements
|
||||
class SimpleTypeStmt:
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
base: TypeExpr
|
||||
constraint: Optional[Expr]
|
||||
|
||||
|
||||
class ComplexTypeStmt:
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
|
||||
class PropertyStmt:
|
||||
name: Token
|
||||
type: TypeExpr
|
||||
constraint: Optional[Expr]
|
||||
|
||||
|
||||
class ExtendStmt:
|
||||
type: TypeExpr
|
||||
operations: list[OpStmt]
|
||||
|
||||
|
||||
class OpStmt:
|
||||
name: Token
|
||||
operand: TypeExpr
|
||||
result: TypeExpr
|
||||
|
||||
|
||||
class PredicateStmt:
|
||||
name: Token
|
||||
subject: Token
|
||||
type: TypeExpr
|
||||
condition: Expr
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Expr | Expressions
|
||||
class SimpleTypeExpr:
|
||||
name: Token
|
||||
optional: bool
|
||||
|
||||
|
||||
class LogicalExpr:
|
||||
left: Expr
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
|
||||
class BinaryExpr:
|
||||
left: Expr
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
|
||||
class UnaryExpr:
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
|
||||
class GetExpr:
|
||||
expr: Expr
|
||||
name: Token
|
||||
|
||||
|
||||
class VariableExpr:
|
||||
name: Token
|
||||
|
||||
|
||||
class GroupingExpr:
|
||||
expr: Expr
|
||||
|
||||
|
||||
class LiteralExpr:
|
||||
value: Any
|
||||
|
||||
|
||||
class WildcardExpr:
|
||||
token: Token
|
||||
|
||||
|
||||
class TemplateExpr:
|
||||
type: TypeExpr
|
||||
|
||||
|
||||
class TypeExpr:
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
optional: bool
|
||||
|
||||
|
||||
###<
|
||||
148
gen/python.py
Normal file
148
gen/python.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821, F401]
|
||||
|
||||
###> Imports
|
||||
import ast
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> MidasType | Type annotations | node
|
||||
class BaseType:
|
||||
base: str
|
||||
param: Optional[MidasType]
|
||||
|
||||
|
||||
class ConstraintType:
|
||||
type: MidasType
|
||||
constraint: ast.expr
|
||||
|
||||
|
||||
class FrameColumn:
|
||||
name: Optional[str]
|
||||
type: Optional[MidasType]
|
||||
|
||||
|
||||
class FrameType:
|
||||
columns: list[FrameColumn]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Stmt | Statements
|
||||
class ExpressionStmt:
|
||||
expr: Expr
|
||||
|
||||
|
||||
class Function:
|
||||
name: str
|
||||
posonlyargs: list[Argument]
|
||||
args: list[Argument]
|
||||
sink: Optional[Argument]
|
||||
kwonlyargs: list[Argument]
|
||||
kw_sink: Optional[Argument]
|
||||
returns: Optional[MidasType]
|
||||
body: list[Stmt]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: str
|
||||
type: Optional[MidasType]
|
||||
default: Optional[Expr]
|
||||
|
||||
@property
|
||||
def all_args(self) -> list[Argument]:
|
||||
return self.posonlyargs + self.args + self.kwonlyargs
|
||||
|
||||
|
||||
class TypeAssign:
|
||||
name: str
|
||||
type: MidasType
|
||||
|
||||
|
||||
class AssignStmt:
|
||||
targets: list[Expr]
|
||||
value: Expr
|
||||
|
||||
|
||||
class ReturnStmt:
|
||||
value: Optional[Expr]
|
||||
|
||||
|
||||
class IfStmt:
|
||||
test: Expr
|
||||
body: list[Stmt]
|
||||
orelse: list[Stmt]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Expr | Expressions
|
||||
class BinaryExpr:
|
||||
left: Expr
|
||||
operator: ast.operator
|
||||
right: Expr
|
||||
|
||||
|
||||
class CompareExpr:
|
||||
left: Expr
|
||||
operator: ast.cmpop
|
||||
right: Expr
|
||||
|
||||
|
||||
class UnaryExpr:
|
||||
operator: ast.unaryop
|
||||
right: Expr
|
||||
|
||||
|
||||
class CallExpr:
|
||||
callee: Expr
|
||||
arguments: list[Expr]
|
||||
keywords: dict[str, Expr]
|
||||
|
||||
|
||||
class GetExpr:
|
||||
object: Expr
|
||||
name: str
|
||||
|
||||
|
||||
class LiteralExpr:
|
||||
value: Any
|
||||
|
||||
|
||||
class VariableExpr:
|
||||
name: str
|
||||
|
||||
|
||||
class LogicalExpr:
|
||||
left: Expr
|
||||
operator: ast.boolop
|
||||
right: Expr
|
||||
|
||||
|
||||
class SetExpr:
|
||||
object: Expr
|
||||
name: str
|
||||
value: Expr
|
||||
|
||||
|
||||
class CastExpr:
|
||||
type: MidasType
|
||||
expr: Expr
|
||||
|
||||
|
||||
class TernaryExpr:
|
||||
test: Expr
|
||||
if_true: Expr
|
||||
if_false: Expr
|
||||
|
||||
|
||||
###<
|
||||
@@ -1,102 +0,0 @@
|
||||
from lexer.base import Lexer
|
||||
from lexer.keyword import ANNOTATION_KEYWORDS
|
||||
from lexer.token import TokenType
|
||||
|
||||
|
||||
class AnnotationLexer(Lexer):
|
||||
def scan_token(self) -> None:
|
||||
char: str = self.advance()
|
||||
match char:
|
||||
case "(":
|
||||
self.add_token(TokenType.LEFT_PAREN)
|
||||
case ")":
|
||||
self.add_token(TokenType.RIGHT_PAREN)
|
||||
case "[":
|
||||
self.add_token(TokenType.LEFT_BRACKET)
|
||||
case "]":
|
||||
self.add_token(TokenType.RIGHT_BRACKET)
|
||||
case "<":
|
||||
self.add_token(
|
||||
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS
|
||||
)
|
||||
case ">":
|
||||
self.add_token(
|
||||
TokenType.GREATER_EQUAL if self.match("=") else TokenType.GREATER
|
||||
)
|
||||
case "=":
|
||||
self.add_token(
|
||||
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
|
||||
)
|
||||
case "!":
|
||||
if self.match("="):
|
||||
self.add_token(TokenType.BANG_EQUAL)
|
||||
else:
|
||||
self.error("Unexpected single bang. Did you mean '!=' ?")
|
||||
case ":":
|
||||
self.add_token(TokenType.COLON)
|
||||
case ",":
|
||||
self.add_token(TokenType.COMMA)
|
||||
case "_":
|
||||
self.add_token(TokenType.UNDERSCORE)
|
||||
case "+":
|
||||
self.add_token(TokenType.PLUS)
|
||||
case "#":
|
||||
self.scan_comment()
|
||||
case "\n":
|
||||
self.add_token(TokenType.NEWLINE)
|
||||
case " " | "\r" | "\t":
|
||||
# Consume all whitespace characters until EOL or EOF
|
||||
while (
|
||||
self.peek().isspace()
|
||||
and self.peek() != "\n"
|
||||
and not self.is_at_end()
|
||||
):
|
||||
self.advance()
|
||||
self.add_token(TokenType.WHITESPACE)
|
||||
case _:
|
||||
if char.isdigit():
|
||||
self.scan_number()
|
||||
elif char.isalpha():
|
||||
self.scan_identifier()
|
||||
else:
|
||||
self.error("Unexpected character")
|
||||
return None
|
||||
|
||||
def scan_number(self):
|
||||
"""Scan the rest of number and add it as a token
|
||||
|
||||
This method handles both simple integers and floats. Scientific notation
|
||||
and base prefixes (0x, 0b, 0o) are not supported
|
||||
"""
|
||||
while self.peek().isdigit():
|
||||
self.advance()
|
||||
|
||||
if self.peek() == "." and self.peek_next().isdigit():
|
||||
self.advance()
|
||||
while self.peek().isdigit():
|
||||
self.advance()
|
||||
|
||||
value: float = float(self.source[self.start : self.idx])
|
||||
self.add_token(TokenType.NUMBER, value)
|
||||
|
||||
def scan_identifier(self):
|
||||
"""Scan the rest of an identifier and add it as a token
|
||||
|
||||
An identifier starts with a letter, followed by any number of
|
||||
alphanumerical characters or underscores
|
||||
"""
|
||||
while self.peek().isalnum() or self.peek() == "_":
|
||||
self.advance()
|
||||
|
||||
lexeme: str = self.source[self.start : self.idx]
|
||||
token_type: TokenType = ANNOTATION_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
||||
self.add_token(token_type)
|
||||
|
||||
def scan_comment(self):
|
||||
"""Scan the rest of a comment and add it as a token
|
||||
|
||||
A comment starts with a `#` character and ends at the EOL/EOF
|
||||
"""
|
||||
while self.peek() != "\n" and not self.is_at_end():
|
||||
self.advance()
|
||||
self.add_token(TokenType.COMMENT)
|
||||
@@ -1,16 +0,0 @@
|
||||
from lexer.token import TokenType
|
||||
|
||||
ANNOTATION_KEYWORDS: dict[str, TokenType] = {
|
||||
"True": TokenType.TRUE,
|
||||
"False": TokenType.FALSE,
|
||||
"None": TokenType.NONE,
|
||||
}
|
||||
|
||||
MIDAS_KEYWORDS: dict[str, TokenType] = {
|
||||
"type": TokenType.TYPE,
|
||||
"op": TokenType.OP,
|
||||
"constraint": TokenType.CONSTRAINT,
|
||||
"true": TokenType.TRUE,
|
||||
"false": TokenType.FALSE,
|
||||
"none": TokenType.NONE,
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any
|
||||
|
||||
from lexer.position import Position
|
||||
|
||||
|
||||
class TokenType(Enum):
|
||||
# Punctuation
|
||||
LEFT_PAREN = auto()
|
||||
RIGHT_PAREN = auto()
|
||||
LEFT_BRACKET = auto()
|
||||
RIGHT_BRACKET = auto()
|
||||
LEFT_BRACE = auto()
|
||||
RIGHT_BRACE = auto()
|
||||
COLON = auto()
|
||||
COMMA = auto()
|
||||
UNDERSCORE = auto()
|
||||
|
||||
# Operators
|
||||
PLUS = auto()
|
||||
MINUS = auto()
|
||||
STAR = auto()
|
||||
SLASH = auto()
|
||||
GREATER = auto()
|
||||
GREATER_EQUAL = auto()
|
||||
LESS = auto()
|
||||
LESS_EQUAL = auto()
|
||||
EQUAL = auto()
|
||||
EQUAL_EQUAL = auto()
|
||||
BANG_EQUAL = auto()
|
||||
|
||||
# Literals
|
||||
IDENTIFIER = auto()
|
||||
NUMBER = auto()
|
||||
TRUE = auto()
|
||||
FALSE = auto()
|
||||
NONE = auto()
|
||||
|
||||
# Keywords
|
||||
TYPE = auto()
|
||||
OP = auto()
|
||||
CONSTRAINT = auto()
|
||||
|
||||
# Misc
|
||||
COMMENT = auto()
|
||||
WHITESPACE = auto()
|
||||
EOF = auto()
|
||||
NEWLINE = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Token:
|
||||
"""A scanned token"""
|
||||
|
||||
type: TokenType
|
||||
lexeme: str
|
||||
value: Any
|
||||
position: Position
|
||||
37
midas/ast/location.py
Normal file
37
midas/ast/location.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Protocol
|
||||
|
||||
|
||||
class HasLocation(Protocol):
|
||||
lineno: int
|
||||
col_offset: int
|
||||
end_lineno: Optional[int]
|
||||
end_col_offset: Optional[int]
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Location:
|
||||
lineno: int
|
||||
col_offset: int
|
||||
end_lineno: Optional[int]
|
||||
end_col_offset: Optional[int]
|
||||
|
||||
@staticmethod
|
||||
def from_ast(obj: HasLocation) -> Location:
|
||||
return Location(
|
||||
lineno=obj.lineno,
|
||||
col_offset=obj.col_offset,
|
||||
end_lineno=obj.end_lineno,
|
||||
end_col_offset=obj.end_col_offset,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def span(start: Location, end: Location) -> Location:
|
||||
return Location(
|
||||
lineno=start.lineno,
|
||||
col_offset=start.col_offset,
|
||||
end_lineno=end.lineno,
|
||||
end_col_offset=end.end_col_offset,
|
||||
)
|
||||
251
midas/ast/midas.py
Normal file
251
midas/ast/midas.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
This file was generated by a script. Any manual changes might be overwritten.
|
||||
Please modify gen/midas.py instead and run gen/gen.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.lexer.token import Token
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Stmt(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_op_stmt(self, stmt: OpStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_predicate_stmt(self, stmt: PredicateStmt) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SimpleTypeStmt(Stmt):
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
base: TypeExpr
|
||||
constraint: Optional[Expr]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_simple_type_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ComplexTypeStmt(Stmt):
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_complex_type_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PropertyStmt(Stmt):
|
||||
name: Token
|
||||
type: TypeExpr
|
||||
constraint: Optional[Expr]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_property_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtendStmt(Stmt):
|
||||
type: TypeExpr
|
||||
operations: list[OpStmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_extend_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OpStmt(Stmt):
|
||||
name: Token
|
||||
operand: TypeExpr
|
||||
result: TypeExpr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_op_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PredicateStmt(Stmt):
|
||||
name: Token
|
||||
subject: Token
|
||||
type: TypeExpr
|
||||
condition: Expr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_predicate_stmt(self)
|
||||
|
||||
|
||||
###############
|
||||
# Expressions #
|
||||
###############
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Expr(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_variable_expr(self, expr: VariableExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_grouping_expr(self, expr: GroupingExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_template_expr(self, expr: TemplateExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_expr(self, expr: TypeExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SimpleTypeExpr(Expr):
|
||||
name: Token
|
||||
optional: bool
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_simple_type_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LogicalExpr(Expr):
|
||||
left: Expr
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_logical_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BinaryExpr(Expr):
|
||||
left: Expr
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_binary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UnaryExpr(Expr):
|
||||
operator: Token
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_unary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetExpr(Expr):
|
||||
expr: Expr
|
||||
name: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_get_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VariableExpr(Expr):
|
||||
name: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_variable_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GroupingExpr(Expr):
|
||||
expr: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_grouping_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteralExpr(Expr):
|
||||
value: Any
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_literal_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WildcardExpr(Expr):
|
||||
token: Token
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_wildcard_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TemplateExpr(Expr):
|
||||
type: TypeExpr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_template_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeExpr(Expr):
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
optional: bool
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_type_expr(self)
|
||||
610
midas/ast/printer.py
Normal file
610
midas/ast/printer.py
Normal file
@@ -0,0 +1,610 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import io
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from typing import Generator, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
|
||||
|
||||
class _Level(Enum):
|
||||
EMPTY = auto()
|
||||
ACTIVE = auto()
|
||||
LAST = auto()
|
||||
|
||||
|
||||
class Expr(Protocol):
|
||||
def accept(self, printer: AstPrinter) -> None: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Expr)
|
||||
|
||||
|
||||
class AstPrinter(Generic[T]):
|
||||
LAST_CHILD = "└── "
|
||||
CHILD = "├── "
|
||||
VERTICAL = "│ "
|
||||
EMPTY = " "
|
||||
|
||||
def __init__(self):
|
||||
self._levels: list[_Level] = []
|
||||
self._idx: Optional[int] = None
|
||||
self._buf: io.StringIO = io.StringIO()
|
||||
|
||||
def print(self, expr: T):
|
||||
self._buf = io.StringIO()
|
||||
expr.accept(self)
|
||||
return self._buf.getvalue()
|
||||
|
||||
@contextmanager
|
||||
def _child_level(self, single: bool = False) -> Generator[None, None, None]:
|
||||
self._levels.append(_Level.LAST if single else _Level.ACTIVE)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._levels.pop()
|
||||
|
||||
def _mark_last(self):
|
||||
if self._levels:
|
||||
self._levels[-1] = _Level.LAST
|
||||
|
||||
def _write_line(self, text: str, *, last: bool = False):
|
||||
if last:
|
||||
self._mark_last()
|
||||
indent: str = self._build_indent()
|
||||
if self._idx is not None:
|
||||
text = f"[{self._idx}] {text}"
|
||||
self._idx = None
|
||||
self._buf.write(indent + text + "\n")
|
||||
|
||||
def _build_indent(self) -> str:
|
||||
parts: list[str] = []
|
||||
for level in self._levels[:-1]:
|
||||
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
|
||||
if self._levels:
|
||||
if self._levels[-1] == _Level.LAST:
|
||||
parts.append(self.LAST_CHILD)
|
||||
self._levels[-1] = _Level.EMPTY
|
||||
else:
|
||||
parts.append(self.CHILD)
|
||||
return "".join(parts)
|
||||
|
||||
def _write_optional_child(
|
||||
self, label: str, child: Optional[T], *, last: bool = False
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
if child is None:
|
||||
self._write_line(f"{label}: None")
|
||||
else:
|
||||
self._write_line(label)
|
||||
with self._child_level(single=True):
|
||||
child.accept(self)
|
||||
|
||||
|
||||
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
||||
# Statements
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
|
||||
self._write_line("SimpleTypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_optional_child("template", stmt.template)
|
||||
self._write_line("base")
|
||||
with self._child_level(single=True):
|
||||
stmt.base.accept(self)
|
||||
self._write_optional_child("constraint", stmt.constraint, last=True)
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
|
||||
self._write_line("ComplexTypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_optional_child("template", stmt.template)
|
||||
self._write_line("properties", last=True)
|
||||
with self._child_level():
|
||||
for i, prop in enumerate(stmt.properties):
|
||||
self._idx = i
|
||||
if i == len(stmt.properties) - 1:
|
||||
self._mark_last()
|
||||
prop.accept(self)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
self._write_line("PropertyStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
self._write_optional_child("constraint", stmt.constraint, last=True)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._write_line("ExtendStmt")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
self._write_line("operations", last=True)
|
||||
with self._child_level():
|
||||
for i, op in enumerate(stmt.operations):
|
||||
self._idx = i
|
||||
if i == len(stmt.operations) - 1:
|
||||
self._mark_last()
|
||||
op.accept(self)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
||||
self._write_line("OpStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
|
||||
self._write_line("operand")
|
||||
with self._child_level(single=True):
|
||||
stmt.operand.accept(self)
|
||||
|
||||
self._write_line("result", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.result.accept(self)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
self._write_line("PredicateStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line(f'subject: "{stmt.subject.lexeme}"')
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
self._write_line("condition", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.condition.accept(self)
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
|
||||
self._write_line("SimpleTypeExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||
self._write_line(f"optional: {expr.optional}", last=True)
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.lexeme}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr")
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
self._write_line("GroupingExpr")
|
||||
with self._child_level():
|
||||
self._write_line("expr", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"value: {expr.value}", last=True)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
|
||||
self._write_line("TemplateExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
expr.type.accept(self)
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr):
|
||||
self._write_line("TypeExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||
self._write_optional_child("template", expr.template)
|
||||
self._write_line(f"optional: {expr.optional}", last=True)
|
||||
|
||||
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
def __init__(self, indent: int = 4):
|
||||
self.indent: int = indent
|
||||
self.level: int = 0
|
||||
|
||||
def indented(self, text: str) -> str:
|
||||
return " " * (self.level * self.indent) + text
|
||||
|
||||
def print(self, expr: m.Expr | m.Stmt):
|
||||
self.level = 0
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
|
||||
template: str = stmt.template.accept(self) if stmt.template is not None else ""
|
||||
res: str = f"type {stmt.name.lexeme}{template}({stmt.base.accept(self)})"
|
||||
if stmt.constraint is not None:
|
||||
res += " where " + stmt.constraint.accept(self)
|
||||
return self.indented(res)
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
|
||||
template: str = stmt.template.accept(self) if stmt.template is not None else ""
|
||||
res: str = self.indented(f"type {stmt.name.lexeme}{template}")
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
for prop in stmt.properties:
|
||||
res += prop.accept(self)
|
||||
res += "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
if stmt.constraint is not None:
|
||||
res += " where " + stmt.constraint.accept(self)
|
||||
return self.indented(res)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt):
|
||||
res: str = self.indented(f"extend {stmt.type.accept(self)}")
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
for op in stmt.operations:
|
||||
res += op.accept(self)
|
||||
self.level -= 1
|
||||
res += "\n" + self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt):
|
||||
operand: str = stmt.operand.accept(self)
|
||||
result: str = stmt.result.accept(self)
|
||||
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}")
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
subject: str = stmt.subject.lexeme
|
||||
type: str = stmt.type.accept(self)
|
||||
condition: str = stmt.condition.accept(self)
|
||||
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
|
||||
|
||||
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
|
||||
return f"{expr.name.lexeme}{'?' if expr.optional else ''}"
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{left} {operator} {right}"
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr):
|
||||
operator: str = expr.operator.lexeme
|
||||
right: str = expr.right.accept(self)
|
||||
return f"{operator}{right}"
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
name: str = expr.name.lexeme
|
||||
return f"{expr_}.{name}"
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr):
|
||||
return expr.name.lexeme
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr):
|
||||
expr_: str = expr.expr.accept(self)
|
||||
return f"({expr_})"
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr):
|
||||
return str(expr.value)
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
||||
return "_"
|
||||
|
||||
def visit_template_expr(self, expr: m.TemplateExpr):
|
||||
return f"[{expr.type.accept(self)}]"
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr):
|
||||
template: str = expr.template.accept(self) if expr.template is not None else ""
|
||||
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}"
|
||||
|
||||
|
||||
class PythonAstPrinter(
|
||||
AstPrinter,
|
||||
p.MidasType.Visitor[None],
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[None],
|
||||
):
|
||||
def visit_base_type(self, node: p.BaseType) -> None:
|
||||
self._write_line("BaseType")
|
||||
with self._child_level():
|
||||
self._write_line(f"base: {node.base}")
|
||||
self._write_optional_child("param", node.param, last=True)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
node.type.accept(self)
|
||||
self._write_line(f"constraint: {ast.unparse(node.constraint)}", last=True)
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> None:
|
||||
self._write_line("FrameColumn")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {node.name}")
|
||||
self._write_optional_child("type", node.type, last=True)
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> None:
|
||||
self._write_line("FrameType")
|
||||
with self._child_level():
|
||||
self._write_line("columns", last=True)
|
||||
with self._child_level():
|
||||
for i, col in enumerate(node.columns):
|
||||
self._idx = i
|
||||
if i == len(node.columns) - 1:
|
||||
self._mark_last()
|
||||
col.accept(self)
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
self._write_line("Function")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
|
||||
self._write_line("posonlyargs")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(stmt.posonlyargs):
|
||||
self._idx = i
|
||||
if i == len(stmt.posonlyargs) - 1:
|
||||
self._mark_last()
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_line("args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(stmt.args):
|
||||
self._idx = i
|
||||
if i == len(stmt.args) - 1:
|
||||
self._mark_last()
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_line("kwonlyargs")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(stmt.kwonlyargs):
|
||||
self._idx = i
|
||||
if i == len(stmt.kwonlyargs) - 1:
|
||||
self._mark_last()
|
||||
self._print_argument(arg)
|
||||
|
||||
self._write_optional_child("returns", stmt.returns)
|
||||
self._write_line("body", last=True)
|
||||
with self._child_level():
|
||||
for i, body_stmt in enumerate(stmt.body):
|
||||
self._idx = i
|
||||
if i == len(stmt.body) - 1:
|
||||
self._mark_last()
|
||||
body_stmt.accept(self)
|
||||
|
||||
def _print_argument(self, arg: p.Function.Argument) -> None:
|
||||
self._write_line("FunctionArgument")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {arg.name}")
|
||||
self._write_optional_child("type", arg.type, last=True)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
self._write_line("TypeAssign")
|
||||
with self._child_level():
|
||||
self._write_line(f"name: {stmt.name}")
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
self._write_line("AssignStmt")
|
||||
with self._child_level():
|
||||
self._write_line("targets")
|
||||
with self._child_level():
|
||||
for i, target in enumerate(stmt.targets):
|
||||
self._idx = i
|
||||
if i == len(stmt.targets) - 1:
|
||||
self._mark_last()
|
||||
target.accept(self)
|
||||
self._write_line("value", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
self._write_line("ReturnStmt")
|
||||
with self._child_level():
|
||||
self._write_optional_child("value", stmt.value, last=True)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
self._write_line("IfStmt")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
stmt.test.accept(self)
|
||||
self._write_line("body")
|
||||
with self._child_level():
|
||||
for i, body_stmt in enumerate(stmt.body):
|
||||
self._idx = i
|
||||
if i == len(stmt.body) - 1:
|
||||
self._mark_last()
|
||||
body_stmt.accept(self)
|
||||
self._write_line("orelse", last=True)
|
||||
with self._child_level():
|
||||
for i, else_stmt in enumerate(stmt.orelse):
|
||||
self._idx = i
|
||||
if i == len(stmt.orelse) - 1:
|
||||
self._mark_last()
|
||||
else_stmt.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||
self._write_line("BinaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||
self._write_line("CompareExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||
self._write_line("UnaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self._write_line("CallExpr")
|
||||
with self._child_level():
|
||||
self._write_line("callee")
|
||||
with self._child_level(single=True):
|
||||
expr.callee.accept(self)
|
||||
|
||||
self._write_line("arguments")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(expr.arguments):
|
||||
self._idx = i
|
||||
if i == len(expr.arguments) - 1:
|
||||
self._mark_last()
|
||||
arg.accept(self)
|
||||
|
||||
self._write_line("keywords", last=True)
|
||||
with self._child_level():
|
||||
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||
self._idx = i
|
||||
if i == len(expr.keywords) - 1:
|
||||
self._mark_last()
|
||||
self._write_line(name)
|
||||
with self._child_level(single=True):
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||
self._write_line("GetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line(f"name: {expr.name}", last=True)
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"value: {expr.value}")
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||
self._write_line("VariableExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"name: {expr.name}")
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
self._write_line("left")
|
||||
with self._child_level(single=True):
|
||||
expr.left.accept(self)
|
||||
|
||||
self._write_line(f"operator: {expr.operator.__class__.__name__}")
|
||||
|
||||
self._write_line("right", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> None:
|
||||
self._write_line("SetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line(f"name: {expr.name}")
|
||||
self._write_line("value", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.value.accept(self)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self._write_line("CastExpr")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
expr.type.accept(self)
|
||||
self._write_line("expr", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||
self._write_line("TernaryExpr")
|
||||
with self._child_level():
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
expr.test.accept(self)
|
||||
|
||||
self._write_line("if_true")
|
||||
with self._child_level(single=True):
|
||||
expr.if_true.accept(self)
|
||||
|
||||
self._write_line("if_false", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.if_false.accept(self)
|
||||
327
midas/ast/python.py
Normal file
327
midas/ast/python.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
This file was generated by a script. Any manual changes might be overwritten.
|
||||
Please modify gen/python.py instead and run gen/gen.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
####################
|
||||
# Type annotations #
|
||||
####################
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MidasType(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_base_type(self, node: BaseType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_type(self, node: ConstraintType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_frame_column(self, node: FrameColumn) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_frame_type(self, node: FrameType) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseType(MidasType):
|
||||
base: str
|
||||
param: Optional[MidasType]
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_base_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintType(MidasType):
|
||||
type: MidasType
|
||||
constraint: ast.expr
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrameColumn(MidasType):
|
||||
name: Optional[str]
|
||||
type: Optional[MidasType]
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_frame_column(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrameType(MidasType):
|
||||
columns: list[FrameColumn]
|
||||
|
||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||
return visitor.visit_frame_type(self)
|
||||
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Stmt(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_expression_stmt(self, stmt: ExpressionStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_function(self, stmt: Function) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_assign(self, stmt: TypeAssign) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_assign_stmt(self, stmt: AssignStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExpressionStmt(Stmt):
|
||||
expr: Expr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_expression_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Function(Stmt):
|
||||
name: str
|
||||
posonlyargs: list[Argument]
|
||||
args: list[Argument]
|
||||
sink: Optional[Argument]
|
||||
kwonlyargs: list[Argument]
|
||||
kw_sink: Optional[Argument]
|
||||
returns: Optional[MidasType]
|
||||
body: list[Stmt]
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: str
|
||||
type: Optional[MidasType]
|
||||
default: Optional[Expr]
|
||||
|
||||
@property
|
||||
def all_args(self) -> list[Argument]:
|
||||
return self.posonlyargs + self.args + self.kwonlyargs
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_function(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeAssign(Stmt):
|
||||
name: str
|
||||
type: MidasType
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_type_assign(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssignStmt(Stmt):
|
||||
targets: list[Expr]
|
||||
value: Expr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_assign_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReturnStmt(Stmt):
|
||||
value: Optional[Expr]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_return_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IfStmt(Stmt):
|
||||
test: Expr
|
||||
body: list[Stmt]
|
||||
orelse: list[Stmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_if_stmt(self)
|
||||
|
||||
|
||||
###############
|
||||
# Expressions #
|
||||
###############
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Expr(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_compare_expr(self, expr: CompareExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_call_expr(self, expr: CallExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_variable_expr(self, expr: VariableExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_set_expr(self, expr: SetExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_cast_expr(self, expr: CastExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BinaryExpr(Expr):
|
||||
left: Expr
|
||||
operator: ast.operator
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_binary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CompareExpr(Expr):
|
||||
left: Expr
|
||||
operator: ast.cmpop
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_compare_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UnaryExpr(Expr):
|
||||
operator: ast.unaryop
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_unary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CallExpr(Expr):
|
||||
callee: Expr
|
||||
arguments: list[Expr]
|
||||
keywords: dict[str, Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_call_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetExpr(Expr):
|
||||
object: Expr
|
||||
name: str
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_get_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiteralExpr(Expr):
|
||||
value: Any
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_literal_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VariableExpr(Expr):
|
||||
name: str
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_variable_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LogicalExpr(Expr):
|
||||
left: Expr
|
||||
operator: ast.boolop
|
||||
right: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_logical_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SetExpr(Expr):
|
||||
object: Expr
|
||||
name: str
|
||||
value: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_set_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CastExpr(Expr):
|
||||
type: MidasType
|
||||
expr: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_cast_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TernaryExpr(Expr):
|
||||
test: Expr
|
||||
if_true: Expr
|
||||
if_false: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_ternary_expr(self)
|
||||
546
midas/checker/checker.py
Normal file
546
midas/checker/checker.py
Normal file
@@ -0,0 +1,546 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
||||
from midas.checker.types import BaseType, Function, SimpleType, Type, UnitType, UnknownType
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
from midas.resolver.midas import MidasResolver
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument:
|
||||
expr: p.Expr
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
|
||||
|
||||
class Checker(
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[Type],
|
||||
p.MidasType.Visitor[Type],
|
||||
):
|
||||
"""A type checker which can use custom type definitions"""
|
||||
|
||||
def __init__(self, locals: dict[p.Expr, int], file_path: Path):
|
||||
self.logger: logging.Logger = logging.getLogger("Checker")
|
||||
self.file_path: Path = file_path
|
||||
self.ctx: MidasResolver = MidasResolver()
|
||||
self.global_env: Environment = Environment()
|
||||
self.env: Environment = self.global_env
|
||||
self.locals: dict[p.Expr, int] = locals
|
||||
self.diagnostics: list[Diagnostic] = []
|
||||
|
||||
def diagnostic(self, type: DiagnosticType, location: Location, message: str):
|
||||
self.diagnostics.append(
|
||||
Diagnostic(
|
||||
file_path=self.file_path,
|
||||
location=location,
|
||||
type=type,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
def error(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.ERROR,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def warning(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.WARNING,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def info(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.INFO,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def evaluate(self, expr: p.Expr) -> Type:
|
||||
"""Evaluate the type of an expression
|
||||
|
||||
Args:
|
||||
expr (p.Expr): the expression to evaluate
|
||||
|
||||
Returns:
|
||||
Type: the type of the given expression
|
||||
"""
|
||||
return expr.accept(self)
|
||||
|
||||
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
||||
"""Evaluate a sequence of statements
|
||||
|
||||
Args:
|
||||
block (list[p.Stmt]): the statements to evaluate
|
||||
env (Environment): the environment in which to evaluate
|
||||
|
||||
Returns:
|
||||
bool: whether a return statement is present in the block
|
||||
"""
|
||||
previous_env: Environment = self.env
|
||||
self.env = env
|
||||
returned: bool = False
|
||||
for i, stmt in enumerate(block):
|
||||
try:
|
||||
stmt.accept(self)
|
||||
except ReturnException:
|
||||
returned = True
|
||||
if i < len(block) - 1:
|
||||
self.warning(block[i + 1].location, "Unreachable statement")
|
||||
break
|
||||
self.env = previous_env
|
||||
return returned
|
||||
|
||||
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
||||
"""Type check a sequence of statements and returns diagnostics
|
||||
|
||||
Args:
|
||||
statements (list[p.Stmt]): the statements to evaluate and check
|
||||
|
||||
Returns:
|
||||
list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
|
||||
"""
|
||||
self.diagnostics = []
|
||||
for stmt in statements:
|
||||
stmt.accept(self)
|
||||
|
||||
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
||||
return self.diagnostics
|
||||
|
||||
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
||||
"""Look up a variable in the environment it was declared
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
expr (p.Expr): the variable expression, used to lookup the scope distance
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the type of the variable, or None if it was not found
|
||||
"""
|
||||
distance: Optional[int] = self.locals.get(expr)
|
||||
if distance is not None:
|
||||
return self.env.get_at(distance, name)
|
||||
return self.global_env.get(name)
|
||||
|
||||
def parse_midas_import(self, expr: p.CallExpr) -> Optional[Path]:
|
||||
"""Parse a Midas import statement
|
||||
|
||||
The statement should be written as `midas.using("path/to/types.midas")`
|
||||
|
||||
Args:
|
||||
expr (p.CallExpr): the import call expression
|
||||
|
||||
Returns:
|
||||
Optional[Path]: the path to the imported file, or None if the expression is malformed
|
||||
"""
|
||||
match expr:
|
||||
case p.CallExpr(
|
||||
callee=p.GetExpr(
|
||||
object=p.VariableExpr(name="midas"),
|
||||
name="using",
|
||||
),
|
||||
arguments=[
|
||||
p.LiteralExpr(value=path),
|
||||
],
|
||||
):
|
||||
return Path(path)
|
||||
return None
|
||||
|
||||
def import_midas(self, path: Path) -> None:
|
||||
"""Import Midas definitions from a path
|
||||
|
||||
Args:
|
||||
path (Path): the import path
|
||||
"""
|
||||
self.logger.debug(f"Importing type definitions from {path}")
|
||||
path = (self.file_path.parent / path).resolve()
|
||||
lexer: MidasLexer = MidasLexer(path.read_text())
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
self.ctx.resolve(stmts)
|
||||
self.logger.debug(f"Midas types: {self.ctx._types}")
|
||||
self.logger.debug(f"Midas operations: {self.ctx._operations}")
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
self.evaluate(stmt.expr)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
env: Environment = Environment(self.env)
|
||||
pos_args: list[Function.Argument] = []
|
||||
args: list[Function.Argument] = []
|
||||
kw_args: list[Function.Argument] = []
|
||||
|
||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||
if arg.type is not None:
|
||||
return arg.type.accept(self)
|
||||
if arg.default is not None:
|
||||
return arg.default.accept(self)
|
||||
return UnknownType()
|
||||
|
||||
for arg in stmt.posonlyargs:
|
||||
pos_args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
for arg in stmt.args:
|
||||
args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
for arg in stmt.kwonlyargs:
|
||||
kw_args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
|
||||
for arg in pos_args + args + kw_args:
|
||||
env.define(arg.name, arg.type)
|
||||
|
||||
returns_hint: Optional[Type] = None
|
||||
if stmt.returns is not None:
|
||||
returns_hint = stmt.returns.accept(self)
|
||||
# Early define to handle simple fully-typed recursion
|
||||
inside_function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns_hint,
|
||||
)
|
||||
self.env.define(stmt.name, inside_function)
|
||||
|
||||
returned: bool = self.evaluate_block(stmt.body, env)
|
||||
inferred_return: Type = UnknownType()
|
||||
if not returned:
|
||||
env.return_types.append(UnitType())
|
||||
return_types: set[Type] = set(env.return_types)
|
||||
if len(return_types) == 1:
|
||||
inferred_return = list(return_types)[0]
|
||||
elif len(return_types) > 1:
|
||||
self.error(
|
||||
stmt.location,
|
||||
f"Mixed return types: {env.return_types}",
|
||||
)
|
||||
|
||||
returns: Type = UnknownType()
|
||||
if returns_hint is not None:
|
||||
assert stmt.returns is not None
|
||||
returns = returns_hint
|
||||
if returns != inferred_return:
|
||||
self.error(
|
||||
stmt.returns.location,
|
||||
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||
)
|
||||
else:
|
||||
returns = inferred_return
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns,
|
||||
)
|
||||
self.env.define(stmt.name, function)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
# TODO check not yet defined locally
|
||||
type: Type = stmt.type.accept(self)
|
||||
self.env.define(stmt.name, type)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
value: Type = self.evaluate(stmt.value)
|
||||
for target in stmt.targets:
|
||||
if not isinstance(target, p.VariableExpr):
|
||||
self.logger.warning(f"Unsupported assignment to {target}")
|
||||
self.warning(target.location, f"Unsupported assignment to {target}")
|
||||
continue
|
||||
name: str = target.name
|
||||
var_type: Optional[Type] = self.look_up_variable(name, target)
|
||||
|
||||
if var_type is None:
|
||||
self.env.define(name, value)
|
||||
else:
|
||||
# TODO: implement real comparison method
|
||||
if var_type != value:
|
||||
self.error(
|
||||
stmt.location,
|
||||
f"Cannot assign {value} to {name} of type {var_type}",
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
||||
self.env.return_types.append(type)
|
||||
raise ReturnException()
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
||||
# For example:
|
||||
# if (m := 1 + 1) < 2:
|
||||
# ...
|
||||
# print(m) # <- m is still defined
|
||||
test_type: Type = stmt.test.accept(self)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.ctx.get_type("bool"):
|
||||
self.error(
|
||||
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
env: Environment = Environment(self.env)
|
||||
body_returned: bool = self.evaluate_block(stmt.body, env)
|
||||
else_returned: bool = self.evaluate_block(stmt.orelse, env)
|
||||
self.env.return_types.extend(env.return_types)
|
||||
if body_returned and else_returned:
|
||||
raise ReturnException()
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
||||
return UnknownType()
|
||||
left: Type = self.evaluate(expr.left)
|
||||
right: Type = self.evaluate(expr.right)
|
||||
|
||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
||||
if result is None:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
return result
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
||||
return UnknownType()
|
||||
left: Type = self.evaluate(expr.left)
|
||||
right: Type = self.evaluate(expr.right)
|
||||
|
||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
||||
if result is None:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
return result
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||
if path := self.parse_midas_import(expr):
|
||||
self.import_midas(path)
|
||||
return UnknownType()
|
||||
callee: Type = self.evaluate(expr.callee)
|
||||
if not isinstance(callee, Function):
|
||||
self.error(expr.callee.location, "Callee is not a function")
|
||||
return UnknownType()
|
||||
function: Function = callee
|
||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||
for arg in mapped:
|
||||
if arg.type != arg.argument.type:
|
||||
self.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
return function.returns
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
||||
match expr.value:
|
||||
case bool(): # Must be before int
|
||||
return self.ctx.get_type("bool")
|
||||
case int():
|
||||
return self.ctx.get_type("int")
|
||||
case float():
|
||||
return self.ctx.get_type("float")
|
||||
case str():
|
||||
return self.ctx.get_type("str")
|
||||
case _:
|
||||
self.warning(expr.location, f"Unknown literal {expr}")
|
||||
return UnknownType()
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
||||
return self.look_up_variable(expr.name, expr) or UnknownType()
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: ...
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||
return expr.type.accept(self)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||
test_type: Type = expr.test.accept(self)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.ctx.get_type("bool"):
|
||||
self.error(
|
||||
expr.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
true_type: Type = expr.if_true.accept(self)
|
||||
false_type: Type = expr.if_false.accept(self)
|
||||
if true_type != false_type:
|
||||
self.error(expr.location, f"Type mismatch in ternary if branches: true={true_type} != false={false_type}")
|
||||
return UnknownType()
|
||||
return true_type
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||
return self.ctx.get_type(node.base)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
||||
|
||||
def map_call_arguments(
|
||||
self, function: Function, call: p.CallExpr
|
||||
) -> list[MappedArgument]:
|
||||
"""Map call arguments to function parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
call (p.CallExpr): the call expression
|
||||
|
||||
Returns:
|
||||
list[MappedArgument]: the list of mapped arguments
|
||||
"""
|
||||
positional: list[tuple[p.Expr, Type]] = [
|
||||
(arg, self.evaluate(arg)) for arg in call.arguments
|
||||
]
|
||||
keywords: dict[str, tuple[p.Expr, Type]] = {
|
||||
name: (arg, self.evaluate(arg)) for name, arg in call.keywords.items()
|
||||
}
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
self.error(arg[0].location, "Too many positional arguments")
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if name in set_args:
|
||||
self.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
self.error(
|
||||
call.location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
self.error(
|
||||
call.location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
|
||||
return mapped
|
||||
33
midas/checker/diagnostic.py
Normal file
33
midas/checker/diagnostic.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
|
||||
|
||||
class DiagnosticType(StrEnum):
|
||||
ERROR = "Error"
|
||||
WARNING = "Warning"
|
||||
INFO = "Info"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Diagnostic:
|
||||
file_path: Path
|
||||
location: Location
|
||||
type: DiagnosticType
|
||||
message: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
|
||||
end_loc: Optional[str] = ""
|
||||
if (
|
||||
self.location.end_lineno is not None
|
||||
and self.location.end_col_offset is not None
|
||||
):
|
||||
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
|
||||
loc: str = (
|
||||
f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}"
|
||||
)
|
||||
return f"{self.type} in {self.file_path} {loc}: {self.message}"
|
||||
142
midas/checker/environment.py
Normal file
142
midas/checker/environment.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from midas.checker.types import Type
|
||||
|
||||
|
||||
class Environment:
|
||||
"""
|
||||
A scoped environment in which variables are defined
|
||||
|
||||
Each environment can inherit from a parent/enclosing environment.
|
||||
"""
|
||||
|
||||
def __init__(self, enclosing: Optional[Environment] = None) -> None:
|
||||
self.enclosing: Optional[Environment] = enclosing
|
||||
self.values: dict[str, Type] = {}
|
||||
self.return_types: list[Type] = []
|
||||
|
||||
self._children: list[Environment] = []
|
||||
if enclosing is not None:
|
||||
enclosing._children.append(self)
|
||||
|
||||
def define(self, name: str, value: Type) -> None:
|
||||
"""Define a variable in this environment
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
value (Type): the value
|
||||
"""
|
||||
self.values[name] = value
|
||||
|
||||
def get(self, name: str) -> Optional[Type]:
|
||||
"""Get a variable in the closest environment which has a definition for it
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the value of the variable, or None if it was not found
|
||||
"""
|
||||
if name in self.values:
|
||||
return self.values[name]
|
||||
if self.enclosing is not None:
|
||||
return self.enclosing.get(name)
|
||||
# raise NameError(f"Undefined variable '{name}'")
|
||||
return None
|
||||
|
||||
def assign(self, name: str, value: Type) -> bool:
|
||||
"""Assign a new value to a variable in the environment it was defined in
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
value (Type): the new value
|
||||
|
||||
Returns:
|
||||
bool: True if the variable was assigned in this environment or an ancestor, False otherwise
|
||||
"""
|
||||
if name not in self.values:
|
||||
if self.enclosing is None:
|
||||
return False
|
||||
if self.enclosing.assign(name, value):
|
||||
return True
|
||||
self.values[name] = value
|
||||
return True
|
||||
|
||||
def clear(self):
|
||||
"""Clear all definitions in this environment"""
|
||||
self.values = {}
|
||||
|
||||
def get_at(self, distance: int, name: str) -> Optional[Type]:
|
||||
"""Get the value of a variable at a given distance
|
||||
|
||||
A distance of 0 looks up in this environment, 1 in the parent environment, etc.
|
||||
This methods expects `distance` to be valid. An error will be raised if
|
||||
the stack does not extend far enough to reach `distance`
|
||||
|
||||
Args:
|
||||
distance (int): the scope distance
|
||||
name (str): the name of the variable
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the value at the given distance, or None if it is not defined in that environment
|
||||
|
||||
Raises:
|
||||
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||
"""
|
||||
return self.ancestor(distance).values.get(name)
|
||||
|
||||
def assign_at(self, distance: int, name: str, value: Type) -> None:
|
||||
"""Assign a new value to a variable at a given distance
|
||||
|
||||
A distance of 0 assigns in this environment, 1 in the parent environment, etc.
|
||||
|
||||
Args:
|
||||
distance (int): the scope distance
|
||||
name (str): the name of the variable
|
||||
value (Type): the new value
|
||||
|
||||
Raises:
|
||||
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||
"""
|
||||
self.ancestor(distance).values[name] = value
|
||||
|
||||
def ancestor(self, distance: int) -> Environment:
|
||||
"""Get the ancestor at a given distance
|
||||
|
||||
A distance of 0 references this environment, 1 the parent environment, etc.
|
||||
|
||||
Args:
|
||||
distance (int): the scope distance
|
||||
|
||||
Returns:
|
||||
Environment: the environment
|
||||
|
||||
Raises:
|
||||
AssertionError: if the stack does not extend far enough to reach `distance`
|
||||
"""
|
||||
env: Environment = self
|
||||
for _ in range(distance):
|
||||
assert env.enclosing is not None
|
||||
env = env.enclosing
|
||||
return env
|
||||
|
||||
def flat_dict(self) -> dict[str, Type]:
|
||||
"""Get the current environment including definitions in its ancestor as a flat dictionary
|
||||
|
||||
This method recursively combines this environment definitions with its ancestor's
|
||||
|
||||
Returns:
|
||||
dict: the combined environment
|
||||
"""
|
||||
if self.enclosing is None:
|
||||
return self.values
|
||||
return self.enclosing.flat_dict() | self.values
|
||||
|
||||
def dump(self) -> dict:
|
||||
return {
|
||||
"values": self.values,
|
||||
"return_types": self.return_types,
|
||||
"children": [child.dump() for child in self._children],
|
||||
}
|
||||
31
midas/checker/operators.py
Normal file
31
midas/checker/operators.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import ast
|
||||
from typing import Type
|
||||
|
||||
OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
||||
ast.Add: "__add__",
|
||||
ast.Sub: "__sub__",
|
||||
ast.Mult: "__mul__",
|
||||
ast.MatMult: "__matmul__",
|
||||
ast.Div: "__truediv__",
|
||||
ast.Mod: "__mod__",
|
||||
ast.Pow: "__pow__",
|
||||
ast.LShift: "__lshift__",
|
||||
ast.RShift: "__rshift__",
|
||||
ast.BitOr: "__or__",
|
||||
ast.BitXor: "__xor__",
|
||||
ast.BitAnd: "__and__",
|
||||
ast.FloorDiv: "__floordiv__",
|
||||
}
|
||||
|
||||
COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
||||
ast.Eq: "__eq__",
|
||||
# ast.NotEq: "__noteq__",
|
||||
ast.Lt: "__lt__",
|
||||
ast.LtE: "__le__",
|
||||
ast.Gt: "__gt__",
|
||||
ast.GtE: "__ge__",
|
||||
# ast.Is: "__is__",
|
||||
# ast.IsNot: "__isnot__",
|
||||
# ast.In: "__in__",
|
||||
# ast.NotIn: "__notin__",
|
||||
}
|
||||
42
midas/checker/types.py
Normal file
42
midas/checker/types.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class BaseType:
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class SimpleType:
|
||||
name: str
|
||||
base: BaseType | SimpleType
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnknownType:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnitType:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Function:
|
||||
name: str
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
name: str
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
|
||||
Type = BaseType | SimpleType | UnknownType | UnitType | Function
|
||||
0
midas/cli/__init__.py
Normal file
0
midas/cli/__init__.py
Normal file
57
midas/cli/highlight.css
Normal file
57
midas/cli/highlight.css
Normal file
@@ -0,0 +1,57 @@
|
||||
html,
|
||||
body {
|
||||
margin: 0;
|
||||
font-size: 14pt;
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
#code {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
font-family: monospace;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.line {
|
||||
display: flex;
|
||||
|
||||
&:nth-child(odd) {
|
||||
background-color: rgb(247, 247, 247);
|
||||
}
|
||||
|
||||
.no {
|
||||
width: 4em;
|
||||
text-align: right;
|
||||
padding: 0.2em 0.4em;
|
||||
border-right: solid black 1px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.txt {
|
||||
flex-grow: 1;
|
||||
padding: 0.2em 0.8em;
|
||||
}
|
||||
}
|
||||
|
||||
span {
|
||||
--col: transparent;
|
||||
--opacity: 0.1;
|
||||
--border: 0px;
|
||||
background-color: rgba(var(--col), var(--opacity));
|
||||
outline: solid rgb(var(--col)) var(--border);
|
||||
outline-offset: 2px;
|
||||
border-radius: 2px;
|
||||
|
||||
&:hover:not(:has(*:hover)) {
|
||||
--opacity: 0.8;
|
||||
--border: 2px;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
&.keyword {
|
||||
color: rgb(211, 72, 9);
|
||||
}
|
||||
}
|
||||
300
midas/cli/highlighter.py
Normal file
300
midas/cli/highlighter.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Generic, Optional, Protocol, TextIO, TypeVar
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
|
||||
H = TypeVar("H", bound="Highlighter", contravariant=True)
|
||||
|
||||
|
||||
class Highlightable(Protocol, Generic[H]):
|
||||
def accept(self, visitor: H): ...
|
||||
|
||||
|
||||
class Locatable(Protocol):
|
||||
@property
|
||||
@abstractmethod
|
||||
def location(self) -> Optional[Location]: ...
|
||||
|
||||
|
||||
class Highlighter(ABC):
|
||||
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
|
||||
EXTRA_CSS_PATH: Optional[Path] = None
|
||||
|
||||
def __init__(self, source: str) -> None:
|
||||
self.source: str = source
|
||||
self.lines: list[str] = self.source.splitlines()
|
||||
self.openings: dict[tuple[int, int], list[str]] = {}
|
||||
self.closings: dict[tuple[int, int], list[str]] = {}
|
||||
|
||||
def format_css(self, path: Path) -> list[str]:
|
||||
css: str = path.read_text()
|
||||
css = "\n".join((" " + line).rstrip() for line in css.splitlines())
|
||||
return [
|
||||
" <style>",
|
||||
css,
|
||||
" </style>",
|
||||
]
|
||||
|
||||
def dump(self, buf: TextIO):
|
||||
base_css: list[str] = self.format_css(self.BASE_CSS_PATH)
|
||||
extra_css: list[str] = (
|
||||
self.format_css(self.EXTRA_CSS_PATH)
|
||||
if self.EXTRA_CSS_PATH is not None
|
||||
else []
|
||||
)
|
||||
lines: list[str] = [
|
||||
"<!DOCTYPE html>",
|
||||
'<html lang="en">',
|
||||
"<head>",
|
||||
' <meta charset="UTF-8">',
|
||||
' <meta name="viewport" content="width=device-width, initial-scale=1.0">',
|
||||
" <title>Highlighted file</title>",
|
||||
*base_css,
|
||||
*extra_css,
|
||||
"</head>",
|
||||
"<body>",
|
||||
' <div id="code">',
|
||||
]
|
||||
for l, line in enumerate(self.lines):
|
||||
lineno: int = l + 1
|
||||
line_buf: str = (
|
||||
f'<div class="line" id="l{lineno}"><div class="no">{lineno}</div><div class="txt">'
|
||||
)
|
||||
for c, char in enumerate(line):
|
||||
pos: tuple[int, int] = (lineno, c)
|
||||
closings: list[str] = self.closings.get(pos, [])
|
||||
openings: list[str] = self.openings.get(pos, [])
|
||||
line_buf += "".join(closings + openings)
|
||||
line_buf += char
|
||||
line_buf += "".join(self.closings.get((lineno, len(line)), []))
|
||||
line_buf += "</div></div>"
|
||||
lines.append(" " + line_buf)
|
||||
lines.extend(
|
||||
[
|
||||
" </div>",
|
||||
"</body>",
|
||||
"</html>",
|
||||
]
|
||||
)
|
||||
|
||||
buf.write("\n".join(lines))
|
||||
|
||||
def wrap(self, node: Locatable, cls: str, message: Optional[str] = None):
|
||||
if node.location is None:
|
||||
return
|
||||
if node.location.end_lineno is None or node.location.end_col_offset is None:
|
||||
return
|
||||
start_pos: tuple[int, int] = (node.location.lineno, node.location.col_offset)
|
||||
end_pos: tuple[int, int] = (
|
||||
node.location.end_lineno,
|
||||
node.location.end_col_offset,
|
||||
)
|
||||
opening: str = f'<span class="{cls}" title="{cls}">'
|
||||
closing: str = "</span>"
|
||||
if message is not None:
|
||||
opening = f'<span class="with-msg">{opening}'
|
||||
closing = f'{closing}<span class="message">{message}</span></span>'
|
||||
|
||||
self.openings.setdefault(start_pos, []).append(opening)
|
||||
self.closings.setdefault(end_pos, []).insert(0, closing)
|
||||
if start_pos[0] != end_pos[0]:
|
||||
for l in range(start_pos[0], end_pos[0]):
|
||||
c: int = len(self.lines[l - 1])
|
||||
self.closings.setdefault((l, c), []).insert(0, closing)
|
||||
self.openings.setdefault((l + 1, 0), []).append(opening)
|
||||
|
||||
|
||||
class PythonHighlighter(
|
||||
Highlighter,
|
||||
p.MidasType.Visitor[None],
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[None],
|
||||
):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_python.css"
|
||||
|
||||
def highlight(self, node: Highlightable[PythonHighlighter]):
|
||||
node.accept(self)
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> None:
|
||||
self.wrap(node, "base-type")
|
||||
if node.param is not None:
|
||||
self.wrap(node.param, "param")
|
||||
node.param.accept(self)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||
self.wrap(node, "constraint-type")
|
||||
node.type.accept(self)
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> None:
|
||||
self.wrap(node, "frame-column")
|
||||
if node.type is not None:
|
||||
node.type.accept(self)
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> None:
|
||||
self.wrap(node, "frame-type")
|
||||
for column in node.columns:
|
||||
column.accept(self)
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
self.wrap(stmt, "function")
|
||||
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs:
|
||||
self._highlight_function_argument(arg)
|
||||
for body_stmt in stmt.body:
|
||||
body_stmt.accept(self)
|
||||
|
||||
def _highlight_function_argument(self, arg: p.Function.Argument) -> None:
|
||||
self.wrap(arg, "argument")
|
||||
if arg.type is not None:
|
||||
arg.type.accept(self)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
for target in stmt.targets:
|
||||
target.accept(self)
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
self.wrap(stmt, "return")
|
||||
if stmt.value is not None:
|
||||
stmt.value.accept(self)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
self.wrap(stmt, "if")
|
||||
stmt.test.accept(self)
|
||||
for body_stmt in stmt.body:
|
||||
body_stmt.accept(self)
|
||||
for else_stmt in stmt.orelse:
|
||||
else_stmt.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None: ...
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self.wrap(expr, "call")
|
||||
expr.callee.accept(self)
|
||||
for arg in expr.arguments:
|
||||
arg.accept(self)
|
||||
for arg in expr.keywords.values():
|
||||
arg.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None: ...
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None: ...
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None: ...
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> None: ...
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
|
||||
|
||||
|
||||
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
|
||||
|
||||
def highlight(self, node: Highlightable[MidasHighlighter]):
|
||||
node.accept(self)
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
|
||||
self.wrap(stmt, "simple-type")
|
||||
if stmt.template is not None:
|
||||
stmt.template.accept(self)
|
||||
stmt.base.accept(self)
|
||||
if stmt.constraint is not None:
|
||||
self.wrap(stmt.constraint, "constraint")
|
||||
stmt.constraint.accept(self)
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None:
|
||||
self.wrap(stmt, "complex-type")
|
||||
if stmt.template is not None:
|
||||
stmt.template.accept(self)
|
||||
for prop in stmt.properties:
|
||||
prop.accept(self)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
|
||||
self.wrap(stmt, "property")
|
||||
stmt.type.accept(self)
|
||||
if stmt.constraint is not None:
|
||||
self.wrap(stmt.constraint, "constraint")
|
||||
stmt.constraint.accept(self)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self.wrap(stmt, "extend")
|
||||
stmt.type.accept(self)
|
||||
for op in stmt.operations:
|
||||
op.accept(self)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
||||
self.wrap(stmt, "op")
|
||||
stmt.operand.accept(self)
|
||||
stmt.result.accept(self)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||
self.wrap(stmt, "predicate")
|
||||
stmt.type.accept(self)
|
||||
stmt.condition.accept(self)
|
||||
|
||||
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None:
|
||||
self.wrap(expr, "simple-type-expr")
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||
self.wrap(expr, "logical-expr")
|
||||
expr.left.accept(self)
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
|
||||
self.wrap(expr, "binary-expr")
|
||||
expr.left.accept(self)
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
||||
self.wrap(expr, "unary-expr")
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
||||
self.wrap(expr, "get-expr")
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
|
||||
self.wrap(expr, "variable")
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||
expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||
|
||||
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
|
||||
self.wrap(expr, "template")
|
||||
expr.type.accept(self)
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr) -> None:
|
||||
self.wrap(expr, "type")
|
||||
if expr.template is not None:
|
||||
expr.template.accept(self)
|
||||
|
||||
|
||||
class DiagnosticsHighlighter(Highlighter):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
||||
|
||||
def highlight(self, diagnostics: list[Diagnostic]):
|
||||
for diagnostic in diagnostics:
|
||||
self.wrap(diagnostic, str(diagnostic.type).lower(), diagnostic.message)
|
||||
39
midas/cli/hl_diagnostic.css
Normal file
39
midas/cli/hl_diagnostic.css
Normal file
@@ -0,0 +1,39 @@
|
||||
span {
|
||||
--opacity: 0.4;
|
||||
|
||||
&.error {
|
||||
--col: 255, 0, 0;
|
||||
}
|
||||
&.warning {
|
||||
--col: 250, 160, 0;
|
||||
}
|
||||
&.info {
|
||||
--col: 150, 190, 250;
|
||||
}
|
||||
|
||||
&.with-msg {
|
||||
position: relative;
|
||||
|
||||
.message {
|
||||
display: none;
|
||||
}
|
||||
|
||||
&:hover:not(:has(.with-msg:hover)) {
|
||||
.message {
|
||||
display: inline-block;
|
||||
}
|
||||
}
|
||||
|
||||
.message {
|
||||
position: absolute;
|
||||
top: calc(100% + 0.2em);
|
||||
left: -.2em;
|
||||
background-color: black;
|
||||
color: white;
|
||||
padding: 0.2em 0.4em;
|
||||
border-radius: .2em;
|
||||
z-index: 10;
|
||||
width: 300%;
|
||||
}
|
||||
}
|
||||
}
|
||||
55
midas/cli/hl_midas.css
Normal file
55
midas/cli/hl_midas.css
Normal file
@@ -0,0 +1,55 @@
|
||||
span {
|
||||
&.comment {
|
||||
--col: 200, 200, 200;
|
||||
color: rgb(110, 110, 110);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
&.simple-type {
|
||||
--col: 108, 233, 108;
|
||||
}
|
||||
|
||||
&.complex-type {
|
||||
--col: 233, 206, 108;
|
||||
}
|
||||
|
||||
&.constraint {
|
||||
--col: 233, 108, 108;
|
||||
}
|
||||
|
||||
&.property {
|
||||
--col: 233, 108, 176;
|
||||
}
|
||||
|
||||
&.extend {
|
||||
--col: 108, 197, 233;
|
||||
}
|
||||
|
||||
&.op {
|
||||
--col: 108, 148, 233;
|
||||
}
|
||||
|
||||
&.predicate {
|
||||
--col: 193, 108, 233;
|
||||
}
|
||||
|
||||
&.simple-type-expr {
|
||||
--col: 150, 150, 150;
|
||||
}
|
||||
|
||||
&.logical-expr,
|
||||
&.binary-expr,
|
||||
&.unary-expr,
|
||||
&.get-expr {
|
||||
--col: 123, 215, 193;
|
||||
}
|
||||
|
||||
&.template {
|
||||
--col: 163, 117, 71;
|
||||
}
|
||||
|
||||
&.type {
|
||||
--col: 200, 200, 200;
|
||||
font-weight: bold;
|
||||
}
|
||||
}
|
||||
29
midas/cli/hl_python.css
Normal file
29
midas/cli/hl_python.css
Normal file
@@ -0,0 +1,29 @@
|
||||
span {
|
||||
&.base-type {
|
||||
--col: 108, 233, 108;
|
||||
}
|
||||
|
||||
&.param {
|
||||
--col: 103, 192, 224;
|
||||
}
|
||||
|
||||
&.constraint-type {
|
||||
--col: 174, 200, 195;
|
||||
}
|
||||
|
||||
&.frame-column {
|
||||
--col: 216, 231, 81;
|
||||
}
|
||||
|
||||
&.frame-type {
|
||||
--col: 231, 46, 40;
|
||||
}
|
||||
|
||||
&.function {
|
||||
--col: 215, 103, 224;
|
||||
}
|
||||
|
||||
&.argument {
|
||||
--col: 103, 192, 224;
|
||||
}
|
||||
}
|
||||
180
midas/cli/main.py
Normal file
180
midas/cli/main.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, TextIO, get_args
|
||||
|
||||
import click
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
|
||||
from midas.checker.checker import Checker
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.checker.types import Type
|
||||
from midas.cli.highlighter import (
|
||||
DiagnosticsHighlighter,
|
||||
Highlighter,
|
||||
MidasHighlighter,
|
||||
PythonHighlighter,
|
||||
)
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.parser.midas import MidasParser
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.resolver.resolver import Resolver
|
||||
from midas.utils import UniversalJSONDumper
|
||||
|
||||
|
||||
@click.group()
|
||||
def midas():
|
||||
click.echo("Welcome to Midas!")
|
||||
|
||||
|
||||
@midas.command()
|
||||
@click.option("-l", "--highlight", type=click.File("w"))
|
||||
@click.argument("file", type=click.File("r"))
|
||||
def compile(highlight: Optional[TextIO], file: TextIO):
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
source: str = file.read()
|
||||
tree: ast.Module = ast.parse(source, filename=file.name)
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
resolver = Resolver()
|
||||
resolver.resolve(*stmts)
|
||||
checker = Checker(resolver.locals, file_path=Path(file.name).resolve())
|
||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
||||
for diagnostic in diagnostics:
|
||||
print(diagnostic)
|
||||
|
||||
print(
|
||||
json.dumps(
|
||||
UniversalJSONDumper.dump(
|
||||
checker.global_env,
|
||||
[("Environment", "_children")],
|
||||
lambda obj: isinstance(obj, get_args(Type)),
|
||||
),
|
||||
indent=4,
|
||||
)
|
||||
)
|
||||
if highlight is not None:
|
||||
highlighter = DiagnosticsHighlighter(source)
|
||||
highlighter.highlight(diagnostics)
|
||||
highlighter.dump(highlight)
|
||||
|
||||
|
||||
@midas.group()
|
||||
def utils():
|
||||
pass
|
||||
|
||||
|
||||
def dump_python_ast(tree: ast.Module) -> str:
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
printer = PythonAstPrinter()
|
||||
dump: str = ""
|
||||
for stmt in stmts:
|
||||
dump += printer.print(stmt)
|
||||
dump += "\n"
|
||||
return dump
|
||||
|
||||
|
||||
def dump_midas_ast(source: str, filename: str) -> str:
|
||||
lexer = MidasLexer(source, file=filename)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
if len(parser.errors) != 0:
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
raise RuntimeError("A parsing error occurred")
|
||||
printer = MidasAstPrinter()
|
||||
dump: str = ""
|
||||
for stmt in stmts:
|
||||
dump += printer.print(stmt)
|
||||
dump += "\n"
|
||||
return dump
|
||||
|
||||
|
||||
@utils.command()
|
||||
@click.option("-o", "--output", type=click.File("w"))
|
||||
@click.option("-p", "--parse", is_flag=True)
|
||||
@click.argument("file", type=click.File("r"))
|
||||
def dump_ast(output: Optional[TextIO], parse: bool, file: TextIO):
|
||||
source: str = file.read()
|
||||
|
||||
dump: str
|
||||
if file.name.endswith(".py"):
|
||||
tree: ast.Module = ast.parse(source, filename=file.name)
|
||||
if parse:
|
||||
dump = dump_python_ast(tree)
|
||||
else:
|
||||
dump = ast.dump(tree, indent=4)
|
||||
elif file.name.endswith(".midas"):
|
||||
dump = dump_midas_ast(source, file.name)
|
||||
else:
|
||||
raise ValueError("Unsupported file type")
|
||||
|
||||
if output is None:
|
||||
click.echo(dump)
|
||||
else:
|
||||
output.write(dump)
|
||||
|
||||
|
||||
def highlight_python(source: str, path: str) -> Highlighter:
|
||||
tree: ast.Module = ast.parse(source, filename=path)
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
highlighter = PythonHighlighter(source)
|
||||
for stmt in stmts:
|
||||
highlighter.highlight(stmt)
|
||||
return highlighter
|
||||
|
||||
|
||||
def highlight_midas(source: str, path: str) -> Highlighter:
|
||||
lexer = MidasLexer(source, file=path)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
highlighter = MidasHighlighter(source)
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LocatableToken:
|
||||
token: Token
|
||||
|
||||
@property
|
||||
def location(self) -> Location:
|
||||
return self.token.get_location()
|
||||
|
||||
for stmt in stmts:
|
||||
highlighter.highlight(stmt)
|
||||
for token in tokens:
|
||||
if token.type == TokenType.COMMENT:
|
||||
highlighter.wrap(LocatableToken(token), "comment")
|
||||
elif token.is_keyword:
|
||||
highlighter.wrap(LocatableToken(token), "keyword")
|
||||
return highlighter
|
||||
|
||||
|
||||
@utils.command()
|
||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
def highlight(output: TextIO, file: TextIO):
|
||||
source: str = file.read()
|
||||
highlighter: Highlighter
|
||||
|
||||
if file.name.endswith(".py"):
|
||||
highlighter = highlight_python(source, file.name)
|
||||
elif file.name.endswith(".midas"):
|
||||
highlighter = highlight_midas(source, file.name)
|
||||
else:
|
||||
raise ValueError("Unsupported file type")
|
||||
highlighter.dump(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
midas()
|
||||
0
midas/lexer/__init__.py
Normal file
0
midas/lexer/__init__.py
Normal file
@@ -1,8 +1,15 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from lexer.position import Position
|
||||
from lexer.token import Token, TokenType
|
||||
from midas.lexer.position import Position
|
||||
from midas.lexer.token import Token, TokenType
|
||||
|
||||
|
||||
class MidasSyntaxError(Exception):
|
||||
def __init__(self, pos: Position, message: str):
|
||||
super().__init__(f"[ERROR] Error at {pos}: {message}")
|
||||
self.pos: Position = pos
|
||||
self.message: str = message
|
||||
|
||||
|
||||
class Lexer(ABC):
|
||||
@@ -38,9 +45,9 @@ class Lexer(ABC):
|
||||
msg (str): the error message
|
||||
|
||||
Raises:
|
||||
SyntaxError
|
||||
MidasSyntaxError
|
||||
"""
|
||||
raise SyntaxError(f"[ERROR] Error at {self.start_pos}: {msg}")
|
||||
raise MidasSyntaxError(self.start_pos, msg)
|
||||
|
||||
def process(self) -> list[Token]:
|
||||
"""Scan tokens out of the source text
|
||||
@@ -49,7 +56,7 @@ class Lexer(ABC):
|
||||
list[Token]: all the tokens that could be scanned
|
||||
|
||||
Raises:
|
||||
SyntaxError: if a syntax error is found
|
||||
MidasSyntaxError: if a syntax error is found
|
||||
"""
|
||||
self.scan_tokens()
|
||||
self.tokens.append(Token(TokenType.EOF, "", None, self.get_position()))
|
||||
@@ -1,6 +1,5 @@
|
||||
from lexer.base import Lexer
|
||||
from lexer.keyword import MIDAS_KEYWORDS
|
||||
from lexer.token import TokenType
|
||||
from midas.lexer.base import Lexer
|
||||
from midas.lexer.token import KEYWORDS, TokenType
|
||||
|
||||
|
||||
class MidasLexer(Lexer):
|
||||
@@ -31,30 +30,32 @@ class MidasLexer(Lexer):
|
||||
self.add_token(
|
||||
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
|
||||
)
|
||||
case "!":
|
||||
if self.match("="):
|
||||
case "!" if self.match("="):
|
||||
self.add_token(TokenType.BANG_EQUAL)
|
||||
else:
|
||||
self.error("Unexpected single bang. Did you mean '!=' ?")
|
||||
case ":":
|
||||
self.add_token(TokenType.COLON)
|
||||
case ",":
|
||||
self.add_token(TokenType.COMMA)
|
||||
case "_":
|
||||
case ".":
|
||||
self.add_token(TokenType.DOT)
|
||||
case "&":
|
||||
self.add_token(TokenType.AND)
|
||||
case "?":
|
||||
self.add_token(TokenType.QMARK)
|
||||
# case ",":
|
||||
# self.add_token(TokenType.COMMA)
|
||||
case "_" if not self.is_identifier_char(self.peek_next(), start=False):
|
||||
self.add_token(TokenType.UNDERSCORE)
|
||||
case "+":
|
||||
self.add_token(TokenType.PLUS)
|
||||
case "-" if self.match(">"):
|
||||
self.add_token(TokenType.ARROW)
|
||||
# case "+":
|
||||
# self.add_token(TokenType.PLUS)
|
||||
case "-":
|
||||
self.add_token(TokenType.MINUS)
|
||||
case "*":
|
||||
self.add_token(TokenType.STAR)
|
||||
case "/":
|
||||
if self.match("/"):
|
||||
# case "*":
|
||||
# self.add_token(TokenType.STAR)
|
||||
case "/" if self.match("/"):
|
||||
self.scan_comment()
|
||||
elif self.match("*"):
|
||||
case "/" if self.match("*"):
|
||||
self.scan_comment_multiline()
|
||||
else:
|
||||
self.add_token(TokenType.SLASH)
|
||||
case "\n":
|
||||
self.add_token(TokenType.NEWLINE)
|
||||
case " " | "\r" | "\t":
|
||||
@@ -69,7 +70,7 @@ class MidasLexer(Lexer):
|
||||
case _:
|
||||
if char.isdigit():
|
||||
self.scan_number()
|
||||
elif char.isalpha():
|
||||
elif self.is_identifier_char(char, start=True):
|
||||
self.scan_identifier()
|
||||
else:
|
||||
self.error("Unexpected character")
|
||||
@@ -98,11 +99,11 @@ class MidasLexer(Lexer):
|
||||
An identifier starts with a letter, followed by any number of
|
||||
alphanumerical characters or underscores
|
||||
"""
|
||||
while self.peek().isalnum() or self.peek() == "_":
|
||||
while self.is_identifier_char(self.peek(), start=False):
|
||||
self.advance()
|
||||
|
||||
lexeme: str = self.source[self.start : self.idx]
|
||||
token_type: TokenType = MIDAS_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
||||
token_type: TokenType = KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
||||
self.add_token(token_type)
|
||||
|
||||
def scan_comment(self):
|
||||
@@ -129,3 +130,12 @@ class MidasLexer(Lexer):
|
||||
if not self.is_at_end():
|
||||
self.advance()
|
||||
self.add_token(TokenType.COMMENT)
|
||||
|
||||
def is_identifier_char(self, char: str, *, start: bool) -> bool:
|
||||
if char == "_":
|
||||
return True
|
||||
if char.isalpha():
|
||||
return True
|
||||
if not start and char.isdigit():
|
||||
return True
|
||||
return False
|
||||
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
@dataclass(frozen=True)
|
||||
class Position:
|
||||
"""A simple structure to store the position of a token"""
|
||||
|
||||
file: Optional[str]
|
||||
line: int
|
||||
column: int
|
||||
104
midas/lexer/token.py
Normal file
104
midas/lexer/token.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.lexer.position import Position
|
||||
|
||||
|
||||
class TokenType(Enum):
|
||||
# Punctuation
|
||||
LEFT_PAREN = auto()
|
||||
RIGHT_PAREN = auto()
|
||||
LEFT_BRACKET = auto()
|
||||
RIGHT_BRACKET = auto()
|
||||
LEFT_BRACE = auto()
|
||||
RIGHT_BRACE = auto()
|
||||
COLON = auto()
|
||||
# COMMA = auto()
|
||||
UNDERSCORE = auto()
|
||||
ARROW = auto()
|
||||
AND = auto()
|
||||
QMARK = auto()
|
||||
DOT = auto()
|
||||
|
||||
# Operators
|
||||
# PLUS = auto()
|
||||
MINUS = auto()
|
||||
# STAR = auto()
|
||||
# SLASH = auto()
|
||||
GREATER = auto()
|
||||
GREATER_EQUAL = auto()
|
||||
LESS = auto()
|
||||
LESS_EQUAL = auto()
|
||||
EQUAL = auto()
|
||||
EQUAL_EQUAL = auto()
|
||||
BANG_EQUAL = auto()
|
||||
|
||||
# Literals
|
||||
IDENTIFIER = auto()
|
||||
NUMBER = auto()
|
||||
TRUE = auto()
|
||||
FALSE = auto()
|
||||
NONE = auto()
|
||||
|
||||
# Keywords
|
||||
TYPE = auto()
|
||||
OP = auto()
|
||||
PREDICATE = auto()
|
||||
EXTEND = auto()
|
||||
WHERE = auto()
|
||||
|
||||
# Misc
|
||||
COMMENT = auto()
|
||||
WHITESPACE = auto()
|
||||
EOF = auto()
|
||||
NEWLINE = auto()
|
||||
|
||||
|
||||
KEYWORDS: dict[str, TokenType] = {
|
||||
"type": TokenType.TYPE,
|
||||
"op": TokenType.OP,
|
||||
"predicate": TokenType.PREDICATE,
|
||||
"extend": TokenType.EXTEND,
|
||||
"where": TokenType.WHERE,
|
||||
"true": TokenType.TRUE,
|
||||
"false": TokenType.FALSE,
|
||||
"none": TokenType.NONE,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Token:
|
||||
"""A scanned token"""
|
||||
|
||||
type: TokenType
|
||||
lexeme: str
|
||||
value: Any
|
||||
position: Position
|
||||
|
||||
def get_location(self) -> Location:
|
||||
lineno: int = self.position.line
|
||||
col_offset: int = self.position.column - 1
|
||||
end_lineno = lineno
|
||||
end_col_offset = col_offset
|
||||
for c in self.lexeme:
|
||||
end_col_offset += 1
|
||||
if c == "\n":
|
||||
end_lineno += 1
|
||||
end_col_offset = 0
|
||||
return Location(
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
end_lineno=end_lineno,
|
||||
end_col_offset=end_col_offset,
|
||||
)
|
||||
|
||||
def location_to(self, to: Token) -> Location:
|
||||
return Location.span(self.get_location(), to.get_location())
|
||||
|
||||
@property
|
||||
def is_keyword(self) -> bool:
|
||||
return self.lexeme in KEYWORDS
|
||||
@@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from lexer.token import Token, TokenType
|
||||
from parser.errors import ParsingError
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.parser.errors import ParsingError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
419
midas/parser/midas.py
Normal file
419
midas/parser/midas.py
Normal file
@@ -0,0 +1,419 @@
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.midas import (
|
||||
BinaryExpr,
|
||||
ComplexTypeStmt,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
OpStmt,
|
||||
PredicateStmt,
|
||||
PropertyStmt,
|
||||
SimpleTypeExpr,
|
||||
SimpleTypeStmt,
|
||||
Stmt,
|
||||
TemplateExpr,
|
||||
TypeExpr,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.parser.base import Parser
|
||||
from midas.parser.errors import ParsingError
|
||||
|
||||
|
||||
class MidasParser(Parser):
|
||||
"""A simple parser for midas type definitions"""
|
||||
|
||||
SYNC_BOUNDARY: set[TokenType] = {
|
||||
TokenType.TYPE,
|
||||
TokenType.OP,
|
||||
TokenType.EXTEND,
|
||||
TokenType.PREDICATE,
|
||||
}
|
||||
|
||||
def parse(self) -> list[Stmt]:
|
||||
statements: list[Stmt] = []
|
||||
while not self.is_at_end():
|
||||
stmt: Optional[Stmt] = self.declaration()
|
||||
if stmt is None:
|
||||
print("Early stop")
|
||||
break
|
||||
statements.append(stmt)
|
||||
return statements
|
||||
|
||||
def synchronize(self):
|
||||
"""Skip tokens until a synchronization boundary is found
|
||||
|
||||
This method allows gracefully recovering from a parse error
|
||||
to a safe place and continue parsing
|
||||
"""
|
||||
self.advance()
|
||||
while not self.is_at_end():
|
||||
if self.previous().type == TokenType.NEWLINE:
|
||||
return
|
||||
if self.peek().type in self.SYNC_BOUNDARY:
|
||||
return
|
||||
self.advance()
|
||||
|
||||
def declaration(self) -> Optional[Stmt]:
|
||||
"""Try and parse a declaration
|
||||
|
||||
Any parsing error is caught and None is returned
|
||||
|
||||
Returns:
|
||||
Optional[Stmt]: the parsed Midas statement, or None if a ParsingError was raised
|
||||
"""
|
||||
try:
|
||||
if self.match(TokenType.TYPE):
|
||||
return self.type_declaration()
|
||||
if self.match(TokenType.EXTEND):
|
||||
return self.extend_declaration()
|
||||
if self.match(TokenType.PREDICATE):
|
||||
return self.predicate_declaration()
|
||||
raise self.error(self.peek(), "Unexpected token")
|
||||
except ParsingError:
|
||||
self.synchronize()
|
||||
return None
|
||||
|
||||
def type_declaration(self) -> SimpleTypeStmt | ComplexTypeStmt:
|
||||
"""Parse a type declaration
|
||||
|
||||
A type declaration can either be a simple type alias or a new complex type.
|
||||
In either case, it can have an optional template expression after its name, wrapped in brackets.
|
||||
A simple type alias is derived from a base type expression, and can have a optional constraint expression preceded by the `where` keyword.
|
||||
A full simple type alias is thus written:
|
||||
```
|
||||
type Name[Template](TypeExpr) where Condition
|
||||
```
|
||||
|
||||
A new complex type has a set of properties which are named, have a type and an optional constraint expression (also preceded by the `where` keyword).
|
||||
A full complex type definition is thus written:
|
||||
```
|
||||
type Name[Template] {
|
||||
prop1: TypeExpr1 where Condition1
|
||||
prop2: TypeExpr2 where Condition2
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
Returns:
|
||||
TypeStmt: the parsed type declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
template: Optional[TemplateExpr] = None
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
template = self.template_expr()
|
||||
|
||||
if self.match(TokenType.LEFT_PAREN):
|
||||
base: TypeExpr = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Unclosed base type parenthesis")
|
||||
constraint: Optional[Expr] = None
|
||||
if self.match(TokenType.WHERE):
|
||||
constraint = self.constraint()
|
||||
return SimpleTypeStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
template=template,
|
||||
base=base,
|
||||
constraint=constraint,
|
||||
)
|
||||
else:
|
||||
properties: list[PropertyStmt] = self.type_properties()
|
||||
return ComplexTypeStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
template=template,
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
def template_expr(self) -> TemplateExpr:
|
||||
"""Parse a generic template expression
|
||||
|
||||
A template is written `[TypeExpr]`
|
||||
|
||||
Returns:
|
||||
TemplateExpr: the parsed template expression
|
||||
"""
|
||||
left: Token = self.consume(
|
||||
TokenType.LEFT_BRACKET, "Missing '[' before template expression"
|
||||
)
|
||||
type: TypeExpr = self.type_expr()
|
||||
right: Token = self.consume(
|
||||
TokenType.RIGHT_BRACKET, "Missing ']' after template expression"
|
||||
)
|
||||
return TemplateExpr(location=left.location_to(right), type=type)
|
||||
|
||||
def type_expr(self) -> TypeExpr:
|
||||
"""Parse a type expression
|
||||
|
||||
A type is an identifier, optionally followed by a template expression.
|
||||
It can also optionally be followed by a '?' to indicate a nullable type
|
||||
|
||||
Returns:
|
||||
TypeExpr: the parsed type expression
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
template: Optional[TemplateExpr] = None
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
template = self.template_expr()
|
||||
optional: bool = self.match(TokenType.QMARK)
|
||||
return TypeExpr(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
template=template,
|
||||
optional=optional,
|
||||
)
|
||||
|
||||
def simple_type_expr(self) -> SimpleTypeExpr:
|
||||
"""Parse a simple type expression
|
||||
|
||||
A simple type is just an identifier optionally followed by a '?'
|
||||
|
||||
Returns:
|
||||
SimpleTypeExpr: the parsed simple type expression
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
optional: bool = self.match(TokenType.QMARK)
|
||||
return SimpleTypeExpr(
|
||||
location=name.location_to(self.previous()), name=name, optional=optional
|
||||
)
|
||||
|
||||
def constraint(self) -> Expr:
|
||||
"""Parse a constraint
|
||||
|
||||
A constraint is basically a logical predicate
|
||||
|
||||
Returns:
|
||||
Expr: the parsed constraint expression
|
||||
"""
|
||||
return self.and_()
|
||||
|
||||
def and_(self) -> Expr:
|
||||
"""Parse a logical AND expression or a simpler expression
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.equality()
|
||||
while self.match(TokenType.AND):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.equality()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = LogicalExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def equality(self) -> Expr:
|
||||
"""Parse a logical equality expression or a simpler expression
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.comparison()
|
||||
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.comparison()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def comparison(self) -> Expr:
|
||||
"""Parse a logical comparison expression or a simpler expression
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.unary()
|
||||
while self.match(
|
||||
TokenType.LESS,
|
||||
TokenType.LESS_EQUAL,
|
||||
TokenType.GREATER,
|
||||
TokenType.GREATER_EQUAL,
|
||||
):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.unary()
|
||||
location: Location = Location.span(expr.location, right.location)
|
||||
expr = BinaryExpr(
|
||||
location=location, left=expr, operator=operator, right=right
|
||||
)
|
||||
return expr
|
||||
|
||||
def unary(self) -> Expr:
|
||||
"""Parse a unary expression or a simpler expression
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
if self.match(TokenType.MINUS):
|
||||
operator: Token = self.previous()
|
||||
right: Expr = self.unary()
|
||||
location: Location = Location.span(operator.get_location(), right.location)
|
||||
return UnaryExpr(location=location, operator=operator, right=right)
|
||||
return self.reference()
|
||||
|
||||
def reference(self) -> Expr:
|
||||
"""Parse an attribute access expression or a simpler expression
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
expr: Expr = self.primary()
|
||||
while self.match(TokenType.DOT):
|
||||
name: Token = self.consume(
|
||||
TokenType.IDENTIFIER, "Expected property name after '.'"
|
||||
)
|
||||
location: Location = Location.span(expr.location, name.get_location())
|
||||
expr = GetExpr(location=location, expr=expr, name=name)
|
||||
return expr
|
||||
|
||||
def primary(self) -> Expr:
|
||||
"""Parse a primary expression
|
||||
|
||||
This includes literals (booleans, numbers, etc.), wildcards, identifiers and grouped expressions
|
||||
|
||||
Returns:
|
||||
Expr: the parsed expression
|
||||
"""
|
||||
token: Token = self.peek()
|
||||
if self.match(TokenType.FALSE):
|
||||
return LiteralExpr(location=token.get_location(), value=False)
|
||||
if self.match(TokenType.TRUE):
|
||||
return LiteralExpr(location=token.get_location(), value=True)
|
||||
if self.match(TokenType.NONE):
|
||||
return LiteralExpr(location=token.get_location(), value=None)
|
||||
|
||||
if self.match(TokenType.NUMBER):
|
||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||
|
||||
if self.match(TokenType.IDENTIFIER):
|
||||
return VariableExpr(location=token.get_location(), name=token)
|
||||
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
return WildcardExpr(location=token.get_location(), token=token)
|
||||
|
||||
if self.match(TokenType.LEFT_PAREN):
|
||||
expr: Expr = self.constraint()
|
||||
right: Token = self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
||||
return GroupingExpr(location=token.location_to(right), expr=expr)
|
||||
|
||||
raise self.error(self.peek(), "Expected expression")
|
||||
|
||||
def type_properties(self) -> list[PropertyStmt]:
|
||||
"""Parse a type definition body
|
||||
|
||||
A type definition body is a set of whitespace-separated
|
||||
property statements enclosed in curly braces
|
||||
|
||||
Returns:
|
||||
list[PropertyStmt]: the parsed type properties
|
||||
"""
|
||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body")
|
||||
properties: list[PropertyStmt] = []
|
||||
names: set[str] = set()
|
||||
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
||||
prop: PropertyStmt = self.property_stmt()
|
||||
if prop.name.lexeme in names:
|
||||
raise self.error(prop.name, "Duplicate property")
|
||||
names.add(prop.name.lexeme)
|
||||
properties.append(prop)
|
||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
||||
return properties
|
||||
|
||||
def property_stmt(self) -> PropertyStmt:
|
||||
"""Parse a property statement
|
||||
|
||||
A type property statement is written `name: Type` or `name: Type where Condition`
|
||||
|
||||
Returns:
|
||||
PropertyStmt: the parsed property statement
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after property name")
|
||||
type: TypeExpr = self.type_expr()
|
||||
constraint: Optional[Expr] = None
|
||||
if self.match(TokenType.WHERE):
|
||||
constraint = self.constraint()
|
||||
return PropertyStmt(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
type=type,
|
||||
constraint=constraint,
|
||||
)
|
||||
|
||||
def extend_declaration(self) -> ExtendStmt:
|
||||
"""Parse an extension definition
|
||||
|
||||
An extension is written `extend Type { operations }`
|
||||
|
||||
Returns:
|
||||
ExtendStmt: the parsed extension statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
type: TypeExpr = self.type_expr()
|
||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
||||
operations: list[OpStmt] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
|
||||
operations.append(self.op_declaration())
|
||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
|
||||
location: Location = keyword.location_to(self.previous())
|
||||
return ExtendStmt(location=location, type=type, operations=operations)
|
||||
|
||||
def op_declaration(self) -> OpStmt:
|
||||
"""Parse an operation definition
|
||||
|
||||
An operation is written `op name(Type) -> Type`
|
||||
|
||||
Returns:
|
||||
OpStmt: the parsed operation statement
|
||||
"""
|
||||
keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword")
|
||||
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
|
||||
operand: TypeExpr = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: TypeExpr = self.type_expr()
|
||||
|
||||
return OpStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
operand=operand,
|
||||
result=result,
|
||||
)
|
||||
|
||||
def predicate_declaration(self) -> PredicateStmt:
|
||||
"""Parse a predicate declaration
|
||||
|
||||
A predicate is written `predicate Name(subject: Type) = constraint_expression`
|
||||
|
||||
Returns:
|
||||
PredicateStmt: the parsed predicate declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
||||
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
||||
type: TypeExpr = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
||||
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
||||
condition: Expr = self.constraint()
|
||||
return PredicateStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
subject=subject,
|
||||
type=type,
|
||||
condition=condition,
|
||||
)
|
||||
492
midas/parser/python.py
Normal file
492
midas/parser/python.py
Normal file
@@ -0,0 +1,492 @@
|
||||
import ast
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.python import (
|
||||
AssignStmt,
|
||||
BaseType,
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
CastExpr,
|
||||
CompareExpr,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExpressionStmt,
|
||||
FrameColumn,
|
||||
FrameType,
|
||||
Function,
|
||||
GetExpr,
|
||||
IfStmt,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ReturnStmt,
|
||||
Stmt,
|
||||
TernaryExpr,
|
||||
TypeAssign,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
)
|
||||
|
||||
|
||||
class InvalidSyntaxError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedSyntaxError(Exception):
|
||||
def __init__(self, expr: ast.expr) -> None:
|
||||
super().__init__(
|
||||
f"Unsupported syntax at L{expr.lineno}:{expr.col_offset}: {ast.unparse(expr)}"
|
||||
)
|
||||
|
||||
|
||||
class PythonParser:
|
||||
CAST_FUNCTION = "cast"
|
||||
|
||||
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
||||
statements: list[Stmt] = []
|
||||
for stmt in node.body:
|
||||
try:
|
||||
parsed: None | Stmt | list[Stmt] = self.parse_stmt(stmt)
|
||||
if isinstance(parsed, Stmt):
|
||||
statements.append(parsed)
|
||||
elif parsed is not None:
|
||||
statements.extend(parsed)
|
||||
except UnsupportedSyntaxError as e:
|
||||
print(f"{e}, skipping")
|
||||
continue
|
||||
return statements
|
||||
|
||||
def parse_stmt(self, node: ast.stmt) -> None | Stmt | list[Stmt]:
|
||||
location: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.AnnAssign():
|
||||
return self.parse_annotation_assign(node)
|
||||
|
||||
case ast.Assign():
|
||||
return self.parse_assign(node)
|
||||
|
||||
case ast.AugAssign():
|
||||
return self.parse_aug_assign(node)
|
||||
|
||||
case ast.FunctionDef():
|
||||
return self.parse_function(node)
|
||||
|
||||
case ast.Expr(value=expr):
|
||||
return ExpressionStmt(
|
||||
location=location,
|
||||
expr=self.parse_expr(expr),
|
||||
)
|
||||
|
||||
case ast.Return(value=value):
|
||||
return ReturnStmt(
|
||||
location=location,
|
||||
value=self.parse_expr(value) if value is not None else None,
|
||||
)
|
||||
|
||||
case ast.If():
|
||||
return self.parse_if(node)
|
||||
|
||||
case _:
|
||||
print(f"Unsupported statement: {ast.unparse(node)}")
|
||||
return None
|
||||
|
||||
def parse_annotation_assign(self, node: ast.AnnAssign) -> list[Stmt]:
|
||||
statements: list[Stmt] = []
|
||||
loc: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.AnnAssign(
|
||||
target=ast.Name(id=target),
|
||||
annotation=annotation,
|
||||
value=value,
|
||||
simple=1,
|
||||
):
|
||||
type = self._parse_type(annotation)
|
||||
statements.append(
|
||||
TypeAssign(
|
||||
location=loc,
|
||||
name=target,
|
||||
type=type,
|
||||
)
|
||||
)
|
||||
|
||||
if value is not None:
|
||||
statements.append(
|
||||
AssignStmt(
|
||||
location=loc,
|
||||
targets=[
|
||||
VariableExpr(
|
||||
location=Location.from_ast(node.target), name=target
|
||||
),
|
||||
],
|
||||
value=self.parse_expr(value),
|
||||
),
|
||||
)
|
||||
case _:
|
||||
print(f"Unsupported annotation: {ast.unparse(node)}")
|
||||
return statements
|
||||
|
||||
def parse_assign(self, node: ast.Assign) -> AssignStmt:
|
||||
targets: list[Expr] = []
|
||||
for target in node.targets:
|
||||
targets.append(self.parse_expr(target))
|
||||
value: Expr = self.parse_expr(node.value)
|
||||
return AssignStmt(
|
||||
location=Location.from_ast(node),
|
||||
targets=targets,
|
||||
value=value,
|
||||
)
|
||||
|
||||
def parse_aug_assign(self, node: ast.AugAssign) -> AssignStmt:
|
||||
location: Location = Location.from_ast(node)
|
||||
target: Expr = self.parse_expr(node.target)
|
||||
value: Expr = self.parse_expr(node.value)
|
||||
return AssignStmt(
|
||||
location=location,
|
||||
targets=[target],
|
||||
value=BinaryExpr(
|
||||
location=location,
|
||||
left=target,
|
||||
operator=node.op,
|
||||
right=value,
|
||||
),
|
||||
)
|
||||
|
||||
def parse_if(self, node: ast.If) -> IfStmt:
|
||||
body: list[Stmt] = []
|
||||
for stmt in node.body:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
if isinstance(stmts, Stmt):
|
||||
body.append(stmts)
|
||||
elif stmts is not None:
|
||||
body.extend(stmts)
|
||||
|
||||
orelse: list[Stmt] = []
|
||||
for stmt in node.orelse:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
if isinstance(stmts, Stmt):
|
||||
orelse.append(stmts)
|
||||
elif stmts is not None:
|
||||
orelse.extend(stmts)
|
||||
|
||||
return IfStmt(
|
||||
location=Location.from_ast(node),
|
||||
test=self.parse_expr(node.test),
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
)
|
||||
|
||||
def parse_function(self, node: ast.FunctionDef) -> Function:
|
||||
loc: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.FunctionDef(
|
||||
name=name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=posonlyargs,
|
||||
args=args,
|
||||
vararg=sink,
|
||||
kwonlyargs=kwonlyargs,
|
||||
kwarg=kw_sink,
|
||||
defaults=defaults,
|
||||
kw_defaults=kw_defaults,
|
||||
),
|
||||
returns=returns,
|
||||
body=raw_body,
|
||||
):
|
||||
|
||||
def parse_args(
|
||||
args_list: list[ast.arg], defaults: list[Optional[Expr]]
|
||||
) -> list[Function.Argument]:
|
||||
return [
|
||||
self._parse_function_argument(arg, default)
|
||||
for arg, default in zip(args_list, defaults)
|
||||
]
|
||||
|
||||
body: list[Stmt] = []
|
||||
for stmt in raw_body:
|
||||
stmts = self.parse_stmt(stmt)
|
||||
if isinstance(stmts, Stmt):
|
||||
body.append(stmts)
|
||||
elif stmts is not None:
|
||||
body.extend(stmts)
|
||||
|
||||
parsed_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) for default in defaults
|
||||
]
|
||||
n_posargs: int = len(posonlyargs)
|
||||
n_args: int = len(args)
|
||||
n_all_posargs = n_posargs + n_args
|
||||
parsed_defaults = [
|
||||
None,
|
||||
] * (n_all_posargs - len(defaults)) + parsed_defaults
|
||||
|
||||
posargs_defaults: list[Optional[Expr]] = parsed_defaults[:n_posargs]
|
||||
args_defaults: list[Optional[Expr]] = parsed_defaults[n_posargs:]
|
||||
kwargs_defaults: list[Optional[Expr]] = [
|
||||
self.parse_expr(default) if default is not None else None
|
||||
for default in kw_defaults
|
||||
]
|
||||
|
||||
return Function(
|
||||
location=loc,
|
||||
name=name,
|
||||
posonlyargs=parse_args(posonlyargs, posargs_defaults),
|
||||
args=parse_args(args, args_defaults),
|
||||
sink=(
|
||||
self._parse_function_argument(sink, None)
|
||||
if sink is not None
|
||||
else None
|
||||
),
|
||||
kwonlyargs=parse_args(kwonlyargs, kwargs_defaults),
|
||||
kw_sink=(
|
||||
self._parse_function_argument(kw_sink, None)
|
||||
if kw_sink is not None
|
||||
else None
|
||||
),
|
||||
returns=self._parse_type(returns) if returns is not None else None,
|
||||
body=body,
|
||||
)
|
||||
case _:
|
||||
print(f"Unsupported function definition: {ast.unparse(node)}")
|
||||
|
||||
def _parse_function_argument(
|
||||
self, arg: ast.arg, default: Optional[Expr]
|
||||
) -> Function.Argument:
|
||||
loc: Location = Location.from_ast(arg)
|
||||
name: str = arg.arg
|
||||
type: Optional[MidasType] = None
|
||||
if arg.annotation is not None:
|
||||
type = self._parse_type(arg.annotation)
|
||||
return Function.Argument(
|
||||
location=loc,
|
||||
name=name,
|
||||
type=type,
|
||||
default=default,
|
||||
)
|
||||
|
||||
def _parse_type(self, type_expr: ast.expr) -> MidasType:
|
||||
loc: Location = Location.from_ast(type_expr)
|
||||
match type_expr:
|
||||
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
|
||||
return self._parse_frame_type(schema)
|
||||
|
||||
case ast.Subscript(value=ast.Name(id=name), slice=param):
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base=name,
|
||||
param=self._parse_type(param),
|
||||
)
|
||||
|
||||
case ast.Name(id=name):
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base=name,
|
||||
param=None,
|
||||
)
|
||||
|
||||
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
|
||||
left = self._parse_type(left_expr)
|
||||
match left:
|
||||
case None:
|
||||
raise InvalidSyntaxError()
|
||||
|
||||
# If chained constraints, separate base type and rebuild constraint
|
||||
case ConstraintType(type=left_type, constraint=left_constraint):
|
||||
constraint = ast.BinOp(
|
||||
left=left_constraint,
|
||||
op=ast.Add(),
|
||||
right=right_expr,
|
||||
)
|
||||
ast.copy_location(constraint, type_expr)
|
||||
return ConstraintType(
|
||||
location=loc,
|
||||
type=left_type,
|
||||
constraint=constraint,
|
||||
)
|
||||
|
||||
case _:
|
||||
return ConstraintType(
|
||||
location=loc,
|
||||
type=left,
|
||||
constraint=right_expr,
|
||||
)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(type_expr)
|
||||
|
||||
def _parse_frame_type(self, schema: ast.expr) -> FrameType:
|
||||
loc: Location = Location.from_ast(schema)
|
||||
columns: list[FrameColumn] = []
|
||||
|
||||
match schema:
|
||||
case ast.Tuple(elts=cols):
|
||||
for col in cols:
|
||||
columns.append(self._parse_frame_column(col))
|
||||
|
||||
case ast.Slice() | ast.Name():
|
||||
columns.append(self._parse_frame_column(schema))
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(schema)
|
||||
|
||||
return FrameType(location=loc, columns=columns)
|
||||
|
||||
def _parse_frame_column(self, column: ast.expr) -> FrameColumn:
|
||||
loc: Location = Location.from_ast(column)
|
||||
match column:
|
||||
case ast.Name():
|
||||
return FrameColumn(
|
||||
location=loc,
|
||||
name=None,
|
||||
type=self._parse_type(column),
|
||||
)
|
||||
|
||||
case ast.Slice(lower=ast.Name(id=name), upper=type_expr):
|
||||
if name == "_":
|
||||
name = None
|
||||
|
||||
type: Optional[MidasType] = None
|
||||
match type_expr:
|
||||
case None:
|
||||
raise InvalidSyntaxError("Missing column type")
|
||||
case ast.Name(id="_"):
|
||||
type = None
|
||||
case ast.expr():
|
||||
type = self._parse_type(type_expr)
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(type_expr)
|
||||
return FrameColumn(location=loc, name=name, type=type)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(column)
|
||||
|
||||
def parse_expr(self, node: ast.expr) -> Expr:
|
||||
location: Location = Location.from_ast(node)
|
||||
match node:
|
||||
case ast.BoolOp():
|
||||
return self.parse_bool_op(node)
|
||||
|
||||
case ast.BinOp(left=left, op=op, right=right):
|
||||
return BinaryExpr(
|
||||
location=location,
|
||||
left=self.parse_expr(left),
|
||||
operator=op,
|
||||
right=self.parse_expr(right),
|
||||
)
|
||||
|
||||
case ast.UnaryOp(op=op, operand=right):
|
||||
return UnaryExpr(
|
||||
location=location,
|
||||
operator=op,
|
||||
right=self.parse_expr(right),
|
||||
)
|
||||
|
||||
case ast.Compare():
|
||||
return self.parse_compare(node)
|
||||
|
||||
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
|
||||
return self.parse_cast(node)
|
||||
|
||||
case ast.Call():
|
||||
return self.parse_call(node)
|
||||
|
||||
case ast.IfExp():
|
||||
return self.parse_ternary(node)
|
||||
|
||||
case ast.Constant(value=value):
|
||||
return LiteralExpr(location=location, value=value)
|
||||
|
||||
case ast.Attribute(value=object, attr=name):
|
||||
return GetExpr(
|
||||
location=location,
|
||||
object=self.parse_expr(object),
|
||||
name=name,
|
||||
)
|
||||
|
||||
case ast.Name(id=name):
|
||||
return VariableExpr(location=location, name=name)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(node)
|
||||
|
||||
def parse_bool_op(self, node: ast.BoolOp) -> LogicalExpr:
|
||||
op: ast.boolop = node.op
|
||||
rights: list[Expr] = [self.parse_expr(expr) for expr in node.values]
|
||||
expr: LogicalExpr = LogicalExpr(
|
||||
location=Location.span(
|
||||
rights[0].location,
|
||||
rights[1].location,
|
||||
),
|
||||
left=rights[0],
|
||||
operator=op,
|
||||
right=rights[1],
|
||||
)
|
||||
for right in rights[2:]:
|
||||
expr = LogicalExpr(
|
||||
location=Location.span(expr.location, right.location),
|
||||
left=expr,
|
||||
operator=op,
|
||||
right=right,
|
||||
)
|
||||
return expr
|
||||
|
||||
def parse_compare(self, node: ast.Compare) -> Expr:
|
||||
ops: list[ast.cmpop] = node.ops
|
||||
left: Expr = self.parse_expr(node.left)
|
||||
rights: list[Expr] = [self.parse_expr(expr) for expr in node.comparators]
|
||||
expr: Expr = CompareExpr(
|
||||
location=Location.span(
|
||||
left.location,
|
||||
rights[0].location,
|
||||
),
|
||||
left=left,
|
||||
operator=ops[0],
|
||||
right=rights[0],
|
||||
)
|
||||
for i, right in enumerate(rights[1:]):
|
||||
comparison = CompareExpr(
|
||||
location=Location.span(rights[i].location, right.location),
|
||||
left=rights[i],
|
||||
operator=ops[i],
|
||||
right=right,
|
||||
)
|
||||
expr = LogicalExpr(
|
||||
location=Location.span(expr.location, comparison.location),
|
||||
left=expr,
|
||||
operator=ast.And(),
|
||||
right=comparison,
|
||||
)
|
||||
return expr
|
||||
|
||||
def parse_cast(self, node: ast.Call) -> CastExpr:
|
||||
match node:
|
||||
case ast.Call(args=[type, expr], keywords=[]):
|
||||
return CastExpr(
|
||||
location=Location.from_ast(node),
|
||||
type=self._parse_type(type),
|
||||
expr=self.parse_expr(expr),
|
||||
)
|
||||
case _:
|
||||
raise InvalidSyntaxError(
|
||||
f"Invalid call to {self.CAST_FUNCTION}, expected type and expression"
|
||||
)
|
||||
|
||||
def parse_call(self, node: ast.Call) -> CallExpr:
|
||||
return CallExpr(
|
||||
location=Location.from_ast(node),
|
||||
callee=self.parse_expr(node.func),
|
||||
arguments=[self.parse_expr(arg) for arg in node.args],
|
||||
keywords={
|
||||
arg.arg: self.parse_expr(arg.value)
|
||||
for arg in node.keywords
|
||||
if arg.arg is not None # Should always be True, type checker happy
|
||||
},
|
||||
)
|
||||
|
||||
def parse_ternary(self, node: ast.IfExp) -> TernaryExpr:
|
||||
return TernaryExpr(
|
||||
location=Location.from_ast(node),
|
||||
test=self.parse_expr(node.test),
|
||||
if_true=self.parse_expr(node.body),
|
||||
if_false=self.parse_expr(node.orelse),
|
||||
)
|
||||
71
midas/resolver/builtin.py
Normal file
71
midas/resolver/builtin.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from midas.checker.types import BaseType, Type, UnitType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.resolver.midas import MidasResolver
|
||||
|
||||
|
||||
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
|
||||
ctx.define_operation(
|
||||
left=t1,
|
||||
operator=operator,
|
||||
right=t2,
|
||||
result=t3,
|
||||
)
|
||||
|
||||
def basic_op(ctx: MidasResolver, type: Type, op: str):
|
||||
ctx.define_operation(
|
||||
left=type,
|
||||
operator=op,
|
||||
right=type,
|
||||
result=type,
|
||||
)
|
||||
|
||||
|
||||
def define_builtins(ctx: MidasResolver):
|
||||
"""Define builtin types and operations"""
|
||||
unit = ctx.define_type("None", UnitType())
|
||||
bool = ctx.define_type("bool", BaseType(name="bool"))
|
||||
int = ctx.define_type("int", BaseType(name="int"))
|
||||
float = ctx.define_type("float", BaseType(name="float"))
|
||||
str = ctx.define_type("str", BaseType(name="str"))
|
||||
|
||||
basic_op(ctx, int, "__add__") # int + int = int
|
||||
basic_op(ctx, int, "__sub__") # int - int = int
|
||||
basic_op(ctx, int, "__mul__") # int * int = int
|
||||
basic_op(ctx, int, "__pow__") # int ** int = int
|
||||
basic_op(ctx, int, "__mod__") # int % int = int
|
||||
basic_op(ctx, int, "__and__") # int & int = int
|
||||
basic_op(ctx, int, "__or__") # int | int = int
|
||||
basic_op(ctx, int, "__xor__") # int ^ int = int
|
||||
op(ctx, int, "__lt__", int, bool) # int < int = bool
|
||||
op(ctx, int, "__gt__", int, bool) # int > int = bool
|
||||
op(ctx, int, "__le__", int, bool) # int <= int = bool
|
||||
op(ctx, int, "__ge__", int, bool) # int >= int = bool
|
||||
op(ctx, int, "__eq__", int, bool) # int == int = bool
|
||||
basic_op(ctx, float, "__add__") # float + float = float
|
||||
basic_op(ctx, float, "__sub__") # float - float = float
|
||||
basic_op(ctx, float, "__mul__") # float * float = float
|
||||
basic_op(ctx, float, "__truediv__") # float / float = float
|
||||
op(ctx, float, "__lt__", float, bool) # float < float = bool
|
||||
op(ctx, float, "__gt__", float, bool) # float > float = bool
|
||||
op(ctx, float, "__le__", float, bool) # float <= float = bool
|
||||
op(ctx, float, "__ge__", float, bool) # float >= float = bool
|
||||
op(ctx, float, "__eq__", float, bool) # float == float = bool
|
||||
basic_op(ctx, str, "__add__") # str + str = str
|
||||
op(ctx, str, "__eq__", str, bool) # str == str = bool
|
||||
|
||||
op(ctx, int, "__lt__", float, bool) # int < float = bool
|
||||
op(ctx, int, "__gt__", float, bool) # int > float = bool
|
||||
op(ctx, int, "__le__", float, bool) # int <= float = bool
|
||||
op(ctx, int, "__ge__", float, bool) # int >= float = bool
|
||||
op(ctx, int, "__eq__", float, bool) # int == float = bool
|
||||
|
||||
op(ctx, float, "__lt__", int, bool) # float < int = bool
|
||||
op(ctx, float, "__gt__", int, bool) # float > int = bool
|
||||
op(ctx, float, "__le__", int, bool) # float <= int = bool
|
||||
op(ctx, float, "__ge__", int, bool) # float >= int = bool
|
||||
op(ctx, float, "__eq__", int, bool) # float == int = bool
|
||||
153
midas/resolver/midas.py
Normal file
153
midas/resolver/midas.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.types import BaseType, SimpleType, Type
|
||||
from midas.resolver.builtin import define_builtins
|
||||
|
||||
|
||||
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._types: dict[str, Type] = {}
|
||||
self._operations: dict[tuple[Type, str, Type], Type] = {}
|
||||
|
||||
define_builtins(self)
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
|
||||
Raises:
|
||||
NameError: if the type is not defined
|
||||
|
||||
Returns:
|
||||
Type: the type
|
||||
"""
|
||||
type: Optional[Type] = self._types.get(name)
|
||||
if type is None:
|
||||
raise NameError(f"Undefined type {name}")
|
||||
return type
|
||||
|
||||
def get_operation_result(
|
||||
self, left: Type, operator: str, right: Type
|
||||
) -> Optional[Type]:
|
||||
"""Get the resulting type of an operation
|
||||
|
||||
Args:
|
||||
left (Type): the type of the left operand
|
||||
operator (str): the operation name
|
||||
right (Type): the type of the right operand
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the result type, or None if no matching operation was found
|
||||
"""
|
||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
||||
result: Optional[Type] = self._operations.get(operation)
|
||||
return result
|
||||
|
||||
def define_type(self, name: str, type: Type) -> Type:
|
||||
"""Define a type in the registry
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
type (Type): the type to define
|
||||
|
||||
Raises:
|
||||
ValueError: if a type is already defined with that name
|
||||
|
||||
Returns:
|
||||
Type: the defined type
|
||||
"""
|
||||
if name in self._types:
|
||||
raise ValueError(f"Type {name} already defined")
|
||||
self._types[name] = type
|
||||
return type
|
||||
|
||||
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
|
||||
"""Define an operation in the registry
|
||||
|
||||
Args:
|
||||
left (Type): the type of the left operand
|
||||
operator (str): the operation name
|
||||
right (Type): the type of the right operand
|
||||
result (Type): the result type
|
||||
|
||||
Raises:
|
||||
ValueError: if an operation is already defined with these operands and name
|
||||
"""
|
||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
||||
if operation in self._operations:
|
||||
raise ValueError(
|
||||
f"Operation {operator} already defined between {left} and {right}"
|
||||
)
|
||||
self._operations[operation] = result
|
||||
|
||||
def resolve(self, stmts: list[m.Stmt]):
|
||||
"""Process a sequence of statements
|
||||
|
||||
Args:
|
||||
stmts (list[m.Stmt]): the statements
|
||||
"""
|
||||
for stmt in stmts:
|
||||
stmt.accept(self)
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
|
||||
# TODO generics, optional, constraint
|
||||
base: Type = self.get_type(stmt.base.name.lexeme)
|
||||
match base:
|
||||
case BaseType() | SimpleType():
|
||||
type = SimpleType(
|
||||
name=stmt.name.lexeme,
|
||||
base=base,
|
||||
)
|
||||
self.define_type(type.name, type)
|
||||
case _:
|
||||
raise TypeError(f"Invalid base {base} for simple type")
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None: ...
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
base: Type = stmt.type.accept(self)
|
||||
for op in stmt.operations:
|
||||
right: Type = op.operand.accept(self)
|
||||
result: Type = op.result.accept(self)
|
||||
self.define_operation(
|
||||
left=base,
|
||||
operator=op.name.lexeme,
|
||||
right=right,
|
||||
result=result,
|
||||
)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
|
||||
|
||||
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> Type:
|
||||
return self.get_type(expr.name.lexeme)
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type: ...
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type: ...
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: ...
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> Type: ...
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> Type: ...
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
|
||||
return expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type: ...
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type: ...
|
||||
|
||||
def visit_template_expr(self, expr: m.TemplateExpr) -> Type: ...
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr) -> Type:
|
||||
return self.get_type(expr.name.lexeme)
|
||||
187
midas/resolver/resolver.py
Normal file
187
midas/resolver/resolver.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import midas.ast.python as p
|
||||
|
||||
|
||||
class ResolverError(Exception): ...
|
||||
|
||||
|
||||
class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
"""A variable assignment and reference resolver
|
||||
|
||||
This class keeps track of which scope a variable is defined in and which
|
||||
scope is referred to when a variable is referenced
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.locals: dict[p.Expr, int] = {}
|
||||
self.scopes: list[dict[str, bool]] = []
|
||||
|
||||
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
|
||||
"""Resolve the given statements or expressions"""
|
||||
|
||||
for obj in objects:
|
||||
obj.accept(self)
|
||||
|
||||
def begin_scope(self):
|
||||
"""Begin a new scope inside the current one"""
|
||||
self.scopes.append({})
|
||||
|
||||
def end_scope(self):
|
||||
"""Close the current scope"""
|
||||
self.scopes.pop()
|
||||
|
||||
def declare(self, name: str) -> None:
|
||||
"""Declare a variable in the current scope
|
||||
|
||||
This method must be called *before* evaluating the variable initializer
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
|
||||
Raises:
|
||||
ResolverError: if the variable has already been declared in the current scope
|
||||
"""
|
||||
if len(self.scopes) == 0:
|
||||
return
|
||||
scope: dict[str, bool] = self.scopes[-1]
|
||||
if name in scope:
|
||||
raise ResolverError(
|
||||
f"A variable with the name {name} is already declared in this scope"
|
||||
)
|
||||
scope[name] = False
|
||||
|
||||
def define(self, name: str) -> None:
|
||||
"""Define a variable in the current scope
|
||||
|
||||
This method must be called *after* evaluating the variable initializer
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
"""
|
||||
if len(self.scopes) == 0:
|
||||
return
|
||||
self.scopes[-1][name] = True
|
||||
|
||||
def resolve_local(self, expr: p.Expr, name: str) -> None:
|
||||
"""Resolve a variable reference and store the scope distance
|
||||
|
||||
This method associates to the variable expression a number representing
|
||||
the "distance" of the variable declaration, i.e. the number of scope
|
||||
levels to go "up" to find the closest declaration for that variable.
|
||||
|
||||
Args:
|
||||
expr (p.Expr): the variable expression
|
||||
name (str): the name of the variable
|
||||
"""
|
||||
for i, scope in enumerate(reversed(self.scopes)):
|
||||
if name in scope:
|
||||
self.locals[expr] = i
|
||||
return
|
||||
|
||||
def resolve_function(self, function: p.Function) -> None:
|
||||
"""Resolve a function definition
|
||||
|
||||
This method creates a new scope for the function, resolves all the
|
||||
parameter declarations and then the body.
|
||||
|
||||
Args:
|
||||
function (p.Function): the function to resolve
|
||||
"""
|
||||
self.begin_scope()
|
||||
for param in function.all_args:
|
||||
self.declare(param.name)
|
||||
self.define(param.name)
|
||||
self.resolve(*function.body)
|
||||
self.end_scope()
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
stmt.expr.accept(self)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
# Declare before resolving body to allow recursion
|
||||
self.declare(stmt.name)
|
||||
self.define(stmt.name)
|
||||
self.resolve_function(stmt)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
self.declare(stmt.name)
|
||||
# NOTE: resolve type here?
|
||||
self.define(stmt.name)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
self.resolve(stmt.value)
|
||||
for target in stmt.targets:
|
||||
match target:
|
||||
case p.VariableExpr(name=name):
|
||||
self.resolve_local(target, name)
|
||||
# TODO: declare if not found
|
||||
case _:
|
||||
raise Exception(f"Unsupported assignment to {target}")
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
if stmt.value is not None:
|
||||
self.resolve(stmt.value)
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
# Not resolved in sub-environment because assignments in the test leak out of the if
|
||||
# For example:
|
||||
# if (m := 1 + 1) < 2:
|
||||
# ...
|
||||
# print(m) # <- m is still defined
|
||||
self.resolve(stmt.test)
|
||||
|
||||
# Body
|
||||
self.begin_scope()
|
||||
self.resolve(*stmt.body)
|
||||
self.end_scope()
|
||||
|
||||
# Else
|
||||
self.begin_scope()
|
||||
self.resolve(*stmt.orelse)
|
||||
self.end_scope()
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> None:
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> None:
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> None:
|
||||
self.resolve(expr.callee)
|
||||
for arg in expr.arguments:
|
||||
self.resolve(arg)
|
||||
for arg in expr.keywords.values():
|
||||
self.resolve(arg)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> None:
|
||||
self.resolve(expr.object)
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||
pass
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||
if len(self.scopes) != 0 and self.scopes[-1].get(expr.name) is False:
|
||||
raise ResolverError(
|
||||
f"Cannot use local variable '{expr.name}' in its own initializer"
|
||||
) # aka. UnboundLocalError
|
||||
self.resolve_local(expr, expr.name)
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None:
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> None:
|
||||
self.resolve(expr.value)
|
||||
self.resolve(expr.object)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self.resolve(expr.expr)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||
self.resolve(expr.test)
|
||||
self.resolve(expr.if_true)
|
||||
self.resolve(expr.if_false)
|
||||
54
midas/utils.py
Normal file
54
midas/utils.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
AllowRepeat = Callable[[object], bool]
|
||||
|
||||
|
||||
class UniversalJSONDumper:
|
||||
@classmethod
|
||||
def dump(
|
||||
cls,
|
||||
obj: Any,
|
||||
include_keys: Optional[list[str | tuple[str, str]]] = None,
|
||||
allow_repeat: Optional[AllowRepeat] = None,
|
||||
) -> Any:
|
||||
if include_keys is None:
|
||||
include_keys = []
|
||||
return cls._dump(obj, include_keys, allow_repeat, [])
|
||||
|
||||
@classmethod
|
||||
def _dump(
|
||||
cls,
|
||||
obj: Any,
|
||||
include_keys: list[str | tuple[str, str]],
|
||||
allow_repeat: Optional[AllowRepeat],
|
||||
visited: list[Any],
|
||||
) -> Any:
|
||||
if obj in visited:
|
||||
return None
|
||||
match obj:
|
||||
case str() | int() | float() | None:
|
||||
return obj
|
||||
case list() | set() | tuple():
|
||||
return [
|
||||
cls._dump(child, include_keys, allow_repeat, visited)
|
||||
for child in obj
|
||||
]
|
||||
case dict():
|
||||
return {
|
||||
str(k): cls._dump(v, include_keys, allow_repeat, visited)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
case object():
|
||||
if allow_repeat is None or not allow_repeat(obj):
|
||||
visited.append(obj)
|
||||
return {
|
||||
"_type": obj.__class__.__name__,
|
||||
} | {
|
||||
k: cls._dump(v, include_keys, allow_repeat, visited)
|
||||
for k, v in obj.__dict__.items()
|
||||
if not k.startswith("_")
|
||||
or k in include_keys
|
||||
or (obj.__class__.__name__, k) in include_keys
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported value: {obj}")
|
||||
@@ -1,152 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.ast.annotations import (
|
||||
AnnotationStmt,
|
||||
ConstraintExpr,
|
||||
Expr,
|
||||
LiteralExpr,
|
||||
SchemaElementExpr,
|
||||
SchemaExpr,
|
||||
Stmt,
|
||||
TypeExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
from lexer.token import Token, TokenType
|
||||
from parser.base import Parser
|
||||
from parser.errors import ParsingError
|
||||
|
||||
|
||||
class AnnotationParser(Parser):
|
||||
"""A simple parser for custom type annotations"""
|
||||
|
||||
SYNC_BOUNDARY: set[TokenType] = set()
|
||||
|
||||
def parse(self) -> Optional[Stmt]:
|
||||
stmt: Optional[Stmt] = None
|
||||
try:
|
||||
stmt = self.annotation()
|
||||
except ParsingError:
|
||||
self.synchronize()
|
||||
if not self.is_at_end():
|
||||
self.error(self.peek(), "Extra tokens")
|
||||
return stmt
|
||||
|
||||
def synchronize(self):
|
||||
"""Skip tokens until a synchronization boundary is found
|
||||
|
||||
This method allows gracefully recovering from a parse error
|
||||
to a safe place and continue parsing
|
||||
"""
|
||||
self.advance()
|
||||
while not self.is_at_end():
|
||||
if self.peek().type in self.SYNC_BOUNDARY:
|
||||
return
|
||||
self.advance()
|
||||
|
||||
def annotation(self) -> AnnotationStmt:
|
||||
"""Parse an annotation
|
||||
|
||||
An annotation is written as `Type` or `Type[Schema]`
|
||||
|
||||
Returns:
|
||||
AnnotationStmt: the parsed annotation statement
|
||||
"""
|
||||
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type identifier")
|
||||
schema: Optional[SchemaExpr] = None
|
||||
if self.match(TokenType.LEFT_BRACKET):
|
||||
schema = self.schema()
|
||||
return AnnotationStmt(name=name, schema=schema)
|
||||
|
||||
def type_expr(self) -> TypeExpr:
|
||||
"""Parse a type expression
|
||||
|
||||
Returns:
|
||||
TypeExpr: the parsed type expression
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
constraints: list[ConstraintExpr] = []
|
||||
|
||||
while not self.is_at_end() and self.match(TokenType.PLUS):
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before type constraint")
|
||||
constraints.append(self.constraint_expr())
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after type constraint")
|
||||
|
||||
return TypeExpr(name=name, constraints=constraints)
|
||||
|
||||
def constraint_expr(self) -> ConstraintExpr:
|
||||
"""Parse a type constraint
|
||||
|
||||
Returns:
|
||||
ConstraintExpr: the parsed type constraint expression
|
||||
"""
|
||||
|
||||
left: Expr = self.constraint_value()
|
||||
op: Token = self.constraint_operator()
|
||||
right: Expr = self.constraint_value()
|
||||
return ConstraintExpr(left=left, op=op, right=right)
|
||||
|
||||
def constraint_value(self) -> Expr:
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
return WildcardExpr(self.previous())
|
||||
return self.literal()
|
||||
|
||||
def literal(self) -> LiteralExpr:
|
||||
if self.match(TokenType.FALSE):
|
||||
return LiteralExpr(False)
|
||||
if self.match(TokenType.TRUE):
|
||||
return LiteralExpr(True)
|
||||
if self.match(TokenType.NONE):
|
||||
return LiteralExpr(None)
|
||||
|
||||
if self.match(TokenType.NUMBER):
|
||||
return LiteralExpr(self.previous().value)
|
||||
|
||||
raise self.error(self.peek(), "Expected literal")
|
||||
|
||||
def constraint_operator(self) -> Token:
|
||||
if self.match(TokenType.LESS, TokenType.LESS_EQUAL, TokenType.GREATER, TokenType.GREATER_EQUAL, TokenType.EQUAL_EQUAL, TokenType.BANG_EQUAL):
|
||||
return self.previous()
|
||||
raise self.error(self.peek(), "Expected constraint operator")
|
||||
|
||||
def schema(self) -> SchemaExpr:
|
||||
"""Parse a schema definition
|
||||
|
||||
A comma separated list of schema elements
|
||||
|
||||
Returns:
|
||||
SchemaExpr: the parsed schema expression
|
||||
"""
|
||||
left: Token = self.previous()
|
||||
elements: list[Expr] = []
|
||||
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
|
||||
elements.append(self.schema_element())
|
||||
if not self.check(TokenType.RIGHT_BRACKET):
|
||||
self.consume(TokenType.COMMA, "Expected ',' between schema elements")
|
||||
|
||||
right: Token = self.consume(TokenType.RIGHT_BRACKET, "Unclosed schema")
|
||||
return SchemaExpr(left=left, elements=elements, right=right)
|
||||
|
||||
def schema_element(self) -> SchemaElementExpr:
|
||||
"""Parse a schema element
|
||||
|
||||
An anonymous element (`_`), a type, an untyped named column (`name: _`),
|
||||
or a named column (`name: Type`)
|
||||
|
||||
Returns:
|
||||
SchemaElementExpr: the parsed schema element expression
|
||||
"""
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
return SchemaElementExpr(name=None, type=None)
|
||||
|
||||
if not self.check(TokenType.IDENTIFIER):
|
||||
raise self.error(self.peek(), "Expected schema element")
|
||||
|
||||
name: Optional[Token] = None
|
||||
type: Optional[TypeExpr] = None
|
||||
if self.check_next(TokenType.COLON):
|
||||
name = self.advance()
|
||||
self.advance()
|
||||
if not self.match(TokenType.UNDERSCORE):
|
||||
type = self.type_expr()
|
||||
return SchemaElementExpr(name=name, type=type)
|
||||
217
parser/midas.py
217
parser/midas.py
@@ -1,217 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.ast.midas import (
|
||||
ConstraintExpr,
|
||||
ConstraintStmt,
|
||||
Expr,
|
||||
LiteralExpr,
|
||||
OpStmt,
|
||||
PropertyStmt,
|
||||
Stmt,
|
||||
TypeBodyExpr,
|
||||
TypeExpr,
|
||||
TypeStmt,
|
||||
WildcardExpr,
|
||||
)
|
||||
from lexer.token import Token, TokenType
|
||||
from parser.base import Parser
|
||||
from parser.errors import ParsingError
|
||||
|
||||
|
||||
class MidasParser(Parser):
|
||||
"""A simple parser for midas type definitions"""
|
||||
|
||||
SYNC_BOUNDARY: set[TokenType] = {TokenType.TYPE, TokenType.OP, TokenType.CONSTRAINT}
|
||||
|
||||
def parse(self) -> list[Stmt]:
|
||||
statements: list[Stmt] = []
|
||||
while not self.is_at_end():
|
||||
stmt: Optional[Stmt] = self.declaration()
|
||||
if stmt is None:
|
||||
print("Early stop")
|
||||
break
|
||||
statements.append(stmt)
|
||||
return statements
|
||||
|
||||
def synchronize(self):
|
||||
"""Skip tokens until a synchronization boundary is found
|
||||
|
||||
This method allows gracefully recovering from a parse error
|
||||
to a safe place and continue parsing
|
||||
"""
|
||||
self.advance()
|
||||
while not self.is_at_end():
|
||||
if self.previous().type == TokenType.NEWLINE:
|
||||
return
|
||||
if self.peek().type in self.SYNC_BOUNDARY:
|
||||
return
|
||||
self.advance()
|
||||
|
||||
def declaration(self) -> Optional[Stmt]:
|
||||
"""Try and parse a declaration
|
||||
|
||||
Any parsing error is caught and None is returned
|
||||
|
||||
Returns:
|
||||
Optional[Stmt]: the parsed Midas statement, or None if a ParsingError was raised
|
||||
"""
|
||||
try:
|
||||
if self.match(TokenType.TYPE):
|
||||
return self.type_declaration()
|
||||
if self.match(TokenType.OP):
|
||||
return self.op_declaration()
|
||||
if self.match(TokenType.CONSTRAINT):
|
||||
return self.constraint_declaration()
|
||||
raise self.error(self.peek(), "Unexpected token")
|
||||
except ParsingError:
|
||||
self.synchronize()
|
||||
return None
|
||||
|
||||
def type_declaration(self) -> TypeStmt:
|
||||
"""Parse a type declaration
|
||||
|
||||
A type declaration is written `type Name<TypeExpr, ...>` optionally followed by a brace-wrapped body
|
||||
|
||||
Returns:
|
||||
TypeStmt: the parsed type declaration statement
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
self.consume(TokenType.LESS, "Expected '<' after type name")
|
||||
bases: list[TypeExpr] = []
|
||||
while not self.check(TokenType.GREATER) and not self.is_at_end():
|
||||
bases.append(self.type_expr())
|
||||
if not self.check(TokenType.GREATER):
|
||||
self.consume(TokenType.COMMA, "Expected ',' between type bases")
|
||||
self.consume(TokenType.GREATER, "Expected '>' after base type")
|
||||
|
||||
body: Optional[TypeBodyExpr] = None
|
||||
|
||||
if self.check(TokenType.LEFT_BRACE):
|
||||
body = self.type_body_expr()
|
||||
return TypeStmt(name=name, bases=bases, body=body)
|
||||
|
||||
def type_expr(self) -> TypeExpr:
|
||||
"""Parse a type expression
|
||||
|
||||
Returns:
|
||||
TypeExpr: the parsed type expression
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
constraints: list[ConstraintExpr] = []
|
||||
|
||||
while not self.is_at_end() and self.match(TokenType.PLUS):
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before type constraint")
|
||||
constraints.append(self.constraint_expr())
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after type constraint")
|
||||
|
||||
return TypeExpr(name=name, constraints=constraints)
|
||||
|
||||
def constraint_expr(self) -> ConstraintExpr:
|
||||
"""Parse a type constraint
|
||||
|
||||
Returns:
|
||||
ConstraintExpr: the parsed type constraint expression
|
||||
"""
|
||||
|
||||
left: Expr = self.constraint_value()
|
||||
op: Token = self.constraint_operator()
|
||||
right: Expr = self.constraint_value()
|
||||
return ConstraintExpr(left=left, op=op, right=right)
|
||||
|
||||
def constraint_value(self) -> Expr:
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
return WildcardExpr(self.previous())
|
||||
return self.literal()
|
||||
|
||||
def literal(self) -> LiteralExpr:
|
||||
if self.match(TokenType.FALSE):
|
||||
return LiteralExpr(False)
|
||||
if self.match(TokenType.TRUE):
|
||||
return LiteralExpr(True)
|
||||
if self.match(TokenType.NONE):
|
||||
return LiteralExpr(None)
|
||||
|
||||
if self.match(TokenType.NUMBER):
|
||||
return LiteralExpr(self.previous().value)
|
||||
|
||||
raise self.error(self.peek(), "Expected literal")
|
||||
|
||||
def constraint_operator(self) -> Token:
|
||||
if self.match(
|
||||
TokenType.LESS,
|
||||
TokenType.LESS_EQUAL,
|
||||
TokenType.GREATER,
|
||||
TokenType.GREATER_EQUAL,
|
||||
TokenType.EQUAL_EQUAL,
|
||||
TokenType.BANG_EQUAL,
|
||||
):
|
||||
return self.previous()
|
||||
raise self.error(self.peek(), "Expected constraint operator")
|
||||
|
||||
def type_body_expr(self) -> TypeBodyExpr:
|
||||
"""Parse a type definition body
|
||||
|
||||
A type definition body is a set of whitespace-separated
|
||||
property statements enclosed in curly braces
|
||||
|
||||
Returns:
|
||||
TypeBodyExpr: the parsed type body expression
|
||||
"""
|
||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body")
|
||||
properties: list[PropertyStmt] = []
|
||||
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
||||
properties.append(self.property_stmt())
|
||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
||||
return TypeBodyExpr(properties=properties)
|
||||
|
||||
def property_stmt(self) -> PropertyStmt:
|
||||
"""Parse a property statement
|
||||
|
||||
A type property statement is written `name: Type`
|
||||
|
||||
Returns:
|
||||
PropertyStmt: the parsed property statement
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after property name")
|
||||
type: TypeExpr = self.type_expr()
|
||||
return PropertyStmt(name=name, type=type)
|
||||
|
||||
def op_declaration(self) -> OpStmt:
|
||||
"""Parse an operation definition
|
||||
|
||||
An operation is written `op <Type1> operator <Type2> = <Type3>` where `operator` can be any single token
|
||||
|
||||
Returns:
|
||||
OpStmt: the parsed operation statement
|
||||
"""
|
||||
self.consume(TokenType.LESS, "Expected '<' before first type")
|
||||
left: TypeExpr = self.type_expr()
|
||||
self.consume(TokenType.GREATER, "Expected '>' after first type")
|
||||
|
||||
op: Token = self.advance()
|
||||
|
||||
self.consume(TokenType.LESS, "Expected '<' before second type")
|
||||
right: TypeExpr = self.type_expr()
|
||||
self.consume(TokenType.GREATER, "Expected '>' after second type")
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' after second type")
|
||||
|
||||
self.consume(TokenType.LESS, "Expected '<' before result type")
|
||||
result: TypeExpr = self.type_expr()
|
||||
self.consume(TokenType.GREATER, "Expected '>' after result type")
|
||||
|
||||
return OpStmt(left=left, op=op, right=right, result=result)
|
||||
|
||||
def constraint_declaration(self) -> ConstraintStmt:
|
||||
"""Parse a type constraint declaration
|
||||
|
||||
A constraint is written `constraint Name = constraint_expression`
|
||||
|
||||
Returns:
|
||||
ConstraintStmt: the parsed constraint declaration statement
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected constraint name")
|
||||
self.consume(TokenType.EQUAL, "Expected '=' after constraint name")
|
||||
constraint: ConstraintExpr = self.constraint_expr()
|
||||
return ConstraintStmt(name=name, constraint=constraint)
|
||||
22
pyproject.toml
Normal file
22
pyproject.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[project]
|
||||
name = "midas"
|
||||
version = "0.1.0"
|
||||
description = "A static-first type checking framework for Python data-frames"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
authors = [
|
||||
{ name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" },
|
||||
]
|
||||
classifiers = ["Programming Language :: Python :: 3"]
|
||||
dependencies = ["click>=8.4.1"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://git.kbk28.ch/HEL/midas"
|
||||
Repository = "https://git.kbk28.ch/HEL/midas"
|
||||
|
||||
[project.scripts]
|
||||
midas = "midas.cli.main:midas"
|
||||
|
||||
[build-system]
|
||||
requires = ['hatchling']
|
||||
build-backend = 'hatchling.build'
|
||||
@@ -1,26 +1,35 @@
|
||||
identifier ::= '[a-zA-Z][a-zA-Z_]*'
|
||||
// W3C EBNF syntax definition for Midas
|
||||
Identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
|
||||
|
||||
integer ::= '\d+'
|
||||
number ::= integer ["." integer]
|
||||
boolean ::= "False" | "True"
|
||||
none ::= "None"
|
||||
Integer ::= '\d+'
|
||||
Number ::= "-"? Integer ("." Integer)?
|
||||
Boolean ::= "False" | "True"
|
||||
None ::= "None"
|
||||
|
||||
value ::= number | boolean | none
|
||||
lambda-value ::= "_" | value
|
||||
lambda-operator ::= ">" | "<" | ">=" | "<=" | "==" | "!="
|
||||
lambda ::= lambda-value lambda-operator lambda-value
|
||||
Value ::= Number | Boolean | None
|
||||
|
||||
constraint ::= identifier | "(" lambda ")"
|
||||
base-type ::= identifier
|
||||
type ::= base-type { "+" constraint }
|
||||
ComparisonOp ::= ">" | "<" | ">=" | "<="
|
||||
EqualityOp ::= "==" | "!="
|
||||
|
||||
type-property ::= 'identifier' ":" 'type'
|
||||
type-body ::= "{" { 'type-property' } "}"
|
||||
Grouping ::= "(" Constraint ")"
|
||||
Primary ::= "_" | Value | Identifier | Grouping
|
||||
Reference ::= Primary ("." Identifier)*
|
||||
Unary ::= "-"? Unary | Reference
|
||||
Comparison ::= Unary (ComparisonOp Unary)*
|
||||
Equality ::= Comparison (EqualityOp Comparison)*
|
||||
Constraint ::= Equality ("&" Equality)*
|
||||
|
||||
operation-type ::= "<" 'type' ">"
|
||||
SimpleType ::= Identifier "?"?
|
||||
Template ::= "[" Type "]"
|
||||
Type ::= Identifier Template? "?"?
|
||||
|
||||
type-statement ::= "type" 'identifier' "<" 'type' {"," 'type'} ">" ['type-body']
|
||||
operation-statement ::= "op" 'operation-type' 'operator' 'operation-type' "=" 'operation-type'
|
||||
constraint-statement ::= "constraint" 'identifier' "=" 'lambda'
|
||||
TypeProperty ::= Identifier ":" Type ("where" Constraints)?
|
||||
ComplexTypeBody ::= "{" TypeProperty* "}"
|
||||
OpDefinition ::= "op" Identifier "(" Type ")" "->" Type
|
||||
ExtendBody ::= "{" OpDefinition* "}"
|
||||
|
||||
statement ::= type-statement | operation-statement | constraint-statement
|
||||
TypeStatement ::= "type" Identifier Template? ("(" Type ")" ("where" Constraint)? | ComplexTypeBody)
|
||||
ExtendStatement ::= "extend" Type ExtendBody
|
||||
PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint
|
||||
|
||||
Statement ::= TypeStatement | ExtendStatement | PredicateStatement
|
||||
|
||||
172
syntax/midas.typ
172
syntax/midas.typ
@@ -1,4 +1,11 @@
|
||||
#import "@preview/fervojo:0.1.1": render
|
||||
#import "@preview/fervojo:0.1.1": default-css, render
|
||||
|
||||
#let extra-css = ```css
|
||||
svg.railroad .terminal rect {
|
||||
fill: #F7DCD4;
|
||||
}
|
||||
```
|
||||
#let css = default-css() + bytes(extra-css.text)
|
||||
|
||||
#let value = ```
|
||||
{[`value` <
|
||||
@@ -8,90 +15,157 @@
|
||||
>]}
|
||||
```
|
||||
|
||||
#let constraint = ```
|
||||
{[`constraint` <"_", 'value'> <">", "<", ">=", "<=", "==", "!="> <"_", 'value'>]}
|
||||
#let grouping = ```
|
||||
{[`grouping` "(" 'constraint' ")"]}
|
||||
```
|
||||
|
||||
#let type-with-constraints = ```
|
||||
{[`type-with-constraints` 'identifier' <!, ["+" "(" 'constraint' ")"] * !>]}
|
||||
#let primary = ```
|
||||
{[`primary` <"_", 'value', 'identifier', 'grouping'>]}
|
||||
```
|
||||
|
||||
#let reference = ```
|
||||
{[`reference` 'primary' <!, ["." 'identifier']*!>]}
|
||||
```
|
||||
|
||||
#let unary = ```
|
||||
{[`unary` <[<!, "-"> 'unary'], 'reference'>]}
|
||||
```
|
||||
|
||||
#let comparison = ```
|
||||
{[`comparison` 'unary'*<">", "<", ">=", "<=">]}
|
||||
```
|
||||
|
||||
#let equality = ```
|
||||
{[`equality` 'comparison'*<"==", "!=">]}
|
||||
```
|
||||
|
||||
#let constraint = ```
|
||||
{[`constraint` 'equality'*"&"]}
|
||||
```
|
||||
|
||||
#let simple-type = ```
|
||||
{[`simple-type` 'identifier' <!, "?">]}
|
||||
```
|
||||
|
||||
#let template = ```
|
||||
{[`template` "[" 'type' "]"]}
|
||||
```
|
||||
|
||||
#let type = ```
|
||||
{[`type` 'identifier' <!, 'template'> <!, "?">]}
|
||||
```
|
||||
|
||||
#let type-property = ```
|
||||
{[`type-property` 'identifier' ":" 'type-with-constraints']}
|
||||
{[`type-property` 'identifier' ":" 'type' <!, ["where" 'constraint']>]}
|
||||
```
|
||||
|
||||
#let type-body = ```
|
||||
{[`type-body` "{" <!, 'type-property'*!> "}"]}
|
||||
```
|
||||
|
||||
#let operation-type = ```
|
||||
{[`operation-type` "<" 'type-with-constraints' ">"]}
|
||||
```
|
||||
|
||||
#let type-statement = ```
|
||||
{[`type-statement` "type" 'identifier' "<" 'type-with-constraints'*"," ">" <!, 'type-body'>]}
|
||||
{[`type-statement` "type" 'identifier' <!, 'template'> <[["(" 'type' ")"] <!, ["where" 'constraint']>], 'type-body'>]}
|
||||
```
|
||||
|
||||
#let operation-statement = ```
|
||||
{[`operation-statement` "op" 'operation-type' "operator" 'operation-type' "=" 'operation-type']}
|
||||
#let op-definition = ```
|
||||
{[`op-definition` "op" 'identifier' "(" 'type' ")" "->" 'type']}
|
||||
```
|
||||
|
||||
#let constraint-statement = ```
|
||||
{[`constraint-statement` "constraint" 'identifier' "=" 'constraint']}
|
||||
#let extend-statement = ```
|
||||
{[`extend-statement` "extend" 'type' "{" <!, 'op-definition'*!> "}"]}
|
||||
```
|
||||
|
||||
#let predicate-statement = ```
|
||||
{[`predicate-statement` "predicate" 'identifier' "(" 'identifier' ":" 'type' ")" "=" 'constraint']}
|
||||
```
|
||||
|
||||
#let statement = ```
|
||||
{[`statement` <'type-statement', 'operation-statement', 'constraint-statement'>]}
|
||||
{[`statement` <'type-statement', 'extend-statement', 'predicate-statement'>]}
|
||||
```
|
||||
|
||||
#let rules = (
|
||||
value,
|
||||
constraint,
|
||||
type-with-constraints,
|
||||
type-property,
|
||||
type-body,
|
||||
operation-type,
|
||||
type-statement,
|
||||
operation-statement,
|
||||
constraint-statement,
|
||||
statement,
|
||||
value: value,
|
||||
grouping: grouping,
|
||||
primary: primary,
|
||||
reference: reference,
|
||||
unary: unary,
|
||||
comparison: comparison,
|
||||
equality: equality,
|
||||
constraint: constraint,
|
||||
simple-type: simple-type,
|
||||
template: template,
|
||||
type: type,
|
||||
type-property: type-property,
|
||||
type-body: type-body,
|
||||
type-statement: type-statement,
|
||||
op-definition: op-definition,
|
||||
extend-statement: extend-statement,
|
||||
predicate-statement: predicate-statement,
|
||||
statement: statement,
|
||||
)
|
||||
|
||||
#let inline = (
|
||||
"grouping",
|
||||
"value",
|
||||
"template",
|
||||
"simple-type",
|
||||
"type-property",
|
||||
"type-body",
|
||||
"op-definition",
|
||||
"type-statement",
|
||||
"extend-statement",
|
||||
"predicate-statement",
|
||||
)
|
||||
|
||||
#set text(font: "Source Sans 3")
|
||||
|
||||
= Midas type definition syntax
|
||||
#title[Midas type definition syntax]
|
||||
|
||||
#for rule in rules {
|
||||
render(rule)
|
||||
}
|
||||
= Outline
|
||||
|
||||
/*
|
||||
#let by-name = (
|
||||
value: value,
|
||||
constraint: constraint,
|
||||
type-with-constraints: type-with-constraints,
|
||||
type-property: type-property,
|
||||
type-body: type-body,
|
||||
operation-type: operation-type,
|
||||
type-statement: type-statement,
|
||||
operation-statement: operation-statement,
|
||||
constraint-statement: constraint-statement,
|
||||
#box(
|
||||
columns(
|
||||
2,
|
||||
outline(title: none),
|
||||
),
|
||||
height: 9cm,
|
||||
stroke: 1pt,
|
||||
inset: 1em,
|
||||
)
|
||||
|
||||
= Statements and expressions
|
||||
|
||||
#for (name, rule) in rules.pairs().rev() {
|
||||
[== #name]
|
||||
render(rule, css: css)
|
||||
}
|
||||
|
||||
#let substitute(base-rule) = {
|
||||
let new-rule = base-rule
|
||||
for (key, rule) in by-name.pairs() {
|
||||
new-rule = new-rule.replace("'" + key + "'", rule.text.slice(1, -1))
|
||||
for name in inline {
|
||||
let rule = rules.at(name)
|
||||
let replacement = rule.text.slice(1, -1).replace(regex("\[`.*?`"), "[")
|
||||
replacement = "[" + replacement + "#`" + name + "`]"
|
||||
new-rule = new-rule.replace(
|
||||
"'" + name + "'",
|
||||
replacement,
|
||||
)
|
||||
}
|
||||
if new-rule != base-rule {
|
||||
new-rule = substitute(new-rule)
|
||||
}
|
||||
return new-rule.replace(regex("`.*?`"), "")
|
||||
return new-rule
|
||||
}
|
||||
|
||||
#let combined = raw(substitute(statement.text))
|
||||
|
||||
|
||||
#set page(flipped: true)
|
||||
#render(combined)
|
||||
*/
|
||||
|
||||
= Combined rules
|
||||
|
||||
#for (name, rule) in rules.pairs() {
|
||||
if not name in inline {
|
||||
[== #name]
|
||||
let combined = substitute(rule.text)
|
||||
render(raw(combined), css: css)
|
||||
//raw(block: true, combined)
|
||||
}
|
||||
}
|
||||
|
||||
35
test.py
35
test.py
@@ -1,40 +1,21 @@
|
||||
import importlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from core.ast.printer import AnnotationAstPrinter, MidasAstPrinter
|
||||
from lexer.annotations import AnnotationLexer
|
||||
from lexer.midas import MidasLexer
|
||||
from lexer.token import Token
|
||||
from parser.annotations import AnnotationParser
|
||||
from parser.midas import MidasParser
|
||||
|
||||
|
||||
def test_annotation():
|
||||
# Frame annotation
|
||||
mod = importlib.import_module("examples.00_syntax_prototype.01_simple_types")
|
||||
|
||||
annotation: str = mod.__annotations__["df"]
|
||||
lexer: AnnotationLexer = AnnotationLexer(annotation, "01_simple_types.py")
|
||||
tokens: list[Token] = lexer.process()
|
||||
# print([f"{t.type.name}('{t.lexeme}')" for t in tokens])
|
||||
|
||||
parser = AnnotationParser(tokens)
|
||||
parsed = parser.parse()
|
||||
print(parsed)
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
printer = AnnotationAstPrinter()
|
||||
if parsed is not None:
|
||||
print(printer.print(parsed))
|
||||
from midas.ast.printer import MidasAstPrinter
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
|
||||
|
||||
def test_midas():
|
||||
# Midas type definitions
|
||||
path: Path = Path("examples") / "00_syntax_prototype" / "02_custom_types.midas"
|
||||
path: Path = Path("examples") / "00_syntax_prototype" / "03_custom_types_v2.midas"
|
||||
definitions: str = path.read_text()
|
||||
midas_lexer: MidasLexer = MidasLexer(definitions, path.name)
|
||||
tokens: list[Token] = midas_lexer.process()
|
||||
# print([f"{t.type.name}('{t.lexeme}')" for t in tokens])
|
||||
with open("tokens.json", "w") as f:
|
||||
json.dump([f"{t.type.name}('{t.lexeme}')" for t in tokens], f, indent=4)
|
||||
|
||||
parser = MidasParser(tokens)
|
||||
parsed = parser.parse()
|
||||
|
||||
143
tests/base.py
Normal file
143
tests/base.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import difflib
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Protocol
|
||||
|
||||
|
||||
class CaseResult(Protocol):
|
||||
def dumps(self) -> str: ...
|
||||
|
||||
|
||||
class Tester(ABC):
|
||||
"""A test runner to check for regressions in the lexer and parser"""
|
||||
|
||||
CASES_DIR: Path = Path(__file__).parent / "cases"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def namespace(self) -> str: ...
|
||||
|
||||
@property
|
||||
def base_dir(self) -> Path:
|
||||
return self.CASES_DIR / self.namespace
|
||||
|
||||
@abstractmethod
|
||||
def _list_tests(self) -> list[Path]: ...
|
||||
|
||||
def run_all_tests(self) -> bool:
|
||||
paths: list[Path] = self._list_tests()
|
||||
return self.run_tests(paths)
|
||||
|
||||
def run_tests(self, tests: list[Path]) -> bool:
|
||||
rule: str = "-" * 80
|
||||
n: int = len(tests)
|
||||
successes: int = 0
|
||||
failures: int = 0
|
||||
|
||||
print(rule)
|
||||
for i, test in enumerate(tests):
|
||||
print(f"Case {i+1}/{n}: {test.relative_to(self.CASES_DIR)}")
|
||||
success: bool = self._run_test(test)
|
||||
if success:
|
||||
successes += 1
|
||||
else:
|
||||
failures += 1
|
||||
|
||||
print(rule)
|
||||
print(f"Success: {successes}/{n}")
|
||||
print(f"Failed: {failures}/{n}")
|
||||
print(rule)
|
||||
return failures == 0
|
||||
|
||||
def _run_test(self, path: Path) -> bool:
|
||||
result_path: Path = self._result_path(path)
|
||||
if not result_path.exists():
|
||||
print("Missing snapshot. Please run the update command first")
|
||||
return False
|
||||
result: CaseResult = self._exec_case(path)
|
||||
expected: str = result_path.read_text()
|
||||
actual: str = result.dumps()
|
||||
|
||||
if expected == actual:
|
||||
return True
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
expected.splitlines(keepends=True),
|
||||
actual.splitlines(keepends=True),
|
||||
fromfile="Snapshot",
|
||||
tofile="Result",
|
||||
)
|
||||
self._print_diff(diff)
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def _exec_case(self, path: Path) -> CaseResult: ...
|
||||
|
||||
def update_all_tests(self):
|
||||
paths: list[Path] = self._list_tests()
|
||||
return self.update_tests(paths)
|
||||
|
||||
def update_tests(self, tests: list[Path]):
|
||||
updated: int = 0
|
||||
for test in tests:
|
||||
if self._update_test(test):
|
||||
updated += 1
|
||||
print(f"Updated {updated}/{len(tests)} tests")
|
||||
|
||||
def _update_test(self, path: Path) -> bool:
|
||||
result: CaseResult = self._exec_case(path)
|
||||
result_path: Path = self._result_path(path)
|
||||
current: str = result_path.read_text() if result_path.exists() else ""
|
||||
new: str = result.dumps()
|
||||
if current == new:
|
||||
return False
|
||||
result_path.write_text(new)
|
||||
return True
|
||||
|
||||
def _result_path(self, test_path: Path) -> Path:
|
||||
return test_path.parent / (test_path.name + ".ref.json")
|
||||
|
||||
def _print_diff(self, diff: Iterator[str]):
|
||||
for line in diff:
|
||||
if line.startswith("+") and not line.startswith("+++"):
|
||||
print(f"\033[92m{line}\033[0m", end="")
|
||||
elif line.startswith("-") and not line.startswith("---"):
|
||||
print(f"\033[91m{line}\033[0m", end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
print()
|
||||
|
||||
@classmethod
|
||||
def main(cls):
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="subcommand")
|
||||
|
||||
update = subparsers.add_parser("update")
|
||||
update.add_argument("-a", "--all", action="store_true")
|
||||
update.add_argument("FILE", type=Path, nargs="*")
|
||||
|
||||
run = subparsers.add_parser("run")
|
||||
run.add_argument("-a", "--all", action="store_true")
|
||||
run.add_argument("FILE", type=Path, nargs="*")
|
||||
args = parser.parse_args()
|
||||
|
||||
tester: Tester = cls()
|
||||
|
||||
match args.subcommand:
|
||||
case "update":
|
||||
if args.all:
|
||||
tester.update_all_tests()
|
||||
else:
|
||||
tester.update_tests(args.FILE)
|
||||
case "run":
|
||||
success: bool
|
||||
if args.all:
|
||||
success = tester.run_all_tests()
|
||||
else:
|
||||
success = tester.run_tests(args.FILE)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
14
tests/cases/checker/01_simple_types.py
Normal file
14
tests/cases/checker/01_simple_types.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
df: Frame[
|
||||
verified: bool,
|
||||
birth_year: int,
|
||||
height: float + ( _ > 0 ) + ( _ < 250 ),
|
||||
name: str,
|
||||
date: datetime,
|
||||
float,
|
||||
unknown: _,
|
||||
_
|
||||
]
|
||||
3
tests/cases/checker/01_simple_types.py.ref.json
Normal file
3
tests/cases/checker/01_simple_types.py.ref.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"diagnostics": []
|
||||
}
|
||||
11
tests/cases/checker/02_simple_operations.py
Normal file
11
tests/cases/checker/02_simple_operations.py
Normal file
@@ -0,0 +1,11 @@
|
||||
a: int = 3
|
||||
b: int = 4
|
||||
|
||||
c = a + b
|
||||
|
||||
c = "invalid"
|
||||
|
||||
d = True
|
||||
e = d + d
|
||||
|
||||
f: float = a
|
||||
46
tests/cases/checker/02_simple_operations.py.ref.json
Normal file
46
tests/cases/checker/02_simple_operations.py.ref.json
Normal file
@@ -0,0 +1,46 @@
|
||||
{
|
||||
"diagnostics": [
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
6,
|
||||
0
|
||||
],
|
||||
"end": [
|
||||
6,
|
||||
13
|
||||
]
|
||||
},
|
||||
"message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
9,
|
||||
4
|
||||
],
|
||||
"end": [
|
||||
9,
|
||||
9
|
||||
]
|
||||
},
|
||||
"message": "Undefined operation __add__ between BaseType(name='bool') and BaseType(name='bool')"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
11,
|
||||
0
|
||||
],
|
||||
"end": [
|
||||
11,
|
||||
12
|
||||
]
|
||||
},
|
||||
"message": "Cannot assign BaseType(name='int') to f of type BaseType(name='float')"
|
||||
}
|
||||
]
|
||||
}
|
||||
18
tests/cases/checker/03_functions.py
Normal file
18
tests/cases/checker/03_functions.py
Normal file
@@ -0,0 +1,18 @@
|
||||
def foo(a: int, /, b: float, *, c: str):
|
||||
return True
|
||||
|
||||
|
||||
r1 = foo()
|
||||
r2 = foo(1)
|
||||
r3 = foo(1, 2.0)
|
||||
r4 = foo(1, b=2.0)
|
||||
r5 = foo(1, 2.0, "test")
|
||||
r6 = foo(1, 2.0, b=3.0)
|
||||
r7 = foo(a=1)
|
||||
r8 = foo(g="test")
|
||||
|
||||
r9a = foo(1, 2.0, c="test")
|
||||
r9b = foo(1, b=2.0, c="test")
|
||||
r9c = foo(1, c="test", b=2.0)
|
||||
|
||||
r10 = foo("a", 3, c=False)
|
||||
270
tests/cases/checker/03_functions.py.ref.json
Normal file
270
tests/cases/checker/03_functions.py.ref.json
Normal file
@@ -0,0 +1,270 @@
|
||||
{
|
||||
"diagnostics": [
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
5,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
5,
|
||||
10
|
||||
]
|
||||
},
|
||||
"message": "Missing required positional arguments: 'a' and 'b'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
5,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
5,
|
||||
10
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
6,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
6,
|
||||
11
|
||||
]
|
||||
},
|
||||
"message": "Missing required positional argument: 'b'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
6,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
6,
|
||||
11
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
7,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
7,
|
||||
16
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
8,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
8,
|
||||
18
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
9,
|
||||
17
|
||||
],
|
||||
"end": [
|
||||
9,
|
||||
23
|
||||
]
|
||||
},
|
||||
"message": "Too many positional arguments"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
9,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
9,
|
||||
24
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
10,
|
||||
19
|
||||
],
|
||||
"end": [
|
||||
10,
|
||||
22
|
||||
]
|
||||
},
|
||||
"message": "Multiple values for argument 'b'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
10,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
10,
|
||||
23
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
11,
|
||||
11
|
||||
],
|
||||
"end": [
|
||||
11,
|
||||
12
|
||||
]
|
||||
},
|
||||
"message": "Unknown keyword argument 'a'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
11,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
11,
|
||||
13
|
||||
]
|
||||
},
|
||||
"message": "Missing required positional arguments: 'a' and 'b'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
11,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
11,
|
||||
13
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
12,
|
||||
11
|
||||
],
|
||||
"end": [
|
||||
12,
|
||||
17
|
||||
]
|
||||
},
|
||||
"message": "Unknown keyword argument 'g'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
12,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
12,
|
||||
18
|
||||
]
|
||||
},
|
||||
"message": "Missing required positional arguments: 'a' and 'b'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
12,
|
||||
5
|
||||
],
|
||||
"end": [
|
||||
12,
|
||||
18
|
||||
]
|
||||
},
|
||||
"message": "Missing required keyword argument: 'c'"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
18,
|
||||
10
|
||||
],
|
||||
"end": [
|
||||
18,
|
||||
13
|
||||
]
|
||||
},
|
||||
"message": "Wrong type for argument 'a', expected BaseType(name='int'), got BaseType(name='str')"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
18,
|
||||
15
|
||||
],
|
||||
"end": [
|
||||
18,
|
||||
16
|
||||
]
|
||||
},
|
||||
"message": "Wrong type for argument 'b', expected BaseType(name='float'), got BaseType(name='int')"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
18,
|
||||
20
|
||||
],
|
||||
"end": [
|
||||
18,
|
||||
25
|
||||
]
|
||||
},
|
||||
"message": "Wrong type for argument 'c', expected BaseType(name='str'), got BaseType(name='bool')"
|
||||
}
|
||||
]
|
||||
}
|
||||
14
tests/cases/checker/04_custom_types.midas
Normal file
14
tests/cases/checker/04_custom_types.midas
Normal file
@@ -0,0 +1,14 @@
|
||||
type Meter(float)
|
||||
type Second(float)
|
||||
type MeterPerSecond(float)
|
||||
|
||||
extend Meter {
|
||||
op __add__(Meter) -> Meter
|
||||
op __sub__(Meter) -> Meter
|
||||
op __truediv__(Second) -> MeterPerSecond
|
||||
}
|
||||
|
||||
extend Second {
|
||||
op __add__(Second) -> Second
|
||||
op __sub__(Second) -> Second
|
||||
}
|
||||
8
tests/cases/checker/04_custom_types.py
Normal file
8
tests/cases/checker/04_custom_types.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
midas.using("04_custom_types.midas")
|
||||
|
||||
distance: Meter = cast(Meter, 123.45)
|
||||
time: Second = cast(Second, 6.7)
|
||||
speed = distance / time
|
||||
3
tests/cases/checker/04_custom_types.py.ref.json
Normal file
3
tests/cases/checker/04_custom_types.py.ref.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"diagnostics": []
|
||||
}
|
||||
25
tests/cases/checker/05_control_flow.py
Normal file
25
tests/cases/checker/05_control_flow.py
Normal file
@@ -0,0 +1,25 @@
|
||||
def valid(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
def with_if(a: int, b: int) -> int:
|
||||
if a < b:
|
||||
return b - a
|
||||
else:
|
||||
return a - b
|
||||
|
||||
def unreachable1():
|
||||
return
|
||||
a = 0
|
||||
|
||||
def unreachable2(a: int) -> int:
|
||||
if a > 10:
|
||||
return a - 10
|
||||
else:
|
||||
return a
|
||||
b = 0
|
||||
|
||||
def mixed(a: int, b: int):
|
||||
if a < b:
|
||||
return b - a
|
||||
else:
|
||||
return "oops"
|
||||
46
tests/cases/checker/05_control_flow.py.ref.json
Normal file
46
tests/cases/checker/05_control_flow.py.ref.json
Normal file
@@ -0,0 +1,46 @@
|
||||
{
|
||||
"diagnostics": [
|
||||
{
|
||||
"type": "Warning",
|
||||
"location": {
|
||||
"start": [
|
||||
12,
|
||||
4
|
||||
],
|
||||
"end": [
|
||||
12,
|
||||
9
|
||||
]
|
||||
},
|
||||
"message": "Unreachable statement"
|
||||
},
|
||||
{
|
||||
"type": "Warning",
|
||||
"location": {
|
||||
"start": [
|
||||
19,
|
||||
4
|
||||
],
|
||||
"end": [
|
||||
19,
|
||||
9
|
||||
]
|
||||
},
|
||||
"message": "Unreachable statement"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
21,
|
||||
0
|
||||
],
|
||||
"end": [
|
||||
25,
|
||||
21
|
||||
]
|
||||
},
|
||||
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
|
||||
}
|
||||
]
|
||||
}
|
||||
57
tests/cases/midas-parser/01_simple_types.midas
Normal file
57
tests/cases/midas-parser/01_simple_types.midas
Normal file
@@ -0,0 +1,57 @@
|
||||
// Simple custom type derived from float
|
||||
type Custom(float)
|
||||
|
||||
// Simple custom types with constraints
|
||||
type Latitude(float) where (-90 <= _ <= 90)
|
||||
type Longitude(float) where (-180 <= _ <= 180)
|
||||
|
||||
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
|
||||
type Difference[T](T)
|
||||
|
||||
// Complex custom type, containing two values accessible through properties
|
||||
type GeoLocation {
|
||||
lat: Latitude
|
||||
lon: Longitude
|
||||
}
|
||||
|
||||
// Define operations on our custom type
|
||||
extend GeoLocation {
|
||||
// This type is compatible with the `-` operation with another GeoLocation
|
||||
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
|
||||
// in a Difference of GeoLocations
|
||||
op __sub__(GeoLocation) -> Difference[GeoLocation]
|
||||
}
|
||||
|
||||
// For complex generics, you need to specify how the genericity the properties
|
||||
// are handled
|
||||
type Difference[GeoLocation] {
|
||||
lat: Difference[Latitude]
|
||||
lon: Difference[Longitude]
|
||||
}
|
||||
|
||||
// Simple operation defined on our custom types
|
||||
extend Latitude {
|
||||
op __sub__(Latitude) -> Difference[Latitude]
|
||||
}
|
||||
|
||||
extend Longitude {
|
||||
op __sub__(Longitude) -> Difference[Longitude]
|
||||
}
|
||||
|
||||
// Predefined custom predicates that can be referenced in other definitions
|
||||
predicate Positive(v: float) = v >= 0
|
||||
predicate StrictlyPositive(v: float) = v > 0
|
||||
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
|
||||
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
|
||||
|
||||
type Person {
|
||||
name: str
|
||||
|
||||
// Property with an inline constraint
|
||||
age: int? where (0 <= _ < 150)
|
||||
|
||||
// Property referencing a predicate
|
||||
height: float where StrictlyPositive
|
||||
|
||||
home: GeoLocation
|
||||
}
|
||||
2659
tests/cases/midas-parser/01_simple_types.midas.ref.json
Normal file
2659
tests/cases/midas-parser/01_simple_types.midas.ref.json
Normal file
File diff suppressed because it is too large
Load Diff
14
tests/cases/python-parser/01_simple_types.py
Normal file
14
tests/cases/python-parser/01_simple_types.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
df: Frame[
|
||||
verified: bool,
|
||||
birth_year: int,
|
||||
height: float + ( _ > 0 ) + ( _ < 250 ),
|
||||
name: str,
|
||||
date: datetime,
|
||||
float,
|
||||
unknown: _,
|
||||
_
|
||||
]
|
||||
85
tests/cases/python-parser/01_simple_types.py.ref.json
Normal file
85
tests/cases/python-parser/01_simple_types.py.ref.json
Normal file
@@ -0,0 +1,85 @@
|
||||
{
|
||||
"stmts": [
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "df",
|
||||
"type": {
|
||||
"_type": "FrameType",
|
||||
"columns": [
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "verified",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "bool",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "birth_year",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "height",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "(_ > 0) + (_ < 250)"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "name",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "str",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "date",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "datetime",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": null,
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "unknown",
|
||||
"type": null
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": null,
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "_",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
29
tests/cases/python-parser/02_custom_types.py
Normal file
29
tests/cases/python-parser/02_custom_types.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
import midas
|
||||
|
||||
midas.using("02_custom_types.midas")
|
||||
|
||||
df: Frame[
|
||||
location: GeoLocation
|
||||
]
|
||||
|
||||
lat: Column[GeoLocation] = df["location"].lat
|
||||
lon: Column[GeoLocation] = df["location"].lon
|
||||
|
||||
lat + lon
|
||||
|
||||
lat1: Latitude = lat[0]
|
||||
lat2: Latitude = lat[1]
|
||||
lat_diff: Difference[Latitude] = lat2 - lat1
|
||||
|
||||
df2: Frame[
|
||||
age: int + (_ >= 0),
|
||||
height: float + (_ >= 0),
|
||||
]
|
||||
df2_bis: Frame[
|
||||
age: int + Positive,
|
||||
height: float + Positive,
|
||||
]
|
||||
162
tests/cases/python-parser/02_custom_types.py.ref.json
Normal file
162
tests/cases/python-parser/02_custom_types.py.ref.json
Normal file
@@ -0,0 +1,162 @@
|
||||
{
|
||||
"stmts": [
|
||||
{
|
||||
"_type": "ExpressionStmt",
|
||||
"expr": {
|
||||
"_type": "CallExpr",
|
||||
"callee": {
|
||||
"_type": "GetExpr",
|
||||
"object": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "midas"
|
||||
},
|
||||
"name": "using"
|
||||
},
|
||||
"arguments": [
|
||||
{
|
||||
"_type": "LiteralExpr",
|
||||
"value": "02_custom_types.midas"
|
||||
}
|
||||
],
|
||||
"keywords": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "df",
|
||||
"type": {
|
||||
"_type": "FrameType",
|
||||
"columns": [
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "location",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "ExpressionStmt",
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lon"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "lat_diff",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Difference",
|
||||
"param": {
|
||||
"_type": "BaseType",
|
||||
"base": "Latitude",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "AssignStmt",
|
||||
"targets": [
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat_diff"
|
||||
}
|
||||
],
|
||||
"value": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat2"
|
||||
},
|
||||
"operator": "-",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat1"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "df2",
|
||||
"type": {
|
||||
"_type": "FrameType",
|
||||
"columns": [
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "age",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "_ >= 0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "height",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "_ >= 0"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "df2_bis",
|
||||
"type": {
|
||||
"_type": "FrameType",
|
||||
"columns": [
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "age",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "Positive"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "FrameColumn",
|
||||
"name": "height",
|
||||
"type": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "Positive"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
15
tests/cases/python-parser/03_functions.py
Normal file
15
tests/cases/python-parser/03_functions.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# type: ignore
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def func(
|
||||
col1: Column[float + (0 <= _ <= 1)],
|
||||
col2: Column[float + (0 <= _ <= 1)],
|
||||
) -> Column[float + (0 <= _ <= 2)]:
|
||||
result: Column[float + (0 <= _ <= 2)] = col1 + col2
|
||||
return result
|
||||
|
||||
|
||||
def func2(a: int, /, b: float, *, c: str):
|
||||
pass
|
||||
149
tests/cases/python-parser/03_functions.py.ref.json
Normal file
149
tests/cases/python-parser/03_functions.py.ref.json
Normal file
@@ -0,0 +1,149 @@
|
||||
{
|
||||
"stmts": [
|
||||
{
|
||||
"_type": "Function",
|
||||
"name": "func",
|
||||
"posonlyargs": [],
|
||||
"args": [
|
||||
{
|
||||
"name": "col1",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
},
|
||||
"default": null
|
||||
},
|
||||
{
|
||||
"name": "col2",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 1"
|
||||
}
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"sink": null,
|
||||
"kwonlyargs": [],
|
||||
"kw_sink": null,
|
||||
"returns": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
},
|
||||
"body": [
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "result",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "ConstraintType",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"constraint": "0 <= _ <= 2"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "AssignStmt",
|
||||
"targets": [
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "result"
|
||||
}
|
||||
],
|
||||
"value": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "col1"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "col2"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "ReturnStmt",
|
||||
"value": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "result"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"_type": "Function",
|
||||
"name": "func2",
|
||||
"posonlyargs": [
|
||||
{
|
||||
"name": "a",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "int",
|
||||
"param": null
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"args": [
|
||||
{
|
||||
"name": "b",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "float",
|
||||
"param": null
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"sink": null,
|
||||
"kwonlyargs": [
|
||||
{
|
||||
"name": "c",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "str",
|
||||
"param": null
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
],
|
||||
"kw_sink": null,
|
||||
"returns": null,
|
||||
"body": []
|
||||
}
|
||||
]
|
||||
}
|
||||
67
tests/checker.py
Normal file
67
tests/checker.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import ast
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.checker.checker import Checker
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.resolver.resolver import Resolver
|
||||
from tests.base import Tester
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaseResult:
|
||||
diagnostics: list[dict] = field(default_factory=list)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(asdict(self), indent=2)
|
||||
|
||||
|
||||
class CheckerTester(Tester):
|
||||
@property
|
||||
def namespace(self) -> str:
|
||||
return "checker"
|
||||
|
||||
def _list_tests(self) -> list[Path]:
|
||||
return list(self.base_dir.rglob("*.py"))
|
||||
|
||||
def _exec_case(self, path: Path) -> CaseResult:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||
if not path.is_file():
|
||||
raise TypeError(f"Test '{path}' is not a file")
|
||||
|
||||
source: str = path.read_text()
|
||||
tree: ast.Module = ast.parse(source, filename=path)
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
resolver = Resolver()
|
||||
resolver.resolve(*stmts)
|
||||
result: CaseResult = CaseResult()
|
||||
checker = Checker(resolver.locals, file_path=path)
|
||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
||||
for diagnostic in diagnostics:
|
||||
result.diagnostics.append(
|
||||
{
|
||||
"type": str(diagnostic.type),
|
||||
"location": {
|
||||
"start": (
|
||||
diagnostic.location.lineno,
|
||||
diagnostic.location.col_offset,
|
||||
),
|
||||
"end": (
|
||||
diagnostic.location.end_lineno,
|
||||
diagnostic.location.end_col_offset,
|
||||
),
|
||||
},
|
||||
"message": diagnostic.message,
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CheckerTester.main()
|
||||
@@ -1,129 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from lexer.annotations import AnnotationLexer
|
||||
from lexer.token import Token, TokenType
|
||||
|
||||
|
||||
def scan(source: str) -> list[Token]:
|
||||
return AnnotationLexer(source).process()
|
||||
|
||||
|
||||
def assert_n_tokens(tokens: list[Token], n: int):
|
||||
assert len(tokens) == n + 1
|
||||
assert tokens[-1].type == TokenType.EOF
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("(", TokenType.LEFT_PAREN),
|
||||
(")", TokenType.RIGHT_PAREN),
|
||||
("[", TokenType.LEFT_BRACKET),
|
||||
("]", TokenType.RIGHT_BRACKET),
|
||||
(":", TokenType.COLON),
|
||||
(",", TokenType.COMMA),
|
||||
("_", TokenType.UNDERSCORE),
|
||||
],
|
||||
)
|
||||
def test_punctuation(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("+", TokenType.PLUS),
|
||||
(">", TokenType.GREATER),
|
||||
(">=", TokenType.GREATER_EQUAL),
|
||||
("<", TokenType.LESS),
|
||||
("<=", TokenType.LESS_EQUAL),
|
||||
("=", TokenType.EQUAL),
|
||||
("==", TokenType.EQUAL_EQUAL),
|
||||
("!=", TokenType.BANG_EQUAL),
|
||||
],
|
||||
)
|
||||
def test_operators(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("a", TokenType.IDENTIFIER),
|
||||
("foo", TokenType.IDENTIFIER),
|
||||
("foo1", TokenType.IDENTIFIER),
|
||||
("foo_", TokenType.IDENTIFIER),
|
||||
("foo_bar1_baz2", TokenType.IDENTIFIER),
|
||||
("FOO_BAR1_BAZ2", TokenType.IDENTIFIER),
|
||||
("True", TokenType.TRUE),
|
||||
("False", TokenType.FALSE),
|
||||
("None", TokenType.NONE),
|
||||
],
|
||||
)
|
||||
def test_identifiers_keywords(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("#", TokenType.COMMENT),
|
||||
("# This is a comment", TokenType.COMMENT),
|
||||
(" ", TokenType.WHITESPACE),
|
||||
("\t", TokenType.WHITESPACE),
|
||||
("\r", TokenType.WHITESPACE),
|
||||
(" \t \t", TokenType.WHITESPACE),
|
||||
("\n", TokenType.NEWLINE),
|
||||
],
|
||||
)
|
||||
def test_misc(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected_type,expected_value",
|
||||
[
|
||||
("0", TokenType.NUMBER, 0),
|
||||
("0.0", TokenType.NUMBER, 0),
|
||||
("1234.56", TokenType.NUMBER, 1234.56),
|
||||
],
|
||||
)
|
||||
def test_literals(src: str, expected_type: TokenType, expected_value: Any):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected_type
|
||||
assert tokens[0].value == expected_value
|
||||
|
||||
|
||||
def test_single_bang_error():
|
||||
with pytest.raises(SyntaxError):
|
||||
scan("!")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src",
|
||||
[
|
||||
"-",
|
||||
"*",
|
||||
"/",
|
||||
"{",
|
||||
"}",
|
||||
"@",
|
||||
'"',
|
||||
"'",
|
||||
".",
|
||||
],
|
||||
)
|
||||
def test_unexpected_character(src: str):
|
||||
with pytest.raises(SyntaxError):
|
||||
scan(src)
|
||||
@@ -1,129 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from lexer.midas import MidasLexer
|
||||
from lexer.token import Token, TokenType
|
||||
|
||||
|
||||
def scan(source: str) -> list[Token]:
|
||||
return MidasLexer(source).process()
|
||||
|
||||
|
||||
def assert_n_tokens(tokens: list[Token], n: int):
|
||||
assert len(tokens) == n + 1
|
||||
assert tokens[-1].type == TokenType.EOF
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("(", TokenType.LEFT_PAREN),
|
||||
(")", TokenType.RIGHT_PAREN),
|
||||
("[", TokenType.LEFT_BRACKET),
|
||||
("]", TokenType.RIGHT_BRACKET),
|
||||
("{", TokenType.LEFT_BRACE),
|
||||
("}", TokenType.RIGHT_BRACE),
|
||||
(":", TokenType.COLON),
|
||||
(",", TokenType.COMMA),
|
||||
("_", TokenType.UNDERSCORE),
|
||||
],
|
||||
)
|
||||
def test_punctuation(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("+", TokenType.PLUS),
|
||||
("-", TokenType.MINUS),
|
||||
("*", TokenType.STAR),
|
||||
("/", TokenType.SLASH),
|
||||
(">", TokenType.GREATER),
|
||||
(">=", TokenType.GREATER_EQUAL),
|
||||
("<", TokenType.LESS),
|
||||
("<=", TokenType.LESS_EQUAL),
|
||||
("=", TokenType.EQUAL),
|
||||
("==", TokenType.EQUAL_EQUAL),
|
||||
("!=", TokenType.BANG_EQUAL),
|
||||
],
|
||||
)
|
||||
def test_operators(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("a", TokenType.IDENTIFIER),
|
||||
("foo", TokenType.IDENTIFIER),
|
||||
("foo1", TokenType.IDENTIFIER),
|
||||
("foo_", TokenType.IDENTIFIER),
|
||||
("foo_bar1_baz2", TokenType.IDENTIFIER),
|
||||
("FOO_BAR1_BAZ2", TokenType.IDENTIFIER),
|
||||
("true", TokenType.TRUE),
|
||||
("false", TokenType.FALSE),
|
||||
("none", TokenType.NONE),
|
||||
],
|
||||
)
|
||||
def test_identifiers_keywords(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("// This is a comment", TokenType.COMMENT),
|
||||
("/* This is a comment */", TokenType.COMMENT),
|
||||
(" ", TokenType.WHITESPACE),
|
||||
("\t", TokenType.WHITESPACE),
|
||||
("\r", TokenType.WHITESPACE),
|
||||
(" \t \t", TokenType.WHITESPACE),
|
||||
("\n", TokenType.NEWLINE),
|
||||
],
|
||||
)
|
||||
def test_misc(src: str, expected: TokenType):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected_type,expected_value",
|
||||
[
|
||||
("0", TokenType.NUMBER, 0),
|
||||
("0.0", TokenType.NUMBER, 0),
|
||||
("1234.56", TokenType.NUMBER, 1234.56),
|
||||
],
|
||||
)
|
||||
def test_literals(src: str, expected_type: TokenType, expected_value: Any):
|
||||
tokens: list[Token] = scan(src)
|
||||
assert_n_tokens(tokens, 1)
|
||||
assert tokens[0].type == expected_type
|
||||
assert tokens[0].value == expected_value
|
||||
|
||||
|
||||
def test_single_bang_error():
|
||||
with pytest.raises(SyntaxError):
|
||||
scan("!")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src",
|
||||
[
|
||||
"@",
|
||||
'"',
|
||||
"'",
|
||||
".",
|
||||
],
|
||||
)
|
||||
def test_unexpected_character(src: str):
|
||||
with pytest.raises(SyntaxError):
|
||||
scan(src)
|
||||
82
tests/midas.py
Normal file
82
tests/midas.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.midas import Stmt
|
||||
from midas.lexer.base import MidasSyntaxError
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
from tests.base import Tester
|
||||
from tests.serializer.midas import MidasAstJsonSerializer
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaseResult:
|
||||
tokens: Optional[list[dict]] = None
|
||||
stmts: Optional[list[dict]] = None
|
||||
errors: list[dict] = field(default_factory=list)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(asdict(self), indent=2)
|
||||
|
||||
|
||||
class MidasTester(Tester):
|
||||
@property
|
||||
def namespace(self) -> str:
|
||||
return "midas-parser"
|
||||
|
||||
def _list_tests(self) -> list[Path]:
|
||||
return list(self.base_dir.rglob("*.midas"))
|
||||
|
||||
def _exec_case(self, path: Path) -> CaseResult:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||
if not path.is_file():
|
||||
raise TypeError(f"Test '{path}' is not a file")
|
||||
|
||||
result: CaseResult = CaseResult()
|
||||
content: str = path.read_text()
|
||||
lexer: MidasLexer = MidasLexer(content)
|
||||
tokens: list[Token] = []
|
||||
try:
|
||||
tokens = lexer.process()
|
||||
result.tokens = [
|
||||
{
|
||||
"type": token.type.name,
|
||||
"lexeme": token.lexeme,
|
||||
"line": token.position.line,
|
||||
"column": token.position.column,
|
||||
}
|
||||
for token in tokens
|
||||
]
|
||||
except MidasSyntaxError as e:
|
||||
result.errors.append(
|
||||
{
|
||||
"type": "SyntaxError",
|
||||
"line": e.pos.line,
|
||||
"column": e.pos.column,
|
||||
"message": e.message,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmts: list[Stmt] = parser.parse()
|
||||
result.stmts = MidasAstJsonSerializer().serialize(stmts)
|
||||
result.errors.extend(
|
||||
[
|
||||
{
|
||||
"line": e.token.position.line,
|
||||
"column": e.token.position.column,
|
||||
"message": e.message,
|
||||
}
|
||||
for e in parser.errors
|
||||
]
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
MidasTester.main()
|
||||
@@ -1,130 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ast.annotations import (
|
||||
AnnotationStmt,
|
||||
ConstraintExpr,
|
||||
Expr,
|
||||
LiteralExpr,
|
||||
SchemaElementExpr,
|
||||
SchemaExpr,
|
||||
Stmt,
|
||||
TypeExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
from lexer.annotations import AnnotationLexer
|
||||
from lexer.position import Position
|
||||
from lexer.token import Token
|
||||
from parser.annotations import AnnotationParser
|
||||
|
||||
|
||||
class AstSerializer(Stmt.Visitor[str], Expr.Visitor[str]):
|
||||
def serialize(self, stmt: Stmt):
|
||||
return stmt.accept(self)
|
||||
|
||||
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> str:
|
||||
schema: str = ""
|
||||
if stmt.schema is not None:
|
||||
schema = " " + stmt.schema.accept(self)
|
||||
return f"(annotation {stmt.name.lexeme}{schema})"
|
||||
|
||||
def visit_schema_expr(self, expr: SchemaExpr) -> str:
|
||||
elements: list[str] = [elmt.accept(self) for elmt in expr.elements]
|
||||
return f"(schema {' '.join(elements)})"
|
||||
|
||||
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> str:
|
||||
name: str = expr.name.lexeme if expr.name is not None else "_"
|
||||
type: str = expr.type.accept(self) if expr.type is not None else "_"
|
||||
return f"({name} {type})"
|
||||
|
||||
def visit_type_expr(self, expr: TypeExpr) -> str:
|
||||
res: str = f"({expr.name.lexeme}"
|
||||
for constraint in expr.constraints:
|
||||
res += " " + constraint.accept(self)
|
||||
res += ")"
|
||||
return res
|
||||
|
||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> str:
|
||||
return f"(constraint {expr.left.accept(self)} {expr.op.lexeme} {expr.right.accept(self)})"
|
||||
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> str:
|
||||
return "(_)"
|
||||
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> str:
|
||||
return f"({expr.value})"
|
||||
|
||||
|
||||
def parse(source: str) -> Optional[Stmt]:
|
||||
tokens: list[Token] = AnnotationLexer(source).process()
|
||||
return AnnotationParser(tokens).parse()
|
||||
|
||||
|
||||
def must_parse(source: str) -> Stmt:
|
||||
stmt: Optional[Stmt] = parse(source)
|
||||
assert stmt is not None
|
||||
return stmt
|
||||
|
||||
|
||||
def ast_str(source: str) -> str:
|
||||
stmt: Stmt = must_parse(source)
|
||||
return AstSerializer().serialize(stmt)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("Type", "(annotation Type)"),
|
||||
("Type[]", "(annotation Type (schema ))"),
|
||||
(
|
||||
"""
|
||||
Frame[
|
||||
verified: bool,
|
||||
birth_year: int,
|
||||
height: float + ( _ > 0 ) + ( _ < 250 ),
|
||||
name: str,
|
||||
date: datetime,
|
||||
float, # unnamed
|
||||
unknown: _, # untyped
|
||||
_ # unnamed and untyped
|
||||
]
|
||||
""",
|
||||
"(annotation Frame (schema (verified (bool)) (birth_year (int)) (height (float (constraint (_) > (0.0)) (constraint (_) < (250.0)))) (name (str)) (date (datetime)) (_ (float)) (unknown _) (_ _)))",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_expressions(src: str, expected: str):
|
||||
assert ast_str(src) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,pos,should_fail",
|
||||
[
|
||||
("", (1, 1), True),
|
||||
("42", (1, 1), True),
|
||||
("True", (1, 1), True),
|
||||
("Type[", (1, 6), True),
|
||||
("Type[] Type2", (1, 8), False),
|
||||
("Type[bool:]", (1, 11), True),
|
||||
("Type[3]", (1, 6), True),
|
||||
("Type[bool float]", (1, 11), True),
|
||||
("Type[bool (_ < 2)]", (1, 11), True),
|
||||
("Type[bool + _ < 2)]", (1, 13), True),
|
||||
("Type[bool + (_ < 2]", (1, 19), True),
|
||||
("Type[bool + (< 2)]", (1, 14), True),
|
||||
("Type[bool + (_ + 2)]", (1, 16), True),
|
||||
("Type[bool + (Foo + Bar)]", (1, 14), True),
|
||||
# ("Type[bool,]", (1, 11), True), # trailing comma is accepted, TODO: update parser or EBNF
|
||||
("Type[bool, Type[]]", (1, 16), True),
|
||||
("Type[foo: 3]", (1, 11), True),
|
||||
],
|
||||
)
|
||||
def test_parsing_error(src: str, pos: tuple[int, int], should_fail: bool):
|
||||
tokens: list[Token] = AnnotationLexer(src).process()
|
||||
parser: AnnotationParser = AnnotationParser(tokens)
|
||||
stmt: Optional[Stmt] = parser.parse()
|
||||
if should_fail:
|
||||
assert stmt is None
|
||||
assert len(parser.errors) != 0
|
||||
error_pos: Position = parser.errors[0].token.position
|
||||
assert (error_pos.line, error_pos.column) == pos
|
||||
@@ -1,202 +0,0 @@
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ast.midas import (
|
||||
ConstraintExpr,
|
||||
ConstraintStmt,
|
||||
Expr,
|
||||
LiteralExpr,
|
||||
OpStmt,
|
||||
PropertyStmt,
|
||||
Stmt,
|
||||
TypeBodyExpr,
|
||||
TypeExpr,
|
||||
TypeStmt,
|
||||
WildcardExpr,
|
||||
)
|
||||
from lexer.midas import MidasLexer
|
||||
from lexer.position import Position
|
||||
from lexer.token import Token
|
||||
from parser.midas import MidasParser
|
||||
|
||||
|
||||
class AstSerializer(Stmt.Visitor[str], Expr.Visitor[str]):
|
||||
def serialize(self, stmt: Stmt):
|
||||
return stmt.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: TypeStmt) -> str:
|
||||
res: str = f"(type_def {stmt.name.lexeme}"
|
||||
for base in stmt.bases:
|
||||
res += " " + base.accept(self)
|
||||
if stmt.body is not None:
|
||||
res += " " + stmt.body.accept(self)
|
||||
res += ")"
|
||||
return res
|
||||
|
||||
def visit_type_expr(self, expr: TypeExpr) -> str:
|
||||
res: str = f"({expr.name.lexeme}"
|
||||
for constraint in expr.constraints:
|
||||
res += " " + constraint.accept(self)
|
||||
res += ")"
|
||||
return res
|
||||
|
||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> str:
|
||||
return f"(constraint {expr.left.accept(self)} {expr.op.lexeme} {expr.right.accept(self)})"
|
||||
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> str:
|
||||
return "(_)"
|
||||
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> str:
|
||||
return f"({expr.value})"
|
||||
|
||||
def visit_type_body_expr(self, expr: TypeBodyExpr) -> str:
|
||||
res: str = "(body"
|
||||
for prop in expr.properties:
|
||||
res += " " + prop.accept(self)
|
||||
res += ")"
|
||||
return res
|
||||
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> str:
|
||||
return f"(property {stmt.name.lexeme} {stmt.type.accept(self)})"
|
||||
|
||||
def visit_op_stmt(self, stmt: OpStmt) -> str:
|
||||
left: str = stmt.left.accept(self)
|
||||
right: str = stmt.right.accept(self)
|
||||
result: str = stmt.result.accept(self)
|
||||
return f"(op_def {left} {stmt.op.lexeme} {right} {result})"
|
||||
|
||||
def visit_constraint_stmt(self, stmt: ConstraintStmt) -> str:
|
||||
return f"(constraint_def {stmt.name.lexeme} {stmt.constraint.accept(self)})"
|
||||
|
||||
|
||||
def parse(source: str) -> list[Stmt]:
|
||||
tokens: list[Token] = MidasLexer(source).process()
|
||||
return MidasParser(tokens).parse()
|
||||
|
||||
|
||||
def ast_str(source: str) -> list[str]:
|
||||
stmts: list[Stmt] = parse(source)
|
||||
return [AstSerializer().serialize(stmt) for stmt in stmts]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,expected",
|
||||
[
|
||||
("type Foo<>", "(type_def Foo)"),
|
||||
("type Foo<Bar>", "(type_def Foo (Bar))"),
|
||||
("type Foo<Bar, Baz>", "(type_def Foo (Bar) (Baz))"),
|
||||
(
|
||||
"type Foo<Bar + (_ < 2), Baz>",
|
||||
"(type_def Foo (Bar (constraint (_) < (2.0))) (Baz))",
|
||||
),
|
||||
(
|
||||
"""
|
||||
type Foo<> {
|
||||
foo: Bar
|
||||
}
|
||||
""",
|
||||
"(type_def Foo (body (property foo (Bar))))",
|
||||
),
|
||||
(
|
||||
"""
|
||||
type Foo<> {
|
||||
foo: Bar + (_ != none)
|
||||
foo2: Bar2 + (0 <= _) + (_ <= 100)
|
||||
}
|
||||
""",
|
||||
"(type_def Foo (body (property foo (Bar (constraint (_) != (None)))) (property foo2 (Bar2 (constraint (0.0) <= (_)) (constraint (_) <= (100.0))))))",
|
||||
),
|
||||
("op <A> + <B> = <C>", "(op_def (A) + (B) (C))"),
|
||||
(
|
||||
"op <A + (_ < 100)> + <B + (_ < 100)> = <C + (_ < 200)>",
|
||||
"(op_def (A (constraint (_) < (100.0))) + (B (constraint (_) < (100.0))) (C (constraint (_) < (200.0))))",
|
||||
),
|
||||
(
|
||||
"constraint Positive = _ >= 0",
|
||||
"(constraint_def Positive (constraint (_) >= (0.0)))",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_expressions(src: str, expected: str | list[str]):
|
||||
if isinstance(expected, str):
|
||||
expected = [expected]
|
||||
assert ast_str(src) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"src,pos",
|
||||
[
|
||||
###
|
||||
# Misc
|
||||
###
|
||||
("42", (1, 1)),
|
||||
("true", (1, 1)),
|
||||
("foo", (1, 1)),
|
||||
###
|
||||
# Type statements
|
||||
###
|
||||
("type", (1, 5)),
|
||||
("type true", (1, 6)),
|
||||
("type Foo", (1, 9)),
|
||||
("type Foo<1>", (1, 10)),
|
||||
# ("type Foo<float,>", (1, 16)), # trailing comma is accepted, TODO: update parser or EBNF
|
||||
("type Foo<float, 1>", (1, 17)),
|
||||
("type Foo<float", (1, 15)),
|
||||
("type Foo<float> { 3 }", (1, 19)),
|
||||
(
|
||||
"""
|
||||
type Foo<float> {
|
||||
foo
|
||||
}
|
||||
""",
|
||||
(4, 1),
|
||||
),
|
||||
(
|
||||
"""
|
||||
type Foo<float> {
|
||||
foo: 3
|
||||
}
|
||||
""",
|
||||
(3, 10),
|
||||
),
|
||||
###
|
||||
# Operation statements
|
||||
###
|
||||
("op", (1, 3)),
|
||||
("op float", (1, 4)),
|
||||
("op <", (1, 5)),
|
||||
("op <float", (1, 10)),
|
||||
("op <float>", (1, 11)),
|
||||
("op <float> +", (1, 13)),
|
||||
("op <float> + float", (1, 14)),
|
||||
("op <float> + <", (1, 15)),
|
||||
("op <float> + <float", (1, 20)),
|
||||
("op <float> + <float>", (1, 21)),
|
||||
("op <float> + <float> =", (1, 23)),
|
||||
("op <float> + <float> = float", (1, 24)),
|
||||
("op <float> + <float> = <", (1, 25)),
|
||||
("op <float> + <float> = <float", (1, 30)),
|
||||
("op <float + 3> + <float> = <float>", (1, 13)),
|
||||
("op <float> + <float + 3> = <float>", (1, 23)),
|
||||
("op <float> + <float> = <float + 3>", (1, 33)),
|
||||
###
|
||||
# Constraint statements
|
||||
###
|
||||
("constraint", (1, 11)),
|
||||
("constraint 3", (1, 12)),
|
||||
("constraint Foo", (1, 15)),
|
||||
("constraint Foo =", (1, 17)),
|
||||
("constraint Foo = 3", (1, 19)),
|
||||
("constraint Foo = 3 <", (1, 21)),
|
||||
],
|
||||
)
|
||||
def test_parsing_error(src: str, pos: tuple[int, int]):
|
||||
src = textwrap.dedent(src)
|
||||
tokens: list[Token] = MidasLexer(src).process()
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmt: list[Stmt] = parser.parse()
|
||||
assert len(stmt) == 0
|
||||
assert len(parser.errors) != 0
|
||||
error_pos: Position = parser.errors[0].token.position
|
||||
assert (error_pos.line, error_pos.column) == pos
|
||||
46
tests/python.py
Normal file
46
tests/python.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import ast
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.python import Stmt
|
||||
from midas.parser.python import PythonParser
|
||||
from tests.base import Tester
|
||||
from tests.serializer.python import PythonAstJsonSerializer
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaseResult:
|
||||
stmts: Optional[list[dict]] = None
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(asdict(self), indent=2)
|
||||
|
||||
|
||||
class PythonTester(Tester):
|
||||
@property
|
||||
def namespace(self) -> str:
|
||||
return "python-parser"
|
||||
|
||||
def _list_tests(self) -> list[Path]:
|
||||
return list(self.base_dir.rglob("*.py"))
|
||||
|
||||
def _exec_case(self, path: Path) -> CaseResult:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Could not find test '{path}'")
|
||||
if not path.is_file():
|
||||
raise TypeError(f"Test '{path}' is not a file")
|
||||
|
||||
result: CaseResult = CaseResult()
|
||||
content: str = path.read_text()
|
||||
tree: ast.Module = ast.parse(content)
|
||||
|
||||
parser: PythonParser = PythonParser()
|
||||
stmts: list[Stmt] = parser.parse_module(tree)
|
||||
result.stmts = PythonAstJsonSerializer().serialize(stmts)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PythonTester.main()
|
||||
159
tests/serializer/midas.py
Normal file
159
tests/serializer/midas.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from midas.ast.midas import (
|
||||
BinaryExpr,
|
||||
ComplexTypeStmt,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
OpStmt,
|
||||
PredicateStmt,
|
||||
PropertyStmt,
|
||||
SimpleTypeExpr,
|
||||
SimpleTypeStmt,
|
||||
Stmt,
|
||||
TemplateExpr,
|
||||
TypeExpr,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
|
||||
|
||||
class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
|
||||
"""An AST serializer which produces a JSON-compatible structure"""
|
||||
|
||||
def serialize(self, stmts: list[Stmt]) -> list[dict]:
|
||||
return [stmt.accept(self) for stmt in stmts]
|
||||
|
||||
def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]:
|
||||
if element is None:
|
||||
return None
|
||||
return element.accept(self)
|
||||
|
||||
def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]:
|
||||
return [element.accept(self) for element in elements]
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict:
|
||||
return {
|
||||
"_type": "SimpleTypeStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"template": self._serialize_optional(stmt.template),
|
||||
"base": stmt.base.accept(self),
|
||||
"constraint": self._serialize_optional(stmt.constraint),
|
||||
}
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict:
|
||||
return {
|
||||
"_type": "ComplexTypeStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"template": self._serialize_optional(stmt.template),
|
||||
"properties": self._serialize_list(stmt.properties),
|
||||
}
|
||||
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
|
||||
return {
|
||||
"_type": "PropertyStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"type": stmt.type.accept(self),
|
||||
"constraint": self._serialize_optional(stmt.constraint),
|
||||
}
|
||||
|
||||
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
|
||||
return {
|
||||
"_type": "ExtendStmt",
|
||||
"type": stmt.type.accept(self),
|
||||
"operations": self._serialize_list(stmt.operations),
|
||||
}
|
||||
|
||||
def visit_op_stmt(self, stmt: OpStmt) -> dict:
|
||||
return {
|
||||
"_type": "OpStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"operand": stmt.operand.accept(self),
|
||||
"result": stmt.result.accept(self),
|
||||
}
|
||||
|
||||
def visit_predicate_stmt(self, stmt: PredicateStmt) -> dict:
|
||||
return {
|
||||
"_type": "PredicateStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"subject": stmt.subject.lexeme,
|
||||
"type": stmt.type.accept(self),
|
||||
"condition": stmt.condition.accept(self),
|
||||
}
|
||||
|
||||
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict:
|
||||
return {
|
||||
"_type": "SimpleTypeExpr",
|
||||
"name": expr.name.lexeme,
|
||||
"optional": expr.optional,
|
||||
}
|
||||
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||
return {
|
||||
"_type": "LogicalExpr",
|
||||
"left": expr.left.accept(self),
|
||||
"operator": expr.operator.lexeme,
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
|
||||
return {
|
||||
"_type": "BinaryExpr",
|
||||
"left": expr.left.accept(self),
|
||||
"operator": expr.operator.lexeme,
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_unary_expr(self, expr: UnaryExpr) -> dict:
|
||||
return {
|
||||
"_type": "UnaryExpr",
|
||||
"operator": expr.operator.lexeme,
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_get_expr(self, expr: GetExpr) -> dict:
|
||||
return {
|
||||
"_type": "GetExpr",
|
||||
"expr": expr.expr.accept(self),
|
||||
"name": expr.name.lexeme,
|
||||
}
|
||||
|
||||
def visit_variable_expr(self, expr: VariableExpr) -> dict:
|
||||
return {
|
||||
"_type": "VariableExpr",
|
||||
"name": expr.name.lexeme,
|
||||
}
|
||||
|
||||
def visit_grouping_expr(self, expr: GroupingExpr) -> dict:
|
||||
return {
|
||||
"_type": "GroupingExpr",
|
||||
"expr": expr.expr.accept(self),
|
||||
}
|
||||
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> dict:
|
||||
return {
|
||||
"_type": "LiteralExpr",
|
||||
"value": expr.value,
|
||||
}
|
||||
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
|
||||
return {"_type": "WildcardExpr"}
|
||||
|
||||
def visit_template_expr(self, expr: TemplateExpr) -> dict:
|
||||
return {
|
||||
"_type": "TemplateExpr",
|
||||
"type": expr.type.accept(self),
|
||||
}
|
||||
|
||||
def visit_type_expr(self, expr: TypeExpr) -> dict:
|
||||
return {
|
||||
"_type": "TypeExpr",
|
||||
"name": expr.name.lexeme,
|
||||
"template": self._serialize_optional(expr.template),
|
||||
"optional": expr.optional,
|
||||
}
|
||||
247
tests/serializer/python.py
Normal file
247
tests/serializer/python.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import ast
|
||||
from typing import Optional, Sequence, Type
|
||||
|
||||
from midas.ast.python import (
|
||||
AssignStmt,
|
||||
BaseType,
|
||||
BinaryExpr,
|
||||
CallExpr,
|
||||
CastExpr,
|
||||
CompareExpr,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExpressionStmt,
|
||||
FrameColumn,
|
||||
FrameType,
|
||||
Function,
|
||||
GetExpr,
|
||||
IfStmt,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ReturnStmt,
|
||||
SetExpr,
|
||||
Stmt,
|
||||
TypeAssign,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
)
|
||||
|
||||
unary_ops: dict[Type[ast.unaryop], str] = {
|
||||
ast.Invert: "~",
|
||||
ast.Not: "not",
|
||||
ast.UAdd: "+",
|
||||
ast.USub: "-",
|
||||
}
|
||||
binary_ops: dict[Type[ast.operator], str] = {
|
||||
ast.Add: "+",
|
||||
ast.Sub: "-",
|
||||
ast.Mult: "*",
|
||||
ast.MatMult: "@",
|
||||
ast.Div: "/",
|
||||
ast.Mod: "%",
|
||||
ast.LShift: "<<",
|
||||
ast.RShift: ">>",
|
||||
ast.BitOr: "|",
|
||||
ast.BitXor: "^",
|
||||
ast.BitAnd: "&",
|
||||
ast.FloorDiv: "//",
|
||||
ast.Pow: "**",
|
||||
}
|
||||
compare_ops: dict[Type[ast.cmpop], str] = {
|
||||
ast.Eq: "==",
|
||||
ast.NotEq: "!=",
|
||||
ast.Lt: "<",
|
||||
ast.LtE: "<=",
|
||||
ast.Gt: ">",
|
||||
ast.GtE: ">=",
|
||||
ast.Is: "is",
|
||||
ast.IsNot: "is not",
|
||||
ast.In: "in",
|
||||
ast.NotIn: "not in",
|
||||
}
|
||||
boolean_ops: dict[Type[ast.boolop], str] = {
|
||||
ast.And: "and",
|
||||
ast.Or: "or",
|
||||
}
|
||||
|
||||
|
||||
class PythonAstJsonSerializer(
|
||||
Stmt.Visitor[dict], Expr.Visitor[dict], MidasType.Visitor[dict]
|
||||
):
|
||||
"""An AST serializer which produces a JSON-compatible structure"""
|
||||
|
||||
def serialize(self, stmts: list[Stmt]) -> list[dict]:
|
||||
return [stmt.accept(self) for stmt in stmts]
|
||||
|
||||
def _serialize_optional(
|
||||
self, element: Optional[Stmt | Expr | MidasType]
|
||||
) -> Optional[dict]:
|
||||
if element is None:
|
||||
return None
|
||||
return element.accept(self)
|
||||
|
||||
def _serialize_list(
|
||||
self, elements: Sequence[Stmt | Expr | MidasType]
|
||||
) -> list[dict]:
|
||||
return [element.accept(self) for element in elements]
|
||||
|
||||
def visit_base_type(self, node: BaseType) -> dict:
|
||||
return {
|
||||
"_type": "BaseType",
|
||||
"base": node.base,
|
||||
"param": self._serialize_optional(node.param),
|
||||
}
|
||||
|
||||
def visit_constraint_type(self, node: ConstraintType) -> dict:
|
||||
return {
|
||||
"_type": "ConstraintType",
|
||||
"type": node.type.accept(self),
|
||||
"constraint": ast.unparse(node.constraint),
|
||||
}
|
||||
|
||||
def visit_frame_column(self, node: FrameColumn) -> dict:
|
||||
return {
|
||||
"_type": "FrameColumn",
|
||||
"name": node.name,
|
||||
"type": self._serialize_optional(node.type),
|
||||
}
|
||||
|
||||
def visit_frame_type(self, node: FrameType) -> dict:
|
||||
return {
|
||||
"_type": "FrameType",
|
||||
"columns": self._serialize_list(node.columns),
|
||||
}
|
||||
|
||||
def visit_expression_stmt(self, stmt: ExpressionStmt) -> dict:
|
||||
return {
|
||||
"_type": "ExpressionStmt",
|
||||
"expr": stmt.expr.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_argument(self, arg: Function.Argument) -> dict:
|
||||
return {
|
||||
"name": arg.name,
|
||||
"type": self._serialize_optional(arg.type),
|
||||
"default": self._serialize_optional(arg.default),
|
||||
}
|
||||
|
||||
def visit_function(self, stmt: Function) -> dict:
|
||||
return {
|
||||
"_type": "Function",
|
||||
"name": stmt.name,
|
||||
"posonlyargs": [self._serialize_argument(arg) for arg in stmt.posonlyargs],
|
||||
"args": [self._serialize_argument(arg) for arg in stmt.args],
|
||||
"sink": (
|
||||
self._serialize_argument(stmt.sink) if stmt.sink is not None else None
|
||||
),
|
||||
"kwonlyargs": [self._serialize_argument(arg) for arg in stmt.kwonlyargs],
|
||||
"kw_sink": (
|
||||
self._serialize_argument(stmt.kw_sink)
|
||||
if stmt.kw_sink is not None
|
||||
else None
|
||||
),
|
||||
"returns": self._serialize_optional(stmt.returns),
|
||||
"body": self._serialize_list(stmt.body),
|
||||
}
|
||||
|
||||
def visit_type_assign(self, stmt: TypeAssign) -> dict:
|
||||
return {
|
||||
"_type": "TypeAssign",
|
||||
"name": stmt.name,
|
||||
"type": stmt.type.accept(self),
|
||||
}
|
||||
|
||||
def visit_assign_stmt(self, stmt: AssignStmt) -> dict:
|
||||
return {
|
||||
"_type": "AssignStmt",
|
||||
"targets": self._serialize_list(stmt.targets),
|
||||
"value": stmt.value.accept(self),
|
||||
}
|
||||
|
||||
def visit_return_stmt(self, stmt: ReturnStmt) -> dict:
|
||||
return {
|
||||
"_type": "ReturnStmt",
|
||||
"value": self._serialize_optional(stmt.value),
|
||||
}
|
||||
|
||||
def visit_if_stmt(self, stmt: IfStmt) -> dict:
|
||||
return {
|
||||
"_type": "IfStmt",
|
||||
"test": stmt.test.accept(self),
|
||||
"body": self._serialize_list(stmt.body),
|
||||
"orelse": self._serialize_list(stmt.orelse),
|
||||
}
|
||||
|
||||
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
|
||||
return {
|
||||
"_type": "BinaryExpr",
|
||||
"left": expr.left.accept(self),
|
||||
"operator": binary_ops[expr.operator.__class__],
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_compare_expr(self, expr: CompareExpr) -> dict:
|
||||
return {
|
||||
"_type": "CompareExpr",
|
||||
"left": expr.left.accept(self),
|
||||
"operator": compare_ops[expr.operator.__class__],
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_unary_expr(self, expr: UnaryExpr) -> dict:
|
||||
return {
|
||||
"_type": "UnaryExpr",
|
||||
"operator": unary_ops[expr.operator.__class__],
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_call_expr(self, expr: CallExpr) -> dict:
|
||||
return {
|
||||
"_type": "CallExpr",
|
||||
"callee": expr.callee.accept(self),
|
||||
"arguments": self._serialize_list(expr.arguments),
|
||||
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
|
||||
}
|
||||
|
||||
def visit_get_expr(self, expr: GetExpr) -> dict:
|
||||
return {
|
||||
"_type": "GetExpr",
|
||||
"object": expr.object.accept(self),
|
||||
"name": expr.name,
|
||||
}
|
||||
|
||||
def visit_literal_expr(self, expr: LiteralExpr) -> dict:
|
||||
return {
|
||||
"_type": "LiteralExpr",
|
||||
"value": expr.value,
|
||||
}
|
||||
|
||||
def visit_variable_expr(self, expr: VariableExpr) -> dict:
|
||||
return {
|
||||
"_type": "VariableExpr",
|
||||
"name": expr.name,
|
||||
}
|
||||
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||
return {
|
||||
"_type": "LogicalExpr",
|
||||
"left": expr.left.accept(self),
|
||||
"operator": boolean_ops[expr.operator.__class__],
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_set_expr(self, expr: SetExpr) -> dict:
|
||||
return {
|
||||
"_type": "SetExpr",
|
||||
"object": expr.object.accept(self),
|
||||
"name": expr.name,
|
||||
"value": expr.value.accept(self),
|
||||
}
|
||||
|
||||
def visit_cast_expr(self, expr: CastExpr) -> dict:
|
||||
return {
|
||||
"_type": "CastExpr",
|
||||
"type": expr.type.accept(self),
|
||||
"expr": expr.expr.accept(self),
|
||||
}
|
||||
@@ -31,22 +31,32 @@
|
||||
]
|
||||
},
|
||||
"type-base": {
|
||||
"begin": "<",
|
||||
"end": ">",
|
||||
"begin": "(\\()([a-zA-Z_][a-zA-Z_\\d]*)(\\))",
|
||||
"end": "$",
|
||||
"beginCaptures": {
|
||||
"0": {
|
||||
"1": {
|
||||
"name": "punctuation.definition.base.begin.midas"
|
||||
}
|
||||
},
|
||||
"endCaptures": {
|
||||
"0": {
|
||||
"2": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"3": {
|
||||
"name": "punctuation.definition.base.end.midas"
|
||||
}
|
||||
},
|
||||
"patterns": [
|
||||
{"include": "source.python"}
|
||||
{ "include": "#type-cond" }
|
||||
]
|
||||
},
|
||||
"type-cond": {
|
||||
"begin": "where",
|
||||
"end": "$",
|
||||
"beginCaptures": {
|
||||
"0": {
|
||||
"name": "keyword.control.where.midas"
|
||||
}
|
||||
}
|
||||
},
|
||||
"type-body": {
|
||||
"begin": "\\{",
|
||||
"end": "\\}",
|
||||
@@ -61,7 +71,8 @@
|
||||
}
|
||||
},
|
||||
"patterns": [
|
||||
{"include": "#type-prop"}
|
||||
{"include": "#type-prop"},
|
||||
{"include": "#comment"}
|
||||
]
|
||||
},
|
||||
"type-prop": {
|
||||
@@ -78,44 +89,67 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"op-def": {
|
||||
"match": "\\b(op)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>\\s+(\\S+)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>\\s+(=)\\s+<([a-zA-Z_][a-zA-Z_\\d]*)>",
|
||||
"captures": {
|
||||
"1": {
|
||||
"name": "keyword.control.op.midas"
|
||||
},
|
||||
"2": {
|
||||
"name" : "variable.name"
|
||||
},
|
||||
"3": {
|
||||
"name" : "keyword.operator"
|
||||
},
|
||||
"4": {
|
||||
"name" : "variable.name"
|
||||
},
|
||||
"5": {
|
||||
"name" : "keyword.operator.assignment"
|
||||
},
|
||||
"6": {
|
||||
"name" : "variable.name"
|
||||
}
|
||||
},
|
||||
"patterns": [
|
||||
{ "include": "#type-base" },
|
||||
{ "include": "#type-body" }
|
||||
]
|
||||
},
|
||||
"constr-def": {
|
||||
"begin": "(constraint)\\s+([a-zA-Z_][a-zA-Z_\\d]*)\\s*(=)",
|
||||
"end": "$",
|
||||
"extend-def": {
|
||||
"begin": "\\b(extend)\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\s+(\\{)",
|
||||
"end": "\\}",
|
||||
"beginCaptures": {
|
||||
"1": {
|
||||
"name": "keyword.control.constr.midas"
|
||||
"name": "keyword.control.extend.midas"
|
||||
},
|
||||
"2": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"3": {
|
||||
"name": "punctuation.definition.extend-body.begin.midas"
|
||||
}
|
||||
},
|
||||
"endCaptures": {
|
||||
"0": {
|
||||
"name": "punctuation.definition.extend-body.end.midas"
|
||||
}
|
||||
},
|
||||
"patterns": [
|
||||
{"include": "#op-def"},
|
||||
{"include": "#comment"}
|
||||
]
|
||||
},
|
||||
"op-def": {
|
||||
"match": "\\b(op)\\s+(\\S+)\\s*\\(\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\s*\\)\\s*(->)\\s*([a-zA-Z_][a-zA-Z_\\d]*)",
|
||||
"captures": {
|
||||
"1": {
|
||||
"name": "keyword.control.op.midas"
|
||||
},
|
||||
"2": {
|
||||
"name" : "keyword.operator"
|
||||
},
|
||||
"3": {
|
||||
"name" : "variable.name"
|
||||
},
|
||||
"4": {
|
||||
"name" : "keyword.operator.assignment"
|
||||
},
|
||||
"5": {
|
||||
"name" : "variable.name"
|
||||
}
|
||||
}
|
||||
},
|
||||
"pred-def": {
|
||||
"begin": "(predicate)\\s+([a-zA-Z_][a-zA-Z_\\d]*)\\(([a-zA-Z_][a-zA-Z_\\d]*):\\s*([a-zA-Z_][a-zA-Z_\\d]*)\\)\\s*(=)",
|
||||
"end": "$",
|
||||
"beginCaptures": {
|
||||
"1": {
|
||||
"name": "keyword.control.pred.midas"
|
||||
},
|
||||
"2": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"3": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"4": {
|
||||
"name": "variable.name"
|
||||
},
|
||||
"5": {
|
||||
"name": "keyword.operator.assignment"
|
||||
}
|
||||
},
|
||||
@@ -127,8 +161,8 @@
|
||||
"patterns": [
|
||||
{ "include": "#comment" },
|
||||
{ "include": "#type-def" },
|
||||
{ "include": "#op-def" },
|
||||
{ "include": "#constr-def" }
|
||||
{ "include": "#extend-def" },
|
||||
{ "include": "#pred-def" }
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user