23 Commits

Author SHA1 Message Date
12d762429d fix(parser): complete EBNF and railroad diagrams 2026-05-21 15:46:40 +02:00
53929ee514 test(parser): remove pytest tests 2026-05-21 15:07:19 +02:00
2f6e137f1a tests(parser): update snapshot with new syntax 2026-05-21 15:04:32 +02:00
5224e79d9f fix(parser): update pretty printer 2026-05-21 14:45:52 +02:00
bdcb12c58a fix(parser): update AST printer 2026-05-21 14:27:38 +02:00
5cb4d587e3 feat(parser)!: adapt parser for revised syntax 2026-05-21 13:57:38 +02:00
8f9ec8d73b feat(parser): add more nodes for constraint parsing 2026-05-21 13:54:58 +02:00
c1c50a448e fix(parser): allow underscores in identifier
modify the lexer to allow underscores in an identifier, but keep scanning single underscores as a specific underscore token
2026-05-21 13:54:19 +02:00
19229db0b1 feat(parser)!: adjust AST node classes for new syntax 2026-05-21 12:25:47 +02:00
f3b6bd146f tool: add AST class generator script 2026-05-21 12:24:43 +02:00
98c3510bd4 feat(parser): update lexer with new tokens 2026-05-21 09:15:14 +02:00
429d0d98fe feat: update railroad diagrams with revised syntax 2026-05-21 07:53:56 +02:00
db8fe5d3ff feat: update EBNF with revised syntax 2026-05-21 07:53:40 +02:00
7477ec8d70 fix: change syntax definition to W3C EBNF 2026-05-20 15:47:34 +02:00
adf7f4e7a2 tests(parser): use new MidasSyntaxError 2026-05-20 15:46:25 +02:00
abf6787946 fix(parser)!: remove annotation lexer and parser 2026-05-20 15:45:55 +02:00
e282b08597 fix: tweak syntax examples
- move operation definitions outside GeoLocation type
- add nullable type
- list syntax choices for complex refinement
2026-05-20 14:14:01 +02:00
0a02b9d3d9 feat: revise syntax (example)
improve the syntax to better fit the principle of least surprise and Python syntax
2026-05-20 13:20:53 +02:00
875ca589e4 Merge pull request 'Improve testing framework' (#2) from feat/test-framework into main
Reviewed-on: #2
2026-05-20 11:17:00 +00:00
88f92d6e1f tests(parser): add simple types snapshot test 2026-05-19 14:12:12 +02:00
db4ed74365 tests(parser): add snapshot test runner
the diff printing function was suggested by Gemini

Co-authored-by: Gemini <noreply@gemini.google.com>
2026-05-19 14:11:32 +02:00
7cbf4fdece feat(tests): add AST JSON serializer 2026-05-19 14:00:32 +02:00
1fa9a09bfe feat(parser): use custom syntax error class 2026-05-19 13:57:00 +02:00
24 changed files with 4055 additions and 1409 deletions

View File

@@ -1,107 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from lexer.token import Token
T = TypeVar("T")
@dataclass(frozen=True)
class Stmt(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> T: ...
@dataclass(frozen=True)
class AnnotationStmt(Stmt):
name: Token
schema: Optional[SchemaExpr]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_annotation_stmt(self)
@dataclass(frozen=True)
class Expr(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
@abstractmethod
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
@abstractmethod
def visit_type_expr(self, expr: TypeExpr) -> T: ...
@abstractmethod
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
@abstractmethod
def visit_schema_expr(self, expr: SchemaExpr) -> T: ...
@abstractmethod
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> T: ...
@dataclass(frozen=True)
class WildcardExpr(Expr):
token: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_wildcard_expr(self)
@dataclass(frozen=True)
class LiteralExpr(Expr):
value: Any
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_literal_expr(self)
@dataclass(frozen=True)
class TypeExpr(Expr):
name: Token
constraints: list[ConstraintExpr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_type_expr(self)
@dataclass(frozen=True)
class ConstraintExpr(Expr):
left: Expr
op: Token
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_constraint_expr(self)
@dataclass(frozen=True)
class SchemaExpr(Expr):
left: Token
elements: list[Expr]
right: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_schema_expr(self)
@dataclass(frozen=True)
class SchemaElementExpr(Expr):
name: Optional[Token]
type: Optional[Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_schema_element_expr(self)

159
core/ast/json_serializer.py Normal file
View File

@@ -0,0 +1,159 @@
from typing import Optional, Sequence
from core.ast.midas import (
BinaryExpr,
ComplexTypeStmt,
Expr,
ExtendStmt,
GetExpr,
GroupingExpr,
LiteralExpr,
LogicalExpr,
OpStmt,
PredicateStmt,
PropertyStmt,
SimpleTypeExpr,
SimpleTypeStmt,
Stmt,
TemplateExpr,
TypeExpr,
UnaryExpr,
VariableExpr,
WildcardExpr,
)
class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"""An AST serializer which produces a JSON-compatible structure"""
def serialize(self, stmts: list[Stmt]) -> list[dict]:
return [stmt.accept(self) for stmt in stmts]
def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]:
if element is None:
return None
return element.accept(self)
def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]:
return [element.accept(self) for element in elements]
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict:
return {
"_type": "SimpleTypeStmt",
"template": self._serialize_optional(stmt.template),
"name": stmt.name.lexeme,
"base": stmt.base.accept(self),
"constraint": self._serialize_optional(stmt.constraint),
}
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict:
return {
"_type": "ComplexTypeStmt",
"name": stmt.name.lexeme,
"template": self._serialize_optional(stmt.template),
"properties": self._serialize_list(stmt.properties),
}
def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
return {
"_type": "PropertyStmt",
"name": stmt.name.lexeme,
"type": stmt.type.accept(self),
"constraint": self._serialize_optional(stmt.constraint),
}
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
return {
"_type": "ExtendStmt",
"type": stmt.type.accept(self),
"operations": self._serialize_list(stmt.operations),
}
def visit_op_stmt(self, stmt: OpStmt) -> dict:
return {
"_type": "OpStmt",
"name": stmt.name.lexeme,
"operand": stmt.operand.accept(self),
"result": stmt.result.accept(self),
}
def visit_predicate_stmt(self, stmt: PredicateStmt) -> dict:
return {
"_type": "PredicateStmt",
"name": stmt.name.lexeme,
"subject": stmt.subject.lexeme,
"type": stmt.type.accept(self),
"condition": stmt.condition.accept(self),
}
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict:
return {
"_type": "SimpleTypeExpr",
"name": expr.name.lexeme,
"optional": expr.optional,
}
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
return {
"_type": "LogicalExpr",
"left": expr.left.accept(self),
"operator": expr.operator.lexeme,
"right": expr.right.accept(self),
}
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
return {
"_type": "BinaryExpr",
"left": expr.left.accept(self),
"operator": expr.operator.lexeme,
"right": expr.right.accept(self),
}
def visit_unary_expr(self, expr: UnaryExpr) -> dict:
return {
"_type": "UnaryExpr",
"operator": expr.operator.lexeme,
"right": expr.right.accept(self),
}
def visit_get_expr(self, expr: GetExpr) -> dict:
return {
"_type": "GetExpr",
"expr": expr.expr.accept(self),
"name": expr.name.lexeme,
}
def visit_variable_expr(self, expr: VariableExpr) -> dict:
return {
"_type": "VariableExpr",
"name": expr.name.lexeme,
}
def visit_grouping_expr(self, expr: GroupingExpr) -> dict:
return {
"_type": "GroupingExpr",
"expr": expr.expr.accept(self),
}
def visit_literal_expr(self, expr: LiteralExpr) -> dict:
return {
"_type": "LiteralExpr",
"value": expr.value,
}
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
return {"_type": "WildcardExpr"}
def visit_template_expr(self, expr: TemplateExpr) -> dict:
return {
"_type": "TemplateExpr",
"type": expr.type.accept(self),
}
def visit_type_expr(self, expr: TypeExpr) -> dict:
return {
"_type": "TypeExpr",
"name": expr.name.lexeme,
"template": self._serialize_optional(expr.template),
"optional": expr.optional,
}

View File

@@ -1,3 +1,8 @@
"""
This file was generated by a script. Any manual changes might be overwritten.
Please modify gen/ast.py instead and run gen/gen.py
"""
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -8,8 +13,9 @@ from lexer.token import Token
T = TypeVar("T") T = TypeVar("T")
##############
# Statements # Statements #
##############
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -19,42 +25,68 @@ class Stmt(ABC):
class Visitor(ABC, Generic[T]): class Visitor(ABC, Generic[T]):
@abstractmethod @abstractmethod
def visit_type_stmt(self, stmt: TypeStmt) -> T: ... def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> T: ...
@abstractmethod
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> T: ...
@abstractmethod @abstractmethod
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ... def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
@abstractmethod
def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ...
@abstractmethod @abstractmethod
def visit_op_stmt(self, stmt: OpStmt) -> T: ... def visit_op_stmt(self, stmt: OpStmt) -> T: ...
@abstractmethod @abstractmethod
def visit_constraint_stmt(self, stmt: ConstraintStmt) -> T: ... def visit_predicate_stmt(self, stmt: PredicateStmt) -> T: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class TypeStmt(Stmt): class SimpleTypeStmt(Stmt):
name: Token name: Token
bases: list[TypeExpr] template: Optional[TemplateExpr]
body: Optional[TypeBodyExpr] base: TypeExpr
constraint: Optional[Expr]
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_type_stmt(self) return visitor.visit_simple_type_stmt(self)
@dataclass(frozen=True)
class ComplexTypeStmt(Stmt):
name: Token
template: Optional[TemplateExpr]
properties: list[PropertyStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_complex_type_stmt(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class PropertyStmt(Stmt): class PropertyStmt(Stmt):
name: Token name: Token
type: TypeExpr type: TypeExpr
constraint: Optional[Expr]
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_property_stmt(self) return visitor.visit_property_stmt(self)
@dataclass(frozen=True)
class ExtendStmt(Stmt):
type: TypeExpr
operations: list[OpStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_extend_stmt(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class OpStmt(Stmt): class OpStmt(Stmt):
left: TypeExpr name: Token
op: Token operand: TypeExpr
right: TypeExpr
result: TypeExpr result: TypeExpr
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
@@ -62,15 +94,19 @@ class OpStmt(Stmt):
@dataclass(frozen=True) @dataclass(frozen=True)
class ConstraintStmt(Stmt): class PredicateStmt(Stmt):
name: Token name: Token
constraint: ConstraintExpr subject: Token
type: TypeExpr
condition: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_constraint_stmt(self) return visitor.visit_predicate_stmt(self)
# Expressions ###############
# Expressions #
###############
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -80,27 +116,100 @@ class Expr(ABC):
class Visitor(ABC, Generic[T]): class Visitor(ABC, Generic[T]):
@abstractmethod @abstractmethod
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ... def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> T: ...
@abstractmethod
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
@abstractmethod
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
@abstractmethod
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
@abstractmethod
def visit_get_expr(self, expr: GetExpr) -> T: ...
@abstractmethod
def visit_variable_expr(self, expr: VariableExpr) -> T: ...
@abstractmethod
def visit_grouping_expr(self, expr: GroupingExpr) -> T: ...
@abstractmethod @abstractmethod
def visit_literal_expr(self, expr: LiteralExpr) -> T: ... def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
@abstractmethod
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
@abstractmethod
def visit_template_expr(self, expr: TemplateExpr) -> T: ...
@abstractmethod @abstractmethod
def visit_type_expr(self, expr: TypeExpr) -> T: ... def visit_type_expr(self, expr: TypeExpr) -> T: ...
@abstractmethod
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
@abstractmethod
def visit_type_body_expr(self, expr: TypeBodyExpr) -> T: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class WildcardExpr(Expr): class SimpleTypeExpr(Expr):
token: Token name: Token
optional: bool
def accept(self, visitor: Expr.Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_wildcard_expr(self) return visitor.visit_simple_type_expr(self)
@dataclass(frozen=True)
class LogicalExpr(Expr):
left: Expr
operator: Token
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_logical_expr(self)
@dataclass(frozen=True)
class BinaryExpr(Expr):
left: Expr
operator: Token
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_binary_expr(self)
@dataclass(frozen=True)
class UnaryExpr(Expr):
operator: Token
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_unary_expr(self)
@dataclass(frozen=True)
class GetExpr(Expr):
expr: Expr
name: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_get_expr(self)
@dataclass(frozen=True)
class VariableExpr(Expr):
name: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_variable_expr(self)
@dataclass(frozen=True)
class GroupingExpr(Expr):
expr: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_grouping_expr(self)
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -111,28 +220,27 @@ class LiteralExpr(Expr):
return visitor.visit_literal_expr(self) return visitor.visit_literal_expr(self)
@dataclass(frozen=True)
class WildcardExpr(Expr):
token: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_wildcard_expr(self)
@dataclass(frozen=True)
class TemplateExpr(Expr):
type: TypeExpr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_template_expr(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class TypeExpr(Expr): class TypeExpr(Expr):
name: Token name: Token
constraints: list[ConstraintExpr] template: Optional[TemplateExpr]
optional: bool
def accept(self, visitor: Expr.Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_type_expr(self) return visitor.visit_type_expr(self)
@dataclass(frozen=True)
class ConstraintExpr(Expr):
left: Expr
op: Token
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_constraint_expr(self)
@dataclass(frozen=True)
class TypeBodyExpr(Expr):
properties: list[PropertyStmt]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_type_body_expr(self)

View File

@@ -1,11 +1,10 @@
from __future__ import annotations from __future__ import annotations
import io
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum, auto from enum import Enum, auto
import io
from typing import Generator, Generic, Optional, Protocol, TypeVar from typing import Generator, Generic, Optional, Protocol, TypeVar
import core.ast.annotations as a
import core.ast.midas as m import core.ast.midas as m
@@ -39,8 +38,8 @@ class AstPrinter(Generic[T]):
return self._buf.getvalue() return self._buf.getvalue()
@contextmanager @contextmanager
def _child_level(self, last: bool = False) -> Generator[None, None, None]: def _child_level(self, single: bool = False) -> Generator[None, None, None]:
self._levels.append(_Level.LAST if last else _Level.ACTIVE) self._levels.append(_Level.LAST if single else _Level.ACTIVE)
try: try:
yield yield
finally: finally:
@@ -80,215 +79,170 @@ class AstPrinter(Generic[T]):
self._write_line(f"{label}: None") self._write_line(f"{label}: None")
else: else:
self._write_line(label) self._write_line(label)
with self._child_level(last=True): with self._child_level(single=True):
child.accept(self) child.accept(self)
class AnnotationAstPrinter(AstPrinter, a.Expr.Visitor[None], a.Stmt.Visitor[None]):
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> None:
self._write_line("AnnotationStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_optional_child("schema", stmt.schema, last=True)
def visit_type_expr(self, expr: a.TypeExpr):
self._write_line("TypeExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_line("constraints", last=True)
with self._child_level():
for i, constraint in enumerate(expr.constraints):
self._idx = i
if i == len(expr.constraints) - 1:
self._mark_last()
constraint.accept(self)
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> None:
self._write_line("ConstraintExpr")
with self._child_level():
self._write_line("left")
with self._child_level():
self._mark_last()
expr.left.accept(self)
self._write_line(f"operator: {expr.op.lexeme}")
self._write_line("right", last=True)
with self._child_level():
self._mark_last()
expr.right.accept(self)
def visit_schema_expr(self, expr: a.SchemaExpr):
self._write_line("SchemaExpr")
with self._child_level():
for i, elmt in enumerate(expr.elements):
self._idx = i
if i == len(expr.elements) - 1:
self._mark_last()
elmt.accept(self)
def visit_schema_element_expr(self, expr: a.SchemaElementExpr):
self._write_line("SchemaElementExpr")
with self._child_level():
name_text: str = "None" if expr.name is None else f'"{expr.name.lexeme}"'
self._write_line(f"name: {name_text}")
self._write_optional_child("type", expr.type, last=True)
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> None:
self._write_line("WildcardExpr")
def visit_literal_expr(self, expr: a.LiteralExpr) -> None:
self._write_line("LiteralExpr")
with self._child_level():
self._write_line(f"value: {expr.value}", last=True)
class AnnotationPrinter(a.Expr.Visitor[str], a.Stmt.Visitor[str]):
def print(self, expr: a.Expr | a.Stmt):
return expr.accept(self)
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> str:
schema: str = ""
if stmt.schema is not None:
schema = stmt.schema.accept(self)
return f"{stmt.name.lexeme}{schema}"
def visit_type_expr(self, expr: a.TypeExpr) -> str:
parts: list[str] = [expr.name.lexeme]
for constraint in expr.constraints:
parts.append("(" + constraint.accept(self) + ")")
return " + ".join(parts)
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> str:
parts: list[str] = [
expr.left.accept(self),
expr.op.lexeme,
expr.right.accept(self),
]
return " ".join(parts)
def visit_schema_expr(self, expr: a.SchemaExpr) -> str:
res: str = expr.left.lexeme
res += ", ".join(elmt.accept(self) for elmt in expr.elements)
res += expr.right.lexeme
return res
def visit_schema_element_expr(self, expr: a.SchemaElementExpr) -> str:
parts: list[str] = []
if expr.name is not None:
parts.append(expr.name.lexeme)
if expr.type is None:
parts.append("_")
else:
parts.append(expr.type.accept(self))
return ": ".join(parts)
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> str:
return "_"
def visit_literal_expr(self, expr: a.LiteralExpr) -> str:
return str(expr.value)
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
def visit_type_stmt(self, stmt: m.TypeStmt): #Statements
self._write_line("TypeStmt")
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
self._write_line("SimpleTypeStmt")
with self._child_level(): with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("bases") self._write_optional_child("template", stmt.template)
self._write_line("base")
with self._child_level(single=True):
stmt.base.accept(self)
self._write_optional_child("constraint", stmt.constraint, last=True)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
self._write_line("ComplexTypeStmt")
with self._child_level(): with self._child_level():
for i, base in enumerate(stmt.bases): self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_optional_child("template", stmt.template)
self._write_line("properties", last=True)
with self._child_level():
for i, prop in enumerate(stmt.properties):
self._idx = i self._idx = i
if i == len(stmt.bases) - 1: if i == len(stmt.properties) - 1:
self._mark_last() self._mark_last()
base.accept(self) prop.accept(self)
self._write_optional_child("body", stmt.body, last=True)
def visit_property_stmt(self, stmt: m.PropertyStmt): def visit_property_stmt(self, stmt: m.PropertyStmt):
self._write_line("PropertyStmt") self._write_line("PropertyStmt")
with self._child_level(): with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True) self._write_line("type")
with self._child_level(): with self._child_level(single=True):
self._mark_last()
stmt.type.accept(self) stmt.type.accept(self)
self._write_optional_child("constraint", stmt.constraint, last=True)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._write_line("ExtendStmt")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
stmt.type.accept(self)
self._write_line("operations", last=True)
with self._child_level():
for i, op in enumerate(stmt.operations):
self._idx = i
if i == len(stmt.operations) - 1:
self._mark_last()
op.accept(self)
def visit_op_stmt(self, stmt: m.OpStmt) -> None: def visit_op_stmt(self, stmt: m.OpStmt) -> None:
self._write_line("OpStmt") self._write_line("OpStmt")
with self._child_level(): with self._child_level():
self._write_line("left") self._write_line(f'name: "{stmt.name.lexeme}"')
with self._child_level():
self._mark_last()
stmt.left.accept(self)
self._write_line(f'op: "{stmt.op.lexeme}"') self._write_line("operand")
with self._child_level(single=True):
self._write_line("right") stmt.operand.accept(self)
with self._child_level():
self._mark_last()
stmt.right.accept(self)
self._write_line("result", last=True) self._write_line("result", last=True)
with self._child_level(): with self._child_level(single=True):
self._mark_last()
stmt.result.accept(self) stmt.result.accept(self)
def visit_constraint_stmt(self, stmt: m.ConstraintStmt): def visit_predicate_stmt(self, stmt: m.PredicateStmt):
self._write_line("ConstraintStmt") self._write_line("PredicateStmt")
with self._child_level(): with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("constraint", last=True) self._write_line(f'subject: "{stmt.subject.lexeme}"')
with self._child_level(): self._write_line("type")
self._mark_last() with self._child_level(single=True):
stmt.constraint.accept(self) stmt.type.accept(self)
self._write_line("condition", last=True)
with self._child_level(single=True):
stmt.condition.accept(self)
def visit_type_expr(self, expr: m.TypeExpr): # Expressions
self._write_line("TypeExpr")
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
self._write_line("SimpleTypeExpr")
with self._child_level(): with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"') self._write_line(f'name: "{expr.name.lexeme}"')
self._write_line("constraints", last=True) self._write_line(f"optional: {expr.optional}", last=True)
with self._child_level():
for i, constraint in enumerate(expr.constraints):
self._idx = i
if i == len(expr.constraints) - 1:
self._mark_last()
constraint.accept(self)
def visit_constraint_expr(self, expr: m.ConstraintExpr): def visit_logical_expr(self, expr: m.LogicalExpr):
self._write_line("ConstraintExpr") self._write_line("LogicalExpr")
with self._child_level(): with self._child_level():
self._write_line("left") self._write_line("left")
with self._child_level(): with self._child_level(single=True):
self._mark_last()
expr.left.accept(self) expr.left.accept(self)
self._write_line(f"operator: {expr.op.lexeme}") self._write_line(f"operator: {expr.operator.lexeme}")
self._write_line("right", last=True) self._write_line("right", last=True)
with self._child_level(): with self._child_level(single=True):
self._mark_last()
expr.right.accept(self) expr.right.accept(self)
def visit_type_body_expr(self, expr: m.TypeBodyExpr): def visit_binary_expr(self, expr: m.BinaryExpr):
self._write_line("TypeBodyExpr") self._write_line("BinaryExpr")
with self._child_level(): with self._child_level():
self._write_line("properties", last=True) self._write_line("left")
with self._child_level(): with self._child_level(single=True):
for i, property in enumerate(expr.properties): expr.left.accept(self)
self._idx = i
if i == len(expr.properties) - 1:
self._mark_last()
property.accept(self)
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: self._write_line(f"operator: {expr.operator.lexeme}")
self._write_line("WildcardExpr")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_unary_expr(self, expr: m.UnaryExpr):
self._write_line("UnaryExpr")
with self._child_level():
self._write_line(f"operator: {expr.operator.lexeme}")
self._write_line("right", last=True)
with self._child_level(single=True):
expr.right.accept(self)
def visit_get_expr(self, expr: m.GetExpr):
self._write_line("GetExpr")
with self._child_level():
self._write_line("expr")
with self._child_level(single=True):
expr.expr.accept(self)
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
def visit_variable_expr(self, expr: m.VariableExpr):
self._write_line("VariableExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"', last=True)
def visit_grouping_expr(self, expr: m.GroupingExpr):
self._write_line("GroupingExpr")
with self._child_level():
self._write_line("expr", last=True)
with self._child_level(single=True):
expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
self._write_line("LiteralExpr") self._write_line("LiteralExpr")
with self._child_level(): with self._child_level():
self._write_line(f"value: {expr.value}", last=True) self._write_line(f"value: {expr.value}", last=True)
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self._write_line("WildcardExpr")
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
self._write_line("TemplateExpr")
with self._child_level(single=True):
self._write_line("type")
with self._child_level(single=True):
expr.type.accept(self)
def visit_type_expr(self, expr: m.TypeExpr):
self._write_line("TypeExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_optional_child("template", expr.template)
self._write_line(f"optional: {expr.optional}", last=True)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def __init__(self, indent: int = 4): def __init__(self, indent: int = 4):
self.indent: int = indent self.indent: int = indent
@@ -301,60 +255,94 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
self.level = 0 self.level = 0
return expr.accept(self) return expr.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt): def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
bases: list[str] = [ template: str = stmt.template.accept(self) if stmt.template is not None else ""
b.accept(self) res: str = f"type {stmt.name.lexeme}{template}({stmt.base.accept(self)})"
for b in stmt.bases if stmt.constraint is not None:
] res += " where " + stmt.constraint.accept(self)
return self.indented(res)
res: str = self.indented(f"type {stmt.name.lexeme}<{', '.join(bases)}>") def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
if stmt.body is not None: template: str = stmt.template.accept(self) if stmt.template is not None else ""
res: str = self.indented(f"type {stmt.name.lexeme}{template}")
res += " {\n" res += " {\n"
self.level += 1 self.level += 1
res += stmt.body.accept(self) for prop in stmt.properties:
res += prop.accept(self)
res += "\n"
self.level -= 1 self.level -= 1
res += "\n" + self.indented("}") res += self.indented("}")
return res return res
def visit_property_stmt(self, stmt: m.PropertyStmt): def visit_property_stmt(self, stmt: m.PropertyStmt):
return f"{stmt.name.lexeme}: {stmt.type.accept(self)}" res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
if stmt.constraint is not None:
res += " where " + stmt.constraint.accept(self)
return self.indented(res)
def visit_extend_stmt(self, stmt: m.ExtendStmt):
res: str = self.indented(f"extend {stmt.type.accept(self)}")
res += " {\n"
self.level += 1
for op in stmt.operations:
res += op.accept(self)
self.level -= 1
res += "\n" + self.indented("}")
return res
def visit_op_stmt(self, stmt: m.OpStmt): def visit_op_stmt(self, stmt: m.OpStmt):
left: str = stmt.left.accept(self) operand: str = stmt.operand.accept(self)
op: str = stmt.op.lexeme
right: str = stmt.right.accept(self)
result: str = stmt.result.accept(self) result: str = stmt.result.accept(self)
return self.indented(f"op <{left}> {op} <{right}> = <{result}>") return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}")
def visit_constraint_stmt(self, stmt: m.ConstraintStmt): def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme name: str = stmt.name.lexeme
constraint: str = stmt.constraint.accept(self) subject: str = stmt.subject.lexeme
return self.indented(f"constraint {name} = {constraint}") type: str = stmt.type.accept(self)
condition: str = stmt.condition.accept(self)
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
def visit_type_expr(self, expr: m.TypeExpr): def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
parts: list[str] = [expr.name.lexeme] return f"{expr.name.lexeme}{'?' if expr.optional else ''}"
for constraint in expr.constraints:
parts.append("(" + constraint.accept(self) + ")")
return " + ".join(parts)
def visit_constraint_expr(self, expr: m.ConstraintExpr): def visit_logical_expr(self, expr: m.LogicalExpr):
parts: list[str] = [ left: str = expr.left.accept(self)
expr.left.accept(self), operator: str = expr.operator.lexeme
expr.op.lexeme, right: str = expr.right.accept(self)
expr.right.accept(self), return f"{left} {operator} {right}"
]
return " ".join(parts)
def visit_type_body_expr(self, expr: m.TypeBodyExpr): def visit_binary_expr(self, expr: m.BinaryExpr):
properties: list[str] = [ left: str = expr.left.accept(self)
self.indented(prop.accept(self)) operator: str = expr.operator.lexeme
for prop in expr.properties right: str = expr.right.accept(self)
] return f"{left} {operator} {right}"
return "\n".join(properties)
def visit_unary_expr(self, expr: m.UnaryExpr):
operator: str = expr.operator.lexeme
right: str = expr.right.accept(self)
return f"{operator}{right}"
def visit_get_expr(self, expr: m.GetExpr):
expr_: str = expr.expr.accept(self)
name: str = expr.name.lexeme
return f"{expr_}.{name}"
def visit_variable_expr(self, expr: m.VariableExpr):
return expr.name.lexeme
def visit_grouping_expr(self, expr: m.GroupingExpr):
expr_: str = expr.expr.accept(self)
return f"({expr_})"
def visit_literal_expr(self, expr: m.LiteralExpr):
return str(expr.value)
def visit_wildcard_expr(self, expr: m.WildcardExpr): def visit_wildcard_expr(self, expr: m.WildcardExpr):
return "_" return "_"
def visit_literal_expr(self, expr: m.LiteralExpr): def visit_template_expr(self, expr: m.TemplateExpr):
return str(expr.value) return f"[{expr.type.accept(self)}]"
def visit_type_expr(self, expr: m.TypeExpr):
template: str = expr.template.accept(self) if expr.template is not None else ""
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}"

View File

@@ -0,0 +1,73 @@
// Simple custom type derived from float
type Custom(float)
// Simple custom types with constraints
type Latitude(float) where (-90 <= _ <= 90)
type Longitude(float) where (-180 <= _ <= 180)
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
type Difference[T](T)
// Complex custom type, containing two values accessible through properties
type GeoLocation {
lat: Latitude
lon: Longitude
}
// Define operations on our custom type
extend GeoLocation {
// This type is compatible with the `-` operation with another GeoLocation
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
// in a Difference of GeoLocations
op __sub__(GeoLocation) -> Difference[GeoLocation]
}
// For complex generics, you need to specify how the genericity the properties
// are handled
type Difference[GeoLocation] {
lat: Difference[Latitude]
lon: Difference[Longitude]
}
// Simple operation defined on our custom types
extend Latitude {
op __sub__(Latitude) -> Difference[Latitude]
}
extend Longitude {
op __sub__(Longitude) -> Difference[Longitude]
}
// Predefined custom predicates that can be referenced in other definitions
predicate Positive(v: float) = v >= 0
predicate StrictlyPositive(v: float) = v > 0
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
type Person {
name: str
// Property with an inline constraint
age: int? where (0 <= _ < 150)
// Property referencing a predicate
height: float where StrictlyPositive
home: GeoLocation
}
// Custom complex type derived from another complex type, with a constraint
// on a property
// Multiple proposed syntaxes, not yet defined
// Explicit, but new keyword
type EquatorialPerson refines Person where Equatorial(_.home)
// Explicit with existing keyword, might be confusing if expectations regarding 'is'
type EquatorialPerson is Person where Equatorial(_.home)
// Consistent and Python-friendly but can be confused with structural extension
type EquatorialPerson(Person) where Equatorial(_.home)
// Allow new properties, probably not useful
type EquatorialPerson extends Person where Equatorial(_.home)

72
gen/ast.py Normal file
View File

@@ -0,0 +1,72 @@
class SimpleTypeStmt:
name: Token
template: Optional[TemplateExpr]
base: TypeExpr
constraint: Optional[Expr]
class SimpleTypeExpr:
name: Token
optional: bool
class LogicalExpr:
left: Expr
operator: Token
right: Expr
class BinaryExpr:
left: Expr
operator: Token
right: Expr
class UnaryExpr:
operator: Token
right: Expr
class GetExpr:
expr: Expr
name: Token
class VariableExpr:
name: Token
class GroupingExpr:
expr: Expr
class LiteralExpr:
value: Any
class WildcardExpr:
token: Token
class TemplateExpr:
type: TypeExpr
class TypeExpr:
name: Token
template: Optional[TemplateExpr]
optional: bool
class ComplexTypeStmt:
name: Token
template: Optional[TemplateExpr]
properties: list[PropertyStmt]
class PropertyStmt:
name: Token
type: TypeExpr
constraint: Optional[Expr]
class ExtendStmt:
type: TypeExpr
operations: list[OpStmt]
class OpStmt:
name: Token
operand: TypeExpr
result: TypeExpr
class PredicateStmt:
name: Token
subject: Token
type: TypeExpr
condition: Expr

128
gen/gen.py Normal file
View File

@@ -0,0 +1,128 @@
from pathlib import Path
import re
HEADER = '''"""
This file was generated by a script. Any manual changes might be overwritten.
Please modify gen/ast.py instead and run gen/gen.py
"""'''
TEMPLATE = """{header}
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from lexer.token import Token
T = TypeVar("T")
##############
# Statements #
##############
@dataclass(frozen=True)
class Stmt(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
{stmt_visitor_methods}
{statements}
###############
# Expressions #
###############
@dataclass(frozen=True)
class Expr(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
{expr_visitor_methods}
{expressions}
"""
VISITOR_METHOD_TEMPLATE = """
@abstractmethod
def visit_{func_name}(self, {param}: {cls}) -> T: ...
"""
CLASS_TEMPLATE = """
@dataclass(frozen=True)
class {cls}({base}):
{body}
def accept(self, visitor: {base}.Visitor[T]) -> T:
return visitor.visit_{func_name}(self)
"""
def snake_case(text: str) -> str:
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
def make_visitor_method(cls: str, param: str):
method: str = VISITOR_METHOD_TEMPLATE.format(
func_name=snake_case(cls),
param=param,
cls=cls
)
return method.strip("\n")
def make_class(name: str, cls: str, base: str):
body: str = cls.split("\n", 1)[1]
func_name: str = snake_case(name)
cls_def: str = CLASS_TEMPLATE.format(
cls=name,
base=base,
body=body,
func_name=func_name,
)
return cls_def.strip("\n")
def generate(src: str):
classes: list[str] = src.split("\n\n")
stmt_visitor_methods: list[str] = []
expr_visitor_methods: list[str] = []
statements: list[str] = []
expressions: list[str] = []
for cls in classes:
cls = cls.strip("\n")
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
print(f"Processing {name}")
if name.endswith("Stmt"):
stmt_visitor_methods.append(make_visitor_method(name, "stmt"))
statements.append(make_class(name, cls, "Stmt"))
elif name.endswith("Expr"):
expr_visitor_methods.append(make_visitor_method(name, "expr"))
expressions.append(make_class(name, cls, "Expr"))
return TEMPLATE.format(
header=HEADER,
stmt_visitor_methods="\n\n".join(stmt_visitor_methods),
expr_visitor_methods="\n\n".join(expr_visitor_methods),
statements="\n\n\n".join(statements),
expressions="\n\n\n".join(expressions),
)
def main():
root: Path = Path(__file__).parent.parent
in_path: Path = root / "gen" / "ast.py"
out_path: Path = root / "core" / "ast" / "midas.py"
src: str = in_path.read_text()
generated: str = generate(src)
out_path.write_text(generated)
if __name__ == "__main__":
main()

View File

@@ -1,102 +0,0 @@
from lexer.base import Lexer
from lexer.keyword import ANNOTATION_KEYWORDS
from lexer.token import TokenType
class AnnotationLexer(Lexer):
def scan_token(self) -> None:
char: str = self.advance()
match char:
case "(":
self.add_token(TokenType.LEFT_PAREN)
case ")":
self.add_token(TokenType.RIGHT_PAREN)
case "[":
self.add_token(TokenType.LEFT_BRACKET)
case "]":
self.add_token(TokenType.RIGHT_BRACKET)
case "<":
self.add_token(
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS
)
case ">":
self.add_token(
TokenType.GREATER_EQUAL if self.match("=") else TokenType.GREATER
)
case "=":
self.add_token(
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
)
case "!":
if self.match("="):
self.add_token(TokenType.BANG_EQUAL)
else:
self.error("Unexpected single bang. Did you mean '!=' ?")
case ":":
self.add_token(TokenType.COLON)
case ",":
self.add_token(TokenType.COMMA)
case "_":
self.add_token(TokenType.UNDERSCORE)
case "+":
self.add_token(TokenType.PLUS)
case "#":
self.scan_comment()
case "\n":
self.add_token(TokenType.NEWLINE)
case " " | "\r" | "\t":
# Consume all whitespace characters until EOL or EOF
while (
self.peek().isspace()
and self.peek() != "\n"
and not self.is_at_end()
):
self.advance()
self.add_token(TokenType.WHITESPACE)
case _:
if char.isdigit():
self.scan_number()
elif char.isalpha():
self.scan_identifier()
else:
self.error("Unexpected character")
return None
def scan_number(self):
"""Scan the rest of number and add it as a token
This method handles both simple integers and floats. Scientific notation
and base prefixes (0x, 0b, 0o) are not supported
"""
while self.peek().isdigit():
self.advance()
if self.peek() == "." and self.peek_next().isdigit():
self.advance()
while self.peek().isdigit():
self.advance()
value: float = float(self.source[self.start : self.idx])
self.add_token(TokenType.NUMBER, value)
def scan_identifier(self):
"""Scan the rest of an identifier and add it as a token
An identifier starts with a letter, followed by any number of
alphanumerical characters or underscores
"""
while self.peek().isalnum() or self.peek() == "_":
self.advance()
lexeme: str = self.source[self.start : self.idx]
token_type: TokenType = ANNOTATION_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
self.add_token(token_type)
def scan_comment(self):
"""Scan the rest of a comment and add it as a token
A comment starts with a `#` character and ends at the EOL/EOF
"""
while self.peek() != "\n" and not self.is_at_end():
self.advance()
self.add_token(TokenType.COMMENT)

View File

@@ -5,6 +5,13 @@ from lexer.position import Position
from lexer.token import Token, TokenType from lexer.token import Token, TokenType
class MidasSyntaxError(Exception):
def __init__(self, pos: Position, message: str):
super().__init__(f"[ERROR] Error at {pos}: {message}")
self.pos: Position = pos
self.message: str = message
class Lexer(ABC): class Lexer(ABC):
"""An abstract lexer which provides methods to easily extend it into a concrete one """An abstract lexer which provides methods to easily extend it into a concrete one
@@ -38,9 +45,9 @@ class Lexer(ABC):
msg (str): the error message msg (str): the error message
Raises: Raises:
SyntaxError MidasSyntaxError
""" """
raise SyntaxError(f"[ERROR] Error at {self.start_pos}: {msg}") raise MidasSyntaxError(self.start_pos, msg)
def process(self) -> list[Token]: def process(self) -> list[Token]:
"""Scan tokens out of the source text """Scan tokens out of the source text
@@ -49,7 +56,7 @@ class Lexer(ABC):
list[Token]: all the tokens that could be scanned list[Token]: all the tokens that could be scanned
Raises: Raises:
SyntaxError: if a syntax error is found MidasSyntaxError: if a syntax error is found
""" """
self.scan_tokens() self.scan_tokens()
self.tokens.append(Token(TokenType.EOF, "", None, self.get_position())) self.tokens.append(Token(TokenType.EOF, "", None, self.get_position()))

View File

@@ -1,15 +1,11 @@
from lexer.token import TokenType from lexer.token import TokenType
ANNOTATION_KEYWORDS: dict[str, TokenType] = { KEYWORDS: dict[str, TokenType] = {
"True": TokenType.TRUE,
"False": TokenType.FALSE,
"None": TokenType.NONE,
}
MIDAS_KEYWORDS: dict[str, TokenType] = {
"type": TokenType.TYPE, "type": TokenType.TYPE,
"op": TokenType.OP, "op": TokenType.OP,
"constraint": TokenType.CONSTRAINT, "predicate": TokenType.PREDICATE,
"extend": TokenType.EXTEND,
"where": TokenType.WHERE,
"true": TokenType.TRUE, "true": TokenType.TRUE,
"false": TokenType.FALSE, "false": TokenType.FALSE,
"none": TokenType.NONE, "none": TokenType.NONE,

View File

@@ -1,5 +1,5 @@
from lexer.base import Lexer from lexer.base import Lexer
from lexer.keyword import MIDAS_KEYWORDS from lexer.keyword import KEYWORDS
from lexer.token import TokenType from lexer.token import TokenType
@@ -31,30 +31,32 @@ class MidasLexer(Lexer):
self.add_token( self.add_token(
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
) )
case "!": case "!" if self.match("="):
if self.match("="):
self.add_token(TokenType.BANG_EQUAL) self.add_token(TokenType.BANG_EQUAL)
else:
self.error("Unexpected single bang. Did you mean '!=' ?")
case ":": case ":":
self.add_token(TokenType.COLON) self.add_token(TokenType.COLON)
case ",": case ".":
self.add_token(TokenType.COMMA) self.add_token(TokenType.DOT)
case "_": case "&":
self.add_token(TokenType.AND)
case "?":
self.add_token(TokenType.QMARK)
# case ",":
# self.add_token(TokenType.COMMA)
case "_" if not self.is_identifier_char(self.peek_next(), start=False):
self.add_token(TokenType.UNDERSCORE) self.add_token(TokenType.UNDERSCORE)
case "+": case "-" if self.match(">"):
self.add_token(TokenType.PLUS) self.add_token(TokenType.ARROW)
# case "+":
# self.add_token(TokenType.PLUS)
case "-": case "-":
self.add_token(TokenType.MINUS) self.add_token(TokenType.MINUS)
case "*": # case "*":
self.add_token(TokenType.STAR) # self.add_token(TokenType.STAR)
case "/": case "/" if self.match("/"):
if self.match("/"):
self.scan_comment() self.scan_comment()
elif self.match("*"): case "/" if self.match("*"):
self.scan_comment_multiline() self.scan_comment_multiline()
else:
self.add_token(TokenType.SLASH)
case "\n": case "\n":
self.add_token(TokenType.NEWLINE) self.add_token(TokenType.NEWLINE)
case " " | "\r" | "\t": case " " | "\r" | "\t":
@@ -69,7 +71,7 @@ class MidasLexer(Lexer):
case _: case _:
if char.isdigit(): if char.isdigit():
self.scan_number() self.scan_number()
elif char.isalpha(): elif self.is_identifier_char(char, start=True):
self.scan_identifier() self.scan_identifier()
else: else:
self.error("Unexpected character") self.error("Unexpected character")
@@ -98,11 +100,11 @@ class MidasLexer(Lexer):
An identifier starts with a letter, followed by any number of An identifier starts with a letter, followed by any number of
alphanumerical characters or underscores alphanumerical characters or underscores
""" """
while self.peek().isalnum() or self.peek() == "_": while self.is_identifier_char(self.peek(), start=False):
self.advance() self.advance()
lexeme: str = self.source[self.start : self.idx] lexeme: str = self.source[self.start : self.idx]
token_type: TokenType = MIDAS_KEYWORDS.get(lexeme, TokenType.IDENTIFIER) token_type: TokenType = KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
self.add_token(token_type) self.add_token(token_type)
def scan_comment(self): def scan_comment(self):
@@ -129,3 +131,12 @@ class MidasLexer(Lexer):
if not self.is_at_end(): if not self.is_at_end():
self.advance() self.advance()
self.add_token(TokenType.COMMENT) self.add_token(TokenType.COMMENT)
def is_identifier_char(self, char: str, *, start: bool) -> bool:
if char == "_":
return True
if char.isalpha():
return True
if not start and char.isdigit():
return True
return False

View File

@@ -14,14 +14,18 @@ class TokenType(Enum):
LEFT_BRACE = auto() LEFT_BRACE = auto()
RIGHT_BRACE = auto() RIGHT_BRACE = auto()
COLON = auto() COLON = auto()
COMMA = auto() # COMMA = auto()
UNDERSCORE = auto() UNDERSCORE = auto()
ARROW = auto()
AND = auto()
QMARK = auto()
DOT = auto()
# Operators # Operators
PLUS = auto() # PLUS = auto()
MINUS = auto() MINUS = auto()
STAR = auto() # STAR = auto()
SLASH = auto() # SLASH = auto()
GREATER = auto() GREATER = auto()
GREATER_EQUAL = auto() GREATER_EQUAL = auto()
LESS = auto() LESS = auto()
@@ -40,7 +44,9 @@ class TokenType(Enum):
# Keywords # Keywords
TYPE = auto() TYPE = auto()
OP = auto() OP = auto()
CONSTRAINT = auto() PREDICATE = auto()
EXTEND = auto()
WHERE = auto()
# Misc # Misc
COMMENT = auto() COMMENT = auto()

View File

@@ -1,152 +0,0 @@
from typing import Optional
from core.ast.annotations import (
AnnotationStmt,
ConstraintExpr,
Expr,
LiteralExpr,
SchemaElementExpr,
SchemaExpr,
Stmt,
TypeExpr,
WildcardExpr,
)
from lexer.token import Token, TokenType
from parser.base import Parser
from parser.errors import ParsingError
class AnnotationParser(Parser):
"""A simple parser for custom type annotations"""
SYNC_BOUNDARY: set[TokenType] = set()
def parse(self) -> Optional[Stmt]:
stmt: Optional[Stmt] = None
try:
stmt = self.annotation()
except ParsingError:
self.synchronize()
if not self.is_at_end():
self.error(self.peek(), "Extra tokens")
return stmt
def synchronize(self):
"""Skip tokens until a synchronization boundary is found
This method allows gracefully recovering from a parse error
to a safe place and continue parsing
"""
self.advance()
while not self.is_at_end():
if self.peek().type in self.SYNC_BOUNDARY:
return
self.advance()
def annotation(self) -> AnnotationStmt:
"""Parse an annotation
An annotation is written as `Type` or `Type[Schema]`
Returns:
AnnotationStmt: the parsed annotation statement
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type identifier")
schema: Optional[SchemaExpr] = None
if self.match(TokenType.LEFT_BRACKET):
schema = self.schema()
return AnnotationStmt(name=name, schema=schema)
def type_expr(self) -> TypeExpr:
"""Parse a type expression
Returns:
TypeExpr: the parsed type expression
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
constraints: list[ConstraintExpr] = []
while not self.is_at_end() and self.match(TokenType.PLUS):
self.consume(TokenType.LEFT_PAREN, "Expected '(' before type constraint")
constraints.append(self.constraint_expr())
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after type constraint")
return TypeExpr(name=name, constraints=constraints)
def constraint_expr(self) -> ConstraintExpr:
"""Parse a type constraint
Returns:
ConstraintExpr: the parsed type constraint expression
"""
left: Expr = self.constraint_value()
op: Token = self.constraint_operator()
right: Expr = self.constraint_value()
return ConstraintExpr(left=left, op=op, right=right)
def constraint_value(self) -> Expr:
if self.match(TokenType.UNDERSCORE):
return WildcardExpr(self.previous())
return self.literal()
def literal(self) -> LiteralExpr:
if self.match(TokenType.FALSE):
return LiteralExpr(False)
if self.match(TokenType.TRUE):
return LiteralExpr(True)
if self.match(TokenType.NONE):
return LiteralExpr(None)
if self.match(TokenType.NUMBER):
return LiteralExpr(self.previous().value)
raise self.error(self.peek(), "Expected literal")
def constraint_operator(self) -> Token:
if self.match(TokenType.LESS, TokenType.LESS_EQUAL, TokenType.GREATER, TokenType.GREATER_EQUAL, TokenType.EQUAL_EQUAL, TokenType.BANG_EQUAL):
return self.previous()
raise self.error(self.peek(), "Expected constraint operator")
def schema(self) -> SchemaExpr:
"""Parse a schema definition
A comma separated list of schema elements
Returns:
SchemaExpr: the parsed schema expression
"""
left: Token = self.previous()
elements: list[Expr] = []
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
elements.append(self.schema_element())
if not self.check(TokenType.RIGHT_BRACKET):
self.consume(TokenType.COMMA, "Expected ',' between schema elements")
right: Token = self.consume(TokenType.RIGHT_BRACKET, "Unclosed schema")
return SchemaExpr(left=left, elements=elements, right=right)
def schema_element(self) -> SchemaElementExpr:
"""Parse a schema element
An anonymous element (`_`), a type, an untyped named column (`name: _`),
or a named column (`name: Type`)
Returns:
SchemaElementExpr: the parsed schema element expression
"""
if self.match(TokenType.UNDERSCORE):
return SchemaElementExpr(name=None, type=None)
if not self.check(TokenType.IDENTIFIER):
raise self.error(self.peek(), "Expected schema element")
name: Optional[Token] = None
type: Optional[TypeExpr] = None
if self.check_next(TokenType.COLON):
name = self.advance()
self.advance()
if not self.match(TokenType.UNDERSCORE):
type = self.type_expr()
return SchemaElementExpr(name=name, type=type)

View File

@@ -1,16 +1,24 @@
from typing import Optional from typing import Optional
from core.ast.midas import ( from core.ast.midas import (
ConstraintExpr, BinaryExpr,
ConstraintStmt, ComplexTypeStmt,
Expr, Expr,
ExtendStmt,
GetExpr,
GroupingExpr,
LiteralExpr, LiteralExpr,
LogicalExpr,
OpStmt, OpStmt,
PredicateStmt,
PropertyStmt, PropertyStmt,
SimpleTypeExpr,
SimpleTypeStmt,
Stmt, Stmt,
TypeBodyExpr, TemplateExpr,
TypeExpr, TypeExpr,
TypeStmt, UnaryExpr,
VariableExpr,
WildcardExpr, WildcardExpr,
) )
from lexer.token import Token, TokenType from lexer.token import Token, TokenType
@@ -21,7 +29,12 @@ from parser.errors import ParsingError
class MidasParser(Parser): class MidasParser(Parser):
"""A simple parser for midas type definitions""" """A simple parser for midas type definitions"""
SYNC_BOUNDARY: set[TokenType] = {TokenType.TYPE, TokenType.OP, TokenType.CONSTRAINT} SYNC_BOUNDARY: set[TokenType] = {
TokenType.TYPE,
TokenType.OP,
TokenType.EXTEND,
TokenType.PREDICATE,
}
def parse(self) -> list[Stmt]: def parse(self) -> list[Stmt]:
statements: list[Stmt] = [] statements: list[Stmt] = []
@@ -58,16 +71,16 @@ class MidasParser(Parser):
try: try:
if self.match(TokenType.TYPE): if self.match(TokenType.TYPE):
return self.type_declaration() return self.type_declaration()
if self.match(TokenType.OP): if self.match(TokenType.EXTEND):
return self.op_declaration() return self.extend_declaration()
if self.match(TokenType.CONSTRAINT): if self.match(TokenType.PREDICATE):
return self.constraint_declaration() return self.predicate_declaration()
raise self.error(self.peek(), "Unexpected token") raise self.error(self.peek(), "Unexpected token")
except ParsingError: except ParsingError:
self.synchronize() self.synchronize()
return None return None
def type_declaration(self) -> TypeStmt: def type_declaration(self) -> SimpleTypeStmt | ComplexTypeStmt:
"""Parse a type declaration """Parse a type declaration
A type declaration is written `type Name<TypeExpr, ...>` optionally followed by a brace-wrapped body A type declaration is written `type Name<TypeExpr, ...>` optionally followed by a brace-wrapped body
@@ -76,19 +89,28 @@ class MidasParser(Parser):
TypeStmt: the parsed type declaration statement TypeStmt: the parsed type declaration statement
""" """
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
self.consume(TokenType.LESS, "Expected '<' after type name") template: Optional[TemplateExpr] = None
bases: list[TypeExpr] = [] if self.check(TokenType.LEFT_BRACKET):
while not self.check(TokenType.GREATER) and not self.is_at_end(): template = self.template_expr()
bases.append(self.type_expr())
if not self.check(TokenType.GREATER):
self.consume(TokenType.COMMA, "Expected ',' between type bases")
self.consume(TokenType.GREATER, "Expected '>' after base type")
body: Optional[TypeBodyExpr] = None if self.match(TokenType.LEFT_PAREN):
base: TypeExpr = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Unclosed base type parenthesis")
constraint: Optional[Expr] = None
if self.match(TokenType.WHERE):
constraint = self.constraint()
return SimpleTypeStmt(
name=name, template=template, base=base, constraint=constraint
)
else:
properties: list[PropertyStmt] = self.type_properties()
return ComplexTypeStmt(name=name, template=template, properties=properties)
if self.check(TokenType.LEFT_BRACE): def template_expr(self) -> TemplateExpr:
body = self.type_body_expr() self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
return TypeStmt(name=name, bases=bases, body=body) type: TypeExpr = self.type_expr()
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
return TemplateExpr(type=type)
def type_expr(self) -> TypeExpr: def type_expr(self) -> TypeExpr:
"""Parse a type expression """Parse a type expression
@@ -97,33 +119,66 @@ class MidasParser(Parser):
TypeExpr: the parsed type expression TypeExpr: the parsed type expression
""" """
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
constraints: list[ConstraintExpr] = [] template: Optional[TemplateExpr] = None
if self.check(TokenType.LEFT_BRACKET):
template = self.template_expr()
optional: bool = self.match(TokenType.QMARK)
return TypeExpr(name=name, template=template, optional=optional)
while not self.is_at_end() and self.match(TokenType.PLUS): def simple_type_expr(self) -> SimpleTypeExpr:
self.consume(TokenType.LEFT_PAREN, "Expected '(' before type constraint") name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
constraints.append(self.constraint_expr()) optional: bool = self.match(TokenType.QMARK)
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after type constraint") return SimpleTypeExpr(name=name, optional=optional)
return TypeExpr(name=name, constraints=constraints) def constraint(self) -> Expr:
return self.and_()
def constraint_expr(self) -> ConstraintExpr: def and_(self) -> Expr:
"""Parse a type constraint expr: Expr = self.equality()
while self.match(TokenType.AND):
operator: Token = self.previous()
right: Expr = self.equality()
expr = LogicalExpr(left=expr, operator=operator, right=right)
return expr
Returns: def equality(self) -> Expr:
ConstraintExpr: the parsed type constraint expression expr: Expr = self.comparison()
""" while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
operator: Token = self.previous()
right: Expr = self.comparison()
expr = BinaryExpr(left=expr, operator=operator, right=right)
return expr
left: Expr = self.constraint_value() def comparison(self) -> Expr:
op: Token = self.constraint_operator() expr: Expr = self.unary()
right: Expr = self.constraint_value() while self.match(
return ConstraintExpr(left=left, op=op, right=right) TokenType.LESS,
TokenType.LESS_EQUAL,
TokenType.GREATER,
TokenType.GREATER_EQUAL,
):
operator: Token = self.previous()
right: Expr = self.unary()
expr = BinaryExpr(left=expr, operator=operator, right=right)
return expr
def constraint_value(self) -> Expr: def unary(self) -> Expr:
if self.match(TokenType.UNDERSCORE): if self.match(TokenType.MINUS):
return WildcardExpr(self.previous()) operator: Token = self.previous()
return self.literal() right: Expr = self.unary()
return UnaryExpr(operator=operator, right=right)
return self.reference()
def literal(self) -> LiteralExpr: def reference(self) -> Expr:
expr: Expr = self.primary()
while self.match(TokenType.DOT):
name: Token = self.consume(
TokenType.IDENTIFIER, "Expected property name after '.'"
)
expr = GetExpr(expr=expr, name=name)
return expr
def primary(self) -> Expr:
if self.match(TokenType.FALSE): if self.match(TokenType.FALSE):
return LiteralExpr(False) return LiteralExpr(False)
if self.match(TokenType.TRUE): if self.match(TokenType.TRUE):
@@ -134,35 +189,34 @@ class MidasParser(Parser):
if self.match(TokenType.NUMBER): if self.match(TokenType.NUMBER):
return LiteralExpr(self.previous().value) return LiteralExpr(self.previous().value)
raise self.error(self.peek(), "Expected literal") if self.match(TokenType.IDENTIFIER):
return VariableExpr(self.previous())
def constraint_operator(self) -> Token: if self.match(TokenType.UNDERSCORE):
if self.match( return WildcardExpr(self.previous())
TokenType.LESS,
TokenType.LESS_EQUAL,
TokenType.GREATER,
TokenType.GREATER_EQUAL,
TokenType.EQUAL_EQUAL,
TokenType.BANG_EQUAL,
):
return self.previous()
raise self.error(self.peek(), "Expected constraint operator")
def type_body_expr(self) -> TypeBodyExpr: if self.match(TokenType.LEFT_PAREN):
expr: Expr = self.constraint()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
return GroupingExpr(expr)
raise self.error(self.peek(), "Expected expression")
def type_properties(self) -> list[PropertyStmt]:
"""Parse a type definition body """Parse a type definition body
A type definition body is a set of whitespace-separated A type definition body is a set of whitespace-separated
property statements enclosed in curly braces property statements enclosed in curly braces
Returns: Returns:
TypeBodyExpr: the parsed type body expression TypeBodyStmt: the parsed type body expression
""" """
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body") self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body")
properties: list[PropertyStmt] = [] properties: list[PropertyStmt] = []
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end(): while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
properties.append(self.property_stmt()) properties.append(self.property_stmt())
self.consume(TokenType.RIGHT_BRACE, "Unclosed type body") self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
return TypeBodyExpr(properties=properties) return properties
def property_stmt(self) -> PropertyStmt: def property_stmt(self) -> PropertyStmt:
"""Parse a property statement """Parse a property statement
@@ -175,7 +229,19 @@ class MidasParser(Parser):
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
self.consume(TokenType.COLON, "Expected ':' after property name") self.consume(TokenType.COLON, "Expected ':' after property name")
type: TypeExpr = self.type_expr() type: TypeExpr = self.type_expr()
return PropertyStmt(name=name, type=type) constraint: Optional[Expr] = None
if self.match(TokenType.WHERE):
constraint = self.constraint()
return PropertyStmt(name=name, type=type, constraint=constraint)
def extend_declaration(self) -> ExtendStmt:
type: TypeExpr = self.type_expr()
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
operations: list[OpStmt] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
operations.append(self.op_declaration())
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
return ExtendStmt(type=type, operations=operations)
def op_declaration(self) -> OpStmt: def op_declaration(self) -> OpStmt:
"""Parse an operation definition """Parse an operation definition
@@ -185,25 +251,19 @@ class MidasParser(Parser):
Returns: Returns:
OpStmt: the parsed operation statement OpStmt: the parsed operation statement
""" """
self.consume(TokenType.LESS, "Expected '<' before first type") self.consume(TokenType.OP, "Expected 'op' keyword")
left: TypeExpr = self.type_expr()
self.consume(TokenType.GREATER, "Expected '>' after first type")
op: Token = self.advance() name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
operand: TypeExpr = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
self.consume(TokenType.LESS, "Expected '<' before second type") self.consume(TokenType.ARROW, "Expected '->' before result type")
right: TypeExpr = self.type_expr()
self.consume(TokenType.GREATER, "Expected '>' after second type")
self.consume(TokenType.EQUAL, "Expected '=' after second type")
self.consume(TokenType.LESS, "Expected '<' before result type")
result: TypeExpr = self.type_expr() result: TypeExpr = self.type_expr()
self.consume(TokenType.GREATER, "Expected '>' after result type")
return OpStmt(left=left, op=op, right=right, result=result) return OpStmt(name=name, operand=operand, result=result)
def constraint_declaration(self) -> ConstraintStmt: def predicate_declaration(self) -> PredicateStmt:
"""Parse a type constraint declaration """Parse a type constraint declaration
A constraint is written `constraint Name = constraint_expression` A constraint is written `constraint Name = constraint_expression`
@@ -211,7 +271,12 @@ class MidasParser(Parser):
Returns: Returns:
ConstraintStmt: the parsed constraint declaration statement ConstraintStmt: the parsed constraint declaration statement
""" """
name: Token = self.consume(TokenType.IDENTIFIER, "Expected constraint name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
self.consume(TokenType.EQUAL, "Expected '=' after constraint name") self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
constraint: ConstraintExpr = self.constraint_expr() subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
return ConstraintStmt(name=name, constraint=constraint) self.consume(TokenType.COLON, "Expected ':' after subject name")
type: TypeExpr = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint()
return PredicateStmt(name=name, subject=subject, type=type, condition=condition)

View File

@@ -1,26 +1,35 @@
identifier ::= '[a-zA-Z][a-zA-Z_]*' // W3C EBNF syntax definition for Midas
Identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
integer ::= '\d+' Integer ::= '\d+'
number ::= integer ["." integer] Number ::= "-"? Integer ("." Integer)?
boolean ::= "False" | "True" Boolean ::= "False" | "True"
none ::= "None" None ::= "None"
value ::= number | boolean | none Value ::= Number | Boolean | None
lambda-value ::= "_" | value
lambda-operator ::= ">" | "<" | ">=" | "<=" | "==" | "!="
lambda ::= lambda-value lambda-operator lambda-value
constraint ::= identifier | "(" lambda ")" ComparisonOp ::= ">" | "<" | ">=" | "<="
base-type ::= identifier EqualityOp ::= "==" | "!="
type ::= base-type { "+" constraint }
type-property ::= 'identifier' ":" 'type' Grouping ::= "(" Constraint ")"
type-body ::= "{" { 'type-property' } "}" Primary ::= "_" | Value | Identifier | Grouping
Reference ::= Primary ("." Identifier)*
Unary ::= "-"? Unary | Reference
Comparison ::= Unary (ComparisonOp Unary)*
Equality ::= Comparison (EqualityOp Comparison)*
Constraint ::= Equality ("&" Equality)*
operation-type ::= "<" 'type' ">" SimpleType ::= Identifier "?"?
Template ::= "[" Type "]"
Type ::= Identifier Template? "?"?
type-statement ::= "type" 'identifier' "<" 'type' {"," 'type'} ">" ['type-body'] TypeProperty ::= Identifier ":" Type ("where" Constraints)?
operation-statement ::= "op" 'operation-type' 'operator' 'operation-type' "=" 'operation-type' ComplexTypeBody ::= "{" TypeProperty* "}"
constraint-statement ::= "constraint" 'identifier' "=" 'lambda' OpDefinition ::= "op" Identifier "(" Type ")" "->" Type
ExtendBody ::= "{" OpDefinition* "}"
statement ::= type-statement | operation-statement | constraint-statement TypeStatement ::= "type" Identifier Template? ("(" Type ")" ("where" Constraint)? | ComplexTypeBody)
ExtendStatement ::= "extend" Type ExtendBody
PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint
Statement ::= TypeStatement | ExtendStatement | PredicateStatement

View File

@@ -1,4 +1,11 @@
#import "@preview/fervojo:0.1.1": render #import "@preview/fervojo:0.1.1": default-css, render
#let extra-css = ```css
svg.railroad .terminal rect {
fill: #F7DCD4;
}
```
#let css = default-css() + bytes(extra-css.text)
#let value = ``` #let value = ```
{[`value` < {[`value` <
@@ -8,90 +15,157 @@
>]} >]}
``` ```
#let constraint = ``` #let grouping = ```
{[`constraint` <"_", 'value'> <">", "<", ">=", "<=", "==", "!="> <"_", 'value'>]} {[`grouping` "(" 'constraint' ")"]}
``` ```
#let type-with-constraints = ``` #let primary = ```
{[`type-with-constraints` 'identifier' <!, ["+" "(" 'constraint' ")"] * !>]} {[`primary` <"_", 'value', 'identifier', 'grouping'>]}
```
#let reference = ```
{[`reference` 'primary' <!, ["." 'identifier']*!>]}
```
#let unary = ```
{[`unary` <[<!, "-"> 'unary'], 'reference'>]}
```
#let comparison = ```
{[`comparison` 'unary'*<">", "<", ">=", "<=">]}
```
#let equality = ```
{[`equality` 'comparison'*<"==", "!=">]}
```
#let constraint = ```
{[`constraint` 'equality'*"&"]}
```
#let simple-type = ```
{[`simple-type` 'identifier' <!, "?">]}
```
#let template = ```
{[`template` "[" 'type' "]"]}
```
#let type = ```
{[`type` 'identifier' <!, 'template'> <!, "?">]}
``` ```
#let type-property = ``` #let type-property = ```
{[`type-property` 'identifier' ":" 'type-with-constraints']} {[`type-property` 'identifier' ":" 'type' <!, ["where" 'constraint']>]}
``` ```
#let type-body = ``` #let type-body = ```
{[`type-body` "{" <!, 'type-property'*!> "}"]} {[`type-body` "{" <!, 'type-property'*!> "}"]}
``` ```
#let operation-type = ```
{[`operation-type` "<" 'type-with-constraints' ">"]}
```
#let type-statement = ``` #let type-statement = ```
{[`type-statement` "type" 'identifier' "<" 'type-with-constraints'*"," ">" <!, 'type-body'>]} {[`type-statement` "type" 'identifier' <!, 'template'> <[["(" 'type' ")"] <!, ["where" 'constraint']>], 'type-body'>]}
``` ```
#let operation-statement = ``` #let op-definition = ```
{[`operation-statement` "op" 'operation-type' "operator" 'operation-type' "=" 'operation-type']} {[`op-definition` "op" 'identifier' "(" 'type' ")" "->" 'type']}
``` ```
#let constraint-statement = ``` #let extend-statement = ```
{[`constraint-statement` "constraint" 'identifier' "=" 'constraint']} {[`extend-statement` "extend" 'type' "{" <!, 'op-definition'*!> "}"]}
```
#let predicate-statement = ```
{[`predicate-statement` "predicate" 'identifier' "(" 'identifier' ":" 'type' ")" "=" 'constraint']}
``` ```
#let statement = ``` #let statement = ```
{[`statement` <'type-statement', 'operation-statement', 'constraint-statement'>]} {[`statement` <'type-statement', 'extend-statement', 'predicate-statement'>]}
``` ```
#let rules = ( #let rules = (
value, value: value,
constraint, grouping: grouping,
type-with-constraints, primary: primary,
type-property, reference: reference,
type-body, unary: unary,
operation-type, comparison: comparison,
type-statement, equality: equality,
operation-statement, constraint: constraint,
constraint-statement, simple-type: simple-type,
statement, template: template,
type: type,
type-property: type-property,
type-body: type-body,
type-statement: type-statement,
op-definition: op-definition,
extend-statement: extend-statement,
predicate-statement: predicate-statement,
statement: statement,
)
#let inline = (
"grouping",
"value",
"template",
"simple-type",
"type-property",
"type-body",
"op-definition",
"type-statement",
"extend-statement",
"predicate-statement",
) )
#set text(font: "Source Sans 3") #set text(font: "Source Sans 3")
= Midas type definition syntax #title[Midas type definition syntax]
#for rule in rules { = Outline
render(rule)
}
/* #box(
#let by-name = ( columns(
value: value, 2,
constraint: constraint, outline(title: none),
type-with-constraints: type-with-constraints, ),
type-property: type-property, height: 9cm,
type-body: type-body, stroke: 1pt,
operation-type: operation-type, inset: 1em,
type-statement: type-statement,
operation-statement: operation-statement,
constraint-statement: constraint-statement,
) )
= Statements and expressions
#for (name, rule) in rules.pairs().rev() {
[== #name]
render(rule, css: css)
}
#let substitute(base-rule) = { #let substitute(base-rule) = {
let new-rule = base-rule let new-rule = base-rule
for (key, rule) in by-name.pairs() { for name in inline {
new-rule = new-rule.replace("'" + key + "'", rule.text.slice(1, -1)) let rule = rules.at(name)
let replacement = rule.text.slice(1, -1).replace(regex("\[`.*?`"), "[")
replacement = "[" + replacement + "#`" + name + "`]"
new-rule = new-rule.replace(
"'" + name + "'",
replacement,
)
} }
if new-rule != base-rule { if new-rule != base-rule {
new-rule = substitute(new-rule) new-rule = substitute(new-rule)
} }
return new-rule.replace(regex("`.*?`"), "") return new-rule
} }
#let combined = raw(substitute(statement.text))
#set page(flipped: true) #set page(flipped: true)
#render(combined)
*/ = Combined rules
#for (name, rule) in rules.pairs() {
if not name in inline {
[== #name]
let combined = substitute(rule.text)
render(raw(combined), css: css)
//raw(block: true, combined)
}
}

29
test.py
View File

@@ -1,40 +1,21 @@
import importlib import json
from pathlib import Path from pathlib import Path
from core.ast.printer import AnnotationAstPrinter, MidasAstPrinter from core.ast.printer import MidasAstPrinter
from lexer.annotations import AnnotationLexer
from lexer.midas import MidasLexer from lexer.midas import MidasLexer
from lexer.token import Token from lexer.token import Token
from parser.annotations import AnnotationParser
from parser.midas import MidasParser from parser.midas import MidasParser
def test_annotation():
# Frame annotation
mod = importlib.import_module("examples.00_syntax_prototype.01_simple_types")
annotation: str = mod.__annotations__["df"]
lexer: AnnotationLexer = AnnotationLexer(annotation, "01_simple_types.py")
tokens: list[Token] = lexer.process()
# print([f"{t.type.name}('{t.lexeme}')" for t in tokens])
parser = AnnotationParser(tokens)
parsed = parser.parse()
print(parsed)
for err in parser.errors:
print(err.get_report())
printer = AnnotationAstPrinter()
if parsed is not None:
print(printer.print(parsed))
def test_midas(): def test_midas():
# Midas type definitions # Midas type definitions
path: Path = Path("examples") / "00_syntax_prototype" / "02_custom_types.midas" path: Path = Path("examples") / "00_syntax_prototype" / "03_custom_types_v2.midas"
definitions: str = path.read_text() definitions: str = path.read_text()
midas_lexer: MidasLexer = MidasLexer(definitions, path.name) midas_lexer: MidasLexer = MidasLexer(definitions, path.name)
tokens: list[Token] = midas_lexer.process() tokens: list[Token] = midas_lexer.process()
# print([f"{t.type.name}('{t.lexeme}')" for t in tokens]) # print([f"{t.type.name}('{t.lexeme}')" for t in tokens])
with open("tokens.json", "w") as f:
json.dump([f"{t.type.name}('{t.lexeme}')" for t in tokens], f, indent=4)
parser = MidasParser(tokens) parser = MidasParser(tokens)
parsed = parser.parse() parsed = parser.parse()

204
tester.py Normal file
View File

@@ -0,0 +1,204 @@
from __future__ import annotations
import argparse
import difflib
import json
import sys
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Iterator, Optional
from core.ast.json_serializer import AstJsonSerializer
from core.ast.midas import Stmt
from lexer.base import MidasSyntaxError
from lexer.midas import MidasLexer
from lexer.token import Token
from parser.midas import MidasParser
DEFAULT_BASE_DIR: Path = Path() / "tests"
@dataclass
class CaseResult:
tokens: Optional[list[dict]] = None
stmts: Optional[list[dict]] = None
errors: list[dict] = field(default_factory=list)
def dumps(self) -> str:
return json.dumps(asdict(self), indent=2)
class Tester:
"""A test runner to check for regressions in the lexer and parser"""
def __init__(self, base_dir: Path):
self.base_dir: Path = base_dir
def _list_tests(self) -> list[Path]:
return list(self.base_dir.rglob("*.midas"))
def run_all_tests(self) -> bool:
paths: list[Path] = self._list_tests()
return self.run_tests(paths)
def run_tests(self, tests: list[Path]) -> bool:
rule: str = "-" * 80
n: int = len(tests)
successes: int = 0
failures: int = 0
print(rule)
for i, test in enumerate(tests):
print(f"Case {i+1}/{n}: {test}")
success: bool = self._run_test(test)
if success:
successes += 1
else:
failures += 1
print(rule)
print(f"Success: {successes}/{n}")
print(f"Failed: {failures}/{n}")
print(rule)
return failures == 0
def _run_test(self, path: Path) -> bool:
result: CaseResult = self._exec_case(path)
result_path: Path = self._result_path(path)
expected: str = result_path.read_text()
actual: str = result.dumps()
if expected == actual:
return True
diff = difflib.unified_diff(
expected.splitlines(keepends=True),
actual.splitlines(keepends=True),
fromfile="Snapshot",
tofile="Result",
)
self._print_diff(diff)
return False
def _exec_case(self, path: Path) -> CaseResult:
if not path.exists():
raise FileNotFoundError(f"Could not find test '{path}'")
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
result: CaseResult = CaseResult()
content: str = path.read_text()
lexer: MidasLexer = MidasLexer(content)
tokens: list[Token] = []
try:
tokens = lexer.process()
result.tokens = [
{
"type": token.type.name,
"lexeme": token.lexeme,
"line": token.position.line,
"column": token.position.column,
}
for token in tokens
]
except MidasSyntaxError as e:
result.errors.append(
{
"type": "SyntaxError",
"line": e.pos.line,
"column": e.pos.column,
"message": e.message,
}
)
return result
parser: MidasParser = MidasParser(tokens)
stmts: list[Stmt] = parser.parse()
result.stmts = AstJsonSerializer().serialize(stmts)
result.errors.extend(
[
{
"line": e.token.position.line,
"column": e.token.position.column,
"message": e.message,
}
for e in parser.errors
]
)
return result
def update_all_tests(self):
paths: list[Path] = self._list_tests()
return self.update_tests(paths)
def update_tests(self, tests: list[Path]):
updated: int = 0
for test in tests:
if self._update_test(test):
updated += 1
print(f"Updated {updated}/{len(tests)} tests")
def _update_test(self, path: Path) -> bool:
result: CaseResult = self._exec_case(path)
result_path: Path = self._result_path(path)
current: str = result_path.read_text()
new: str = result.dumps()
if current == new:
return False
result_path.write_text(new)
return True
def _result_path(self, test_path: Path) -> Path:
return test_path.parent / (test_path.name + ".ref.json")
def _print_diff(self, diff: Iterator[str]):
for line in diff:
if line.startswith("+") and not line.startswith("+++"):
print(f"\033[92m{line}\033[0m", end="")
elif line.startswith("-") and not line.startswith("---"):
print(f"\033[91m{line}\033[0m", end="")
else:
print(line, end="")
print()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-D",
"--base-dir",
help="Base directory containing test files",
type=Path,
default=DEFAULT_BASE_DIR,
)
subparsers = parser.add_subparsers(dest="subcommand")
update = subparsers.add_parser("update")
update.add_argument("-a", "--all", action="store_true")
update.add_argument("FILE", type=Path, nargs="*")
run = subparsers.add_parser("run")
run.add_argument("-a", "--all", action="store_true")
run.add_argument("FILE", type=Path, nargs="*")
args = parser.parse_args()
tester: Tester = Tester(args.base_dir)
match args.subcommand:
case "update":
if args.all:
tester.update_all_tests()
else:
tester.update_tests(args.FILE)
case "run":
success: bool
if args.all:
success = tester.run_all_tests()
else:
success = tester.run_tests(args.FILE)
if not success:
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,57 @@
// Simple custom type derived from float
type Custom(float)
// Simple custom types with constraints
type Latitude(float) where (-90 <= _ <= 90)
type Longitude(float) where (-180 <= _ <= 180)
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
type Difference[T](T)
// Complex custom type, containing two values accessible through properties
type GeoLocation {
lat: Latitude
lon: Longitude
}
// Define operations on our custom type
extend GeoLocation {
// This type is compatible with the `-` operation with another GeoLocation
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
// in a Difference of GeoLocations
op __sub__(GeoLocation) -> Difference[GeoLocation]
}
// For complex generics, you need to specify how the genericity the properties
// are handled
type Difference[GeoLocation] {
lat: Difference[Latitude]
lon: Difference[Longitude]
}
// Simple operation defined on our custom types
extend Latitude {
op __sub__(Latitude) -> Difference[Latitude]
}
extend Longitude {
op __sub__(Longitude) -> Difference[Longitude]
}
// Predefined custom predicates that can be referenced in other definitions
predicate Positive(v: float) = v >= 0
predicate StrictlyPositive(v: float) = v > 0
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
type Person {
name: str
// Property with an inline constraint
age: int? where (0 <= _ < 150)
// Property referencing a predicate
height: float where StrictlyPositive
home: GeoLocation
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,129 +0,0 @@
from typing import Any
import pytest
from lexer.annotations import AnnotationLexer
from lexer.token import Token, TokenType
def scan(source: str) -> list[Token]:
return AnnotationLexer(source).process()
def assert_n_tokens(tokens: list[Token], n: int):
assert len(tokens) == n + 1
assert tokens[-1].type == TokenType.EOF
@pytest.mark.parametrize(
"src,expected",
[
("(", TokenType.LEFT_PAREN),
(")", TokenType.RIGHT_PAREN),
("[", TokenType.LEFT_BRACKET),
("]", TokenType.RIGHT_BRACKET),
(":", TokenType.COLON),
(",", TokenType.COMMA),
("_", TokenType.UNDERSCORE),
],
)
def test_punctuation(src: str, expected: TokenType):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected
@pytest.mark.parametrize(
"src,expected",
[
("+", TokenType.PLUS),
(">", TokenType.GREATER),
(">=", TokenType.GREATER_EQUAL),
("<", TokenType.LESS),
("<=", TokenType.LESS_EQUAL),
("=", TokenType.EQUAL),
("==", TokenType.EQUAL_EQUAL),
("!=", TokenType.BANG_EQUAL),
],
)
def test_operators(src: str, expected: TokenType):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected
@pytest.mark.parametrize(
"src,expected",
[
("a", TokenType.IDENTIFIER),
("foo", TokenType.IDENTIFIER),
("foo1", TokenType.IDENTIFIER),
("foo_", TokenType.IDENTIFIER),
("foo_bar1_baz2", TokenType.IDENTIFIER),
("FOO_BAR1_BAZ2", TokenType.IDENTIFIER),
("True", TokenType.TRUE),
("False", TokenType.FALSE),
("None", TokenType.NONE),
],
)
def test_identifiers_keywords(src: str, expected: TokenType):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected
@pytest.mark.parametrize(
"src,expected",
[
("#", TokenType.COMMENT),
("# This is a comment", TokenType.COMMENT),
(" ", TokenType.WHITESPACE),
("\t", TokenType.WHITESPACE),
("\r", TokenType.WHITESPACE),
(" \t \t", TokenType.WHITESPACE),
("\n", TokenType.NEWLINE),
],
)
def test_misc(src: str, expected: TokenType):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected
@pytest.mark.parametrize(
"src,expected_type,expected_value",
[
("0", TokenType.NUMBER, 0),
("0.0", TokenType.NUMBER, 0),
("1234.56", TokenType.NUMBER, 1234.56),
],
)
def test_literals(src: str, expected_type: TokenType, expected_value: Any):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected_type
assert tokens[0].value == expected_value
def test_single_bang_error():
with pytest.raises(SyntaxError):
scan("!")
@pytest.mark.parametrize(
"src",
[
"-",
"*",
"/",
"{",
"}",
"@",
'"',
"'",
".",
],
)
def test_unexpected_character(src: str):
with pytest.raises(SyntaxError):
scan(src)

View File

@@ -1,129 +0,0 @@
from typing import Any
import pytest
from lexer.midas import MidasLexer
from lexer.token import Token, TokenType
def scan(source: str) -> list[Token]:
return MidasLexer(source).process()
def assert_n_tokens(tokens: list[Token], n: int):
assert len(tokens) == n + 1
assert tokens[-1].type == TokenType.EOF
@pytest.mark.parametrize(
"src,expected",
[
("(", TokenType.LEFT_PAREN),
(")", TokenType.RIGHT_PAREN),
("[", TokenType.LEFT_BRACKET),
("]", TokenType.RIGHT_BRACKET),
("{", TokenType.LEFT_BRACE),
("}", TokenType.RIGHT_BRACE),
(":", TokenType.COLON),
(",", TokenType.COMMA),
("_", TokenType.UNDERSCORE),
],
)
def test_punctuation(src: str, expected: TokenType):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected
@pytest.mark.parametrize(
"src,expected",
[
("+", TokenType.PLUS),
("-", TokenType.MINUS),
("*", TokenType.STAR),
("/", TokenType.SLASH),
(">", TokenType.GREATER),
(">=", TokenType.GREATER_EQUAL),
("<", TokenType.LESS),
("<=", TokenType.LESS_EQUAL),
("=", TokenType.EQUAL),
("==", TokenType.EQUAL_EQUAL),
("!=", TokenType.BANG_EQUAL),
],
)
def test_operators(src: str, expected: TokenType):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected
@pytest.mark.parametrize(
"src,expected",
[
("a", TokenType.IDENTIFIER),
("foo", TokenType.IDENTIFIER),
("foo1", TokenType.IDENTIFIER),
("foo_", TokenType.IDENTIFIER),
("foo_bar1_baz2", TokenType.IDENTIFIER),
("FOO_BAR1_BAZ2", TokenType.IDENTIFIER),
("true", TokenType.TRUE),
("false", TokenType.FALSE),
("none", TokenType.NONE),
],
)
def test_identifiers_keywords(src: str, expected: TokenType):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected
@pytest.mark.parametrize(
"src,expected",
[
("// This is a comment", TokenType.COMMENT),
("/* This is a comment */", TokenType.COMMENT),
(" ", TokenType.WHITESPACE),
("\t", TokenType.WHITESPACE),
("\r", TokenType.WHITESPACE),
(" \t \t", TokenType.WHITESPACE),
("\n", TokenType.NEWLINE),
],
)
def test_misc(src: str, expected: TokenType):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected
@pytest.mark.parametrize(
"src,expected_type,expected_value",
[
("0", TokenType.NUMBER, 0),
("0.0", TokenType.NUMBER, 0),
("1234.56", TokenType.NUMBER, 1234.56),
],
)
def test_literals(src: str, expected_type: TokenType, expected_value: Any):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected_type
assert tokens[0].value == expected_value
def test_single_bang_error():
with pytest.raises(SyntaxError):
scan("!")
@pytest.mark.parametrize(
"src",
[
"@",
'"',
"'",
".",
],
)
def test_unexpected_character(src: str):
with pytest.raises(SyntaxError):
scan(src)

View File

@@ -1,130 +0,0 @@
from typing import Optional
import pytest
from core.ast.annotations import (
AnnotationStmt,
ConstraintExpr,
Expr,
LiteralExpr,
SchemaElementExpr,
SchemaExpr,
Stmt,
TypeExpr,
WildcardExpr,
)
from lexer.annotations import AnnotationLexer
from lexer.position import Position
from lexer.token import Token
from parser.annotations import AnnotationParser
class AstSerializer(Stmt.Visitor[str], Expr.Visitor[str]):
def serialize(self, stmt: Stmt):
return stmt.accept(self)
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> str:
schema: str = ""
if stmt.schema is not None:
schema = " " + stmt.schema.accept(self)
return f"(annotation {stmt.name.lexeme}{schema})"
def visit_schema_expr(self, expr: SchemaExpr) -> str:
elements: list[str] = [elmt.accept(self) for elmt in expr.elements]
return f"(schema {' '.join(elements)})"
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> str:
name: str = expr.name.lexeme if expr.name is not None else "_"
type: str = expr.type.accept(self) if expr.type is not None else "_"
return f"({name} {type})"
def visit_type_expr(self, expr: TypeExpr) -> str:
res: str = f"({expr.name.lexeme}"
for constraint in expr.constraints:
res += " " + constraint.accept(self)
res += ")"
return res
def visit_constraint_expr(self, expr: ConstraintExpr) -> str:
return f"(constraint {expr.left.accept(self)} {expr.op.lexeme} {expr.right.accept(self)})"
def visit_wildcard_expr(self, expr: WildcardExpr) -> str:
return "(_)"
def visit_literal_expr(self, expr: LiteralExpr) -> str:
return f"({expr.value})"
def parse(source: str) -> Optional[Stmt]:
tokens: list[Token] = AnnotationLexer(source).process()
return AnnotationParser(tokens).parse()
def must_parse(source: str) -> Stmt:
stmt: Optional[Stmt] = parse(source)
assert stmt is not None
return stmt
def ast_str(source: str) -> str:
stmt: Stmt = must_parse(source)
return AstSerializer().serialize(stmt)
@pytest.mark.parametrize(
"src,expected",
[
("Type", "(annotation Type)"),
("Type[]", "(annotation Type (schema ))"),
(
"""
Frame[
verified: bool,
birth_year: int,
height: float + ( _ > 0 ) + ( _ < 250 ),
name: str,
date: datetime,
float, # unnamed
unknown: _, # untyped
_ # unnamed and untyped
]
""",
"(annotation Frame (schema (verified (bool)) (birth_year (int)) (height (float (constraint (_) > (0.0)) (constraint (_) < (250.0)))) (name (str)) (date (datetime)) (_ (float)) (unknown _) (_ _)))",
),
],
)
def test_expressions(src: str, expected: str):
assert ast_str(src) == expected
@pytest.mark.parametrize(
"src,pos,should_fail",
[
("", (1, 1), True),
("42", (1, 1), True),
("True", (1, 1), True),
("Type[", (1, 6), True),
("Type[] Type2", (1, 8), False),
("Type[bool:]", (1, 11), True),
("Type[3]", (1, 6), True),
("Type[bool float]", (1, 11), True),
("Type[bool (_ < 2)]", (1, 11), True),
("Type[bool + _ < 2)]", (1, 13), True),
("Type[bool + (_ < 2]", (1, 19), True),
("Type[bool + (< 2)]", (1, 14), True),
("Type[bool + (_ + 2)]", (1, 16), True),
("Type[bool + (Foo + Bar)]", (1, 14), True),
# ("Type[bool,]", (1, 11), True), # trailing comma is accepted, TODO: update parser or EBNF
("Type[bool, Type[]]", (1, 16), True),
("Type[foo: 3]", (1, 11), True),
],
)
def test_parsing_error(src: str, pos: tuple[int, int], should_fail: bool):
tokens: list[Token] = AnnotationLexer(src).process()
parser: AnnotationParser = AnnotationParser(tokens)
stmt: Optional[Stmt] = parser.parse()
if should_fail:
assert stmt is None
assert len(parser.errors) != 0
error_pos: Position = parser.errors[0].token.position
assert (error_pos.line, error_pos.column) == pos

View File

@@ -1,202 +0,0 @@
import textwrap
import pytest
from core.ast.midas import (
ConstraintExpr,
ConstraintStmt,
Expr,
LiteralExpr,
OpStmt,
PropertyStmt,
Stmt,
TypeBodyExpr,
TypeExpr,
TypeStmt,
WildcardExpr,
)
from lexer.midas import MidasLexer
from lexer.position import Position
from lexer.token import Token
from parser.midas import MidasParser
class AstSerializer(Stmt.Visitor[str], Expr.Visitor[str]):
def serialize(self, stmt: Stmt):
return stmt.accept(self)
def visit_type_stmt(self, stmt: TypeStmt) -> str:
res: str = f"(type_def {stmt.name.lexeme}"
for base in stmt.bases:
res += " " + base.accept(self)
if stmt.body is not None:
res += " " + stmt.body.accept(self)
res += ")"
return res
def visit_type_expr(self, expr: TypeExpr) -> str:
res: str = f"({expr.name.lexeme}"
for constraint in expr.constraints:
res += " " + constraint.accept(self)
res += ")"
return res
def visit_constraint_expr(self, expr: ConstraintExpr) -> str:
return f"(constraint {expr.left.accept(self)} {expr.op.lexeme} {expr.right.accept(self)})"
def visit_wildcard_expr(self, expr: WildcardExpr) -> str:
return "(_)"
def visit_literal_expr(self, expr: LiteralExpr) -> str:
return f"({expr.value})"
def visit_type_body_expr(self, expr: TypeBodyExpr) -> str:
res: str = "(body"
for prop in expr.properties:
res += " " + prop.accept(self)
res += ")"
return res
def visit_property_stmt(self, stmt: PropertyStmt) -> str:
return f"(property {stmt.name.lexeme} {stmt.type.accept(self)})"
def visit_op_stmt(self, stmt: OpStmt) -> str:
left: str = stmt.left.accept(self)
right: str = stmt.right.accept(self)
result: str = stmt.result.accept(self)
return f"(op_def {left} {stmt.op.lexeme} {right} {result})"
def visit_constraint_stmt(self, stmt: ConstraintStmt) -> str:
return f"(constraint_def {stmt.name.lexeme} {stmt.constraint.accept(self)})"
def parse(source: str) -> list[Stmt]:
tokens: list[Token] = MidasLexer(source).process()
return MidasParser(tokens).parse()
def ast_str(source: str) -> list[str]:
stmts: list[Stmt] = parse(source)
return [AstSerializer().serialize(stmt) for stmt in stmts]
@pytest.mark.parametrize(
"src,expected",
[
("type Foo<>", "(type_def Foo)"),
("type Foo<Bar>", "(type_def Foo (Bar))"),
("type Foo<Bar, Baz>", "(type_def Foo (Bar) (Baz))"),
(
"type Foo<Bar + (_ < 2), Baz>",
"(type_def Foo (Bar (constraint (_) < (2.0))) (Baz))",
),
(
"""
type Foo<> {
foo: Bar
}
""",
"(type_def Foo (body (property foo (Bar))))",
),
(
"""
type Foo<> {
foo: Bar + (_ != none)
foo2: Bar2 + (0 <= _) + (_ <= 100)
}
""",
"(type_def Foo (body (property foo (Bar (constraint (_) != (None)))) (property foo2 (Bar2 (constraint (0.0) <= (_)) (constraint (_) <= (100.0))))))",
),
("op <A> + <B> = <C>", "(op_def (A) + (B) (C))"),
(
"op <A + (_ < 100)> + <B + (_ < 100)> = <C + (_ < 200)>",
"(op_def (A (constraint (_) < (100.0))) + (B (constraint (_) < (100.0))) (C (constraint (_) < (200.0))))",
),
(
"constraint Positive = _ >= 0",
"(constraint_def Positive (constraint (_) >= (0.0)))",
),
],
)
def test_expressions(src: str, expected: str | list[str]):
if isinstance(expected, str):
expected = [expected]
assert ast_str(src) == expected
@pytest.mark.parametrize(
"src,pos",
[
###
# Misc
###
("42", (1, 1)),
("true", (1, 1)),
("foo", (1, 1)),
###
# Type statements
###
("type", (1, 5)),
("type true", (1, 6)),
("type Foo", (1, 9)),
("type Foo<1>", (1, 10)),
# ("type Foo<float,>", (1, 16)), # trailing comma is accepted, TODO: update parser or EBNF
("type Foo<float, 1>", (1, 17)),
("type Foo<float", (1, 15)),
("type Foo<float> { 3 }", (1, 19)),
(
"""
type Foo<float> {
foo
}
""",
(4, 1),
),
(
"""
type Foo<float> {
foo: 3
}
""",
(3, 10),
),
###
# Operation statements
###
("op", (1, 3)),
("op float", (1, 4)),
("op <", (1, 5)),
("op <float", (1, 10)),
("op <float>", (1, 11)),
("op <float> +", (1, 13)),
("op <float> + float", (1, 14)),
("op <float> + <", (1, 15)),
("op <float> + <float", (1, 20)),
("op <float> + <float>", (1, 21)),
("op <float> + <float> =", (1, 23)),
("op <float> + <float> = float", (1, 24)),
("op <float> + <float> = <", (1, 25)),
("op <float> + <float> = <float", (1, 30)),
("op <float + 3> + <float> = <float>", (1, 13)),
("op <float> + <float + 3> = <float>", (1, 23)),
("op <float> + <float> = <float + 3>", (1, 33)),
###
# Constraint statements
###
("constraint", (1, 11)),
("constraint 3", (1, 12)),
("constraint Foo", (1, 15)),
("constraint Foo =", (1, 17)),
("constraint Foo = 3", (1, 19)),
("constraint Foo = 3 <", (1, 21)),
],
)
def test_parsing_error(src: str, pos: tuple[int, int]):
src = textwrap.dedent(src)
tokens: list[Token] = MidasLexer(src).process()
parser: MidasParser = MidasParser(tokens)
stmt: list[Stmt] = parser.parse()
assert len(stmt) == 0
assert len(parser.errors) != 0
error_pos: Position = parser.errors[0].token.position
assert (error_pos.line, error_pos.column) == pos