12 Commits

Author SHA1 Message Date
429d0d98fe feat: update railroad diagrams with revised syntax 2026-05-21 07:53:56 +02:00
db8fe5d3ff feat: update EBNF with revised syntax 2026-05-21 07:53:40 +02:00
7477ec8d70 fix: change syntax definition to W3C EBNF 2026-05-20 15:47:34 +02:00
adf7f4e7a2 tests(parser): use new MidasSyntaxError 2026-05-20 15:46:25 +02:00
abf6787946 fix(parser)!: remove annotation lexer and parser 2026-05-20 15:45:55 +02:00
e282b08597 fix: tweak syntax examples
- move operation definitions outside GeoLocation type
- add nullable type
- list syntax choices for complex refinement
2026-05-20 14:14:01 +02:00
0a02b9d3d9 feat: revise syntax (example)
improve the syntax to better fit the principle of least surprise and Python syntax
2026-05-20 13:20:53 +02:00
875ca589e4 Merge pull request 'Improve testing framework' (#2) from feat/test-framework into main
Reviewed-on: #2
2026-05-20 11:17:00 +00:00
88f92d6e1f tests(parser): add simple types snapshot test 2026-05-19 14:12:12 +02:00
db4ed74365 tests(parser): add snapshot test runner
the diff printing function was suggested by Gemini

Co-authored-by: Gemini <noreply@gemini.google.com>
2026-05-19 14:11:32 +02:00
7cbf4fdece feat(tests): add AST JSON serializer 2026-05-19 14:00:32 +02:00
1fa9a09bfe feat(parser): use custom syntax error class 2026-05-19 13:57:00 +02:00
17 changed files with 1702 additions and 819 deletions

View File

@@ -1,107 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from lexer.token import Token
T = TypeVar("T")
@dataclass(frozen=True)
class Stmt(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> T: ...
@dataclass(frozen=True)
class AnnotationStmt(Stmt):
name: Token
schema: Optional[SchemaExpr]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_annotation_stmt(self)
@dataclass(frozen=True)
class Expr(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
@abstractmethod
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
@abstractmethod
def visit_type_expr(self, expr: TypeExpr) -> T: ...
@abstractmethod
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
@abstractmethod
def visit_schema_expr(self, expr: SchemaExpr) -> T: ...
@abstractmethod
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> T: ...
@dataclass(frozen=True)
class WildcardExpr(Expr):
token: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_wildcard_expr(self)
@dataclass(frozen=True)
class LiteralExpr(Expr):
value: Any
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_literal_expr(self)
@dataclass(frozen=True)
class TypeExpr(Expr):
name: Token
constraints: list[ConstraintExpr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_type_expr(self)
@dataclass(frozen=True)
class ConstraintExpr(Expr):
left: Expr
op: Token
right: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_constraint_expr(self)
@dataclass(frozen=True)
class SchemaExpr(Expr):
left: Token
elements: list[Expr]
right: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_schema_expr(self)
@dataclass(frozen=True)
class SchemaElementExpr(Expr):
name: Optional[Token]
type: Optional[Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_schema_element_expr(self)

View File

@@ -0,0 +1,81 @@
from core.ast.midas import (
ConstraintExpr,
ConstraintStmt,
Expr,
LiteralExpr,
OpStmt,
PropertyStmt,
Stmt,
TypeBodyExpr,
TypeExpr,
TypeStmt,
WildcardExpr,
)
class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"""An AST serializer which produces a JSON-compatible structure"""
def serialize(self, stmts: list[Stmt]) -> list[dict]:
return [stmt.accept(self) for stmt in stmts]
def visit_type_stmt(self, stmt: TypeStmt) -> dict:
return {
"_type": "TypeStmt",
"name": stmt.name.lexeme,
"bases": [base.accept(self) for base in stmt.bases],
"body": stmt.body.accept(self) if stmt.body is not None else None,
}
def visit_type_expr(self, expr: TypeExpr) -> dict:
return {
"_type": "TypeExpr",
"name": expr.name.lexeme,
"constraints": [constraint.accept(self) for constraint in expr.constraints],
}
def visit_constraint_expr(self, expr: ConstraintExpr) -> dict:
return {
"_type": "ConstraintExpr",
"left": expr.left.accept(self),
"op": expr.op.lexeme,
"right": expr.right.accept(self),
}
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
return {"_type": "WildcardExpr"}
def visit_literal_expr(self, expr: LiteralExpr) -> dict:
return {
"_type": "LiteralExpr",
"value": expr.value,
}
def visit_type_body_expr(self, expr: TypeBodyExpr) -> dict:
return {
"_type": "TypeBodyExpr",
"properties": [prop.accept(self) for prop in expr.properties],
}
def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
return {
"_type": "PropertyStmt",
"name": stmt.name.lexeme,
"type": stmt.type.accept(self),
}
def visit_op_stmt(self, stmt: OpStmt) -> dict:
return {
"_type": "OpStmt",
"left": stmt.left.accept(self),
"op": stmt.op.lexeme,
"right": stmt.right.accept(self),
"result": stmt.result.accept(self),
}
def visit_constraint_stmt(self, stmt: ConstraintStmt) -> dict:
return {
"_type": "ConstraintStmt",
"name": stmt.name.lexeme,
"constraint": stmt.constraint.accept(self),
}

View File

@@ -1,11 +1,10 @@
from __future__ import annotations from __future__ import annotations
import io
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum, auto from enum import Enum, auto
import io
from typing import Generator, Generic, Optional, Protocol, TypeVar from typing import Generator, Generic, Optional, Protocol, TypeVar
import core.ast.annotations as a
import core.ast.midas as m import core.ast.midas as m
@@ -84,113 +83,6 @@ class AstPrinter(Generic[T]):
child.accept(self) child.accept(self)
class AnnotationAstPrinter(AstPrinter, a.Expr.Visitor[None], a.Stmt.Visitor[None]):
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> None:
self._write_line("AnnotationStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_optional_child("schema", stmt.schema, last=True)
def visit_type_expr(self, expr: a.TypeExpr):
self._write_line("TypeExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_line("constraints", last=True)
with self._child_level():
for i, constraint in enumerate(expr.constraints):
self._idx = i
if i == len(expr.constraints) - 1:
self._mark_last()
constraint.accept(self)
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> None:
self._write_line("ConstraintExpr")
with self._child_level():
self._write_line("left")
with self._child_level():
self._mark_last()
expr.left.accept(self)
self._write_line(f"operator: {expr.op.lexeme}")
self._write_line("right", last=True)
with self._child_level():
self._mark_last()
expr.right.accept(self)
def visit_schema_expr(self, expr: a.SchemaExpr):
self._write_line("SchemaExpr")
with self._child_level():
for i, elmt in enumerate(expr.elements):
self._idx = i
if i == len(expr.elements) - 1:
self._mark_last()
elmt.accept(self)
def visit_schema_element_expr(self, expr: a.SchemaElementExpr):
self._write_line("SchemaElementExpr")
with self._child_level():
name_text: str = "None" if expr.name is None else f'"{expr.name.lexeme}"'
self._write_line(f"name: {name_text}")
self._write_optional_child("type", expr.type, last=True)
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> None:
self._write_line("WildcardExpr")
def visit_literal_expr(self, expr: a.LiteralExpr) -> None:
self._write_line("LiteralExpr")
with self._child_level():
self._write_line(f"value: {expr.value}", last=True)
class AnnotationPrinter(a.Expr.Visitor[str], a.Stmt.Visitor[str]):
def print(self, expr: a.Expr | a.Stmt):
return expr.accept(self)
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> str:
schema: str = ""
if stmt.schema is not None:
schema = stmt.schema.accept(self)
return f"{stmt.name.lexeme}{schema}"
def visit_type_expr(self, expr: a.TypeExpr) -> str:
parts: list[str] = [expr.name.lexeme]
for constraint in expr.constraints:
parts.append("(" + constraint.accept(self) + ")")
return " + ".join(parts)
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> str:
parts: list[str] = [
expr.left.accept(self),
expr.op.lexeme,
expr.right.accept(self),
]
return " ".join(parts)
def visit_schema_expr(self, expr: a.SchemaExpr) -> str:
res: str = expr.left.lexeme
res += ", ".join(elmt.accept(self) for elmt in expr.elements)
res += expr.right.lexeme
return res
def visit_schema_element_expr(self, expr: a.SchemaElementExpr) -> str:
parts: list[str] = []
if expr.name is not None:
parts.append(expr.name.lexeme)
if expr.type is None:
parts.append("_")
else:
parts.append(expr.type.accept(self))
return ": ".join(parts)
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> str:
return "_"
def visit_literal_expr(self, expr: a.LiteralExpr) -> str:
return str(expr.value)
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
def visit_type_stmt(self, stmt: m.TypeStmt): def visit_type_stmt(self, stmt: m.TypeStmt):
self._write_line("TypeStmt") self._write_line("TypeStmt")
@@ -289,6 +181,7 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
with self._child_level(): with self._child_level():
self._write_line(f"value: {expr.value}", last=True) self._write_line(f"value: {expr.value}", last=True)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def __init__(self, indent: int = 4): def __init__(self, indent: int = 4):
self.indent: int = indent self.indent: int = indent
@@ -302,11 +195,8 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
return expr.accept(self) return expr.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt): def visit_type_stmt(self, stmt: m.TypeStmt):
bases: list[str] = [ bases: list[str] = [b.accept(self) for b in stmt.bases]
b.accept(self)
for b in stmt.bases
]
res: str = self.indented(f"type {stmt.name.lexeme}<{', '.join(bases)}>") res: str = self.indented(f"type {stmt.name.lexeme}<{', '.join(bases)}>")
if stmt.body is not None: if stmt.body is not None:
res += " {\n" res += " {\n"
@@ -348,8 +238,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def visit_type_body_expr(self, expr: m.TypeBodyExpr): def visit_type_body_expr(self, expr: m.TypeBodyExpr):
properties: list[str] = [ properties: list[str] = [
self.indented(prop.accept(self)) self.indented(prop.accept(self)) for prop in expr.properties
for prop in expr.properties
] ]
return "\n".join(properties) return "\n".join(properties)
@@ -357,4 +246,4 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
return "_" return "_"
def visit_literal_expr(self, expr: m.LiteralExpr): def visit_literal_expr(self, expr: m.LiteralExpr):
return str(expr.value) return str(expr.value)

View File

@@ -0,0 +1,73 @@
// Simple custom type derived from float
type Custom(float)
// Simple custom types with constraints
type Latitude(float) where (-90 <= _ <= 90)
type Longitude(float) where (-180 <= _ <= 180)
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
type Difference[T](T)
// Complex custom type, containing two values accessible through properties
type GeoLocation {
lat: Latitude
lon: Longitude
}
// Define operations on our custom type
extend GeoLocation {
// This type is compatible with the `-` operation with another GeoLocation
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
// in a Difference of GeoLocations
op __sub__(GeoLocation) -> Difference[GeoLocation]
}
// For complex generics, you need to specify how the genericity the properties
// are handled
type Difference[GeoLocation] {
lat: Difference[Latitude]
lon: Difference[Longitude]
}
// Simple operation defined on our custom types
extend Latitude {
op __sub__(Latitude) -> Difference[Latitude]
}
extend Longitude {
op __sub__(Longitude) -> Difference[Longitude]
}
// Predefined custom predicates that can be referenced in other definitions
predicate Positive(v: float) = v >= 0
predicate StrictlyPositive(v: float) = v > 0
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
type Person {
name: str
// Property with an inline constraint
age: int? where (0 <= _ < 150)
// Property referencing a predicate
height: float where StrictlyPositive
home: GeoLocation
}
// Custom complex type derived from another complex type, with a constraint
// on a property
// Multiple proposed syntaxes, not yet defined
// Explicit, but new keyword
type EquatorialPerson refines Person where Equatorial(_.home)
// Explicit with existing keyword, might be confusing if expectations regarding 'is'
type EquatorialPerson is Person where Equatorial(_.home)
// Consistent and Python-friendly but can be confused with structural extension
type EquatorialPerson(Person) where Equatorial(_.home)
// Allow new properties, probably not useful
type EquatorialPerson extends Person where Equatorial(_.home)

View File

@@ -1,102 +0,0 @@
from lexer.base import Lexer
from lexer.keyword import ANNOTATION_KEYWORDS
from lexer.token import TokenType
class AnnotationLexer(Lexer):
def scan_token(self) -> None:
char: str = self.advance()
match char:
case "(":
self.add_token(TokenType.LEFT_PAREN)
case ")":
self.add_token(TokenType.RIGHT_PAREN)
case "[":
self.add_token(TokenType.LEFT_BRACKET)
case "]":
self.add_token(TokenType.RIGHT_BRACKET)
case "<":
self.add_token(
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS
)
case ">":
self.add_token(
TokenType.GREATER_EQUAL if self.match("=") else TokenType.GREATER
)
case "=":
self.add_token(
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
)
case "!":
if self.match("="):
self.add_token(TokenType.BANG_EQUAL)
else:
self.error("Unexpected single bang. Did you mean '!=' ?")
case ":":
self.add_token(TokenType.COLON)
case ",":
self.add_token(TokenType.COMMA)
case "_":
self.add_token(TokenType.UNDERSCORE)
case "+":
self.add_token(TokenType.PLUS)
case "#":
self.scan_comment()
case "\n":
self.add_token(TokenType.NEWLINE)
case " " | "\r" | "\t":
# Consume all whitespace characters until EOL or EOF
while (
self.peek().isspace()
and self.peek() != "\n"
and not self.is_at_end()
):
self.advance()
self.add_token(TokenType.WHITESPACE)
case _:
if char.isdigit():
self.scan_number()
elif char.isalpha():
self.scan_identifier()
else:
self.error("Unexpected character")
return None
def scan_number(self):
"""Scan the rest of number and add it as a token
This method handles both simple integers and floats. Scientific notation
and base prefixes (0x, 0b, 0o) are not supported
"""
while self.peek().isdigit():
self.advance()
if self.peek() == "." and self.peek_next().isdigit():
self.advance()
while self.peek().isdigit():
self.advance()
value: float = float(self.source[self.start : self.idx])
self.add_token(TokenType.NUMBER, value)
def scan_identifier(self):
"""Scan the rest of an identifier and add it as a token
An identifier starts with a letter, followed by any number of
alphanumerical characters or underscores
"""
while self.peek().isalnum() or self.peek() == "_":
self.advance()
lexeme: str = self.source[self.start : self.idx]
token_type: TokenType = ANNOTATION_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
self.add_token(token_type)
def scan_comment(self):
"""Scan the rest of a comment and add it as a token
A comment starts with a `#` character and ends at the EOL/EOF
"""
while self.peek() != "\n" and not self.is_at_end():
self.advance()
self.add_token(TokenType.COMMENT)

View File

@@ -5,6 +5,13 @@ from lexer.position import Position
from lexer.token import Token, TokenType from lexer.token import Token, TokenType
class MidasSyntaxError(Exception):
def __init__(self, pos: Position, message: str):
super().__init__(f"[ERROR] Error at {pos}: {message}")
self.pos: Position = pos
self.message: str = message
class Lexer(ABC): class Lexer(ABC):
"""An abstract lexer which provides methods to easily extend it into a concrete one """An abstract lexer which provides methods to easily extend it into a concrete one
@@ -38,9 +45,9 @@ class Lexer(ABC):
msg (str): the error message msg (str): the error message
Raises: Raises:
SyntaxError MidasSyntaxError
""" """
raise SyntaxError(f"[ERROR] Error at {self.start_pos}: {msg}") raise MidasSyntaxError(self.start_pos, msg)
def process(self) -> list[Token]: def process(self) -> list[Token]:
"""Scan tokens out of the source text """Scan tokens out of the source text
@@ -49,7 +56,7 @@ class Lexer(ABC):
list[Token]: all the tokens that could be scanned list[Token]: all the tokens that could be scanned
Raises: Raises:
SyntaxError: if a syntax error is found MidasSyntaxError: if a syntax error is found
""" """
self.scan_tokens() self.scan_tokens()
self.tokens.append(Token(TokenType.EOF, "", None, self.get_position())) self.tokens.append(Token(TokenType.EOF, "", None, self.get_position()))

View File

@@ -1,12 +1,6 @@
from lexer.token import TokenType from lexer.token import TokenType
ANNOTATION_KEYWORDS: dict[str, TokenType] = { KEYWORDS: dict[str, TokenType] = {
"True": TokenType.TRUE,
"False": TokenType.FALSE,
"None": TokenType.NONE,
}
MIDAS_KEYWORDS: dict[str, TokenType] = {
"type": TokenType.TYPE, "type": TokenType.TYPE,
"op": TokenType.OP, "op": TokenType.OP,
"constraint": TokenType.CONSTRAINT, "constraint": TokenType.CONSTRAINT,

View File

@@ -1,5 +1,5 @@
from lexer.base import Lexer from lexer.base import Lexer
from lexer.keyword import MIDAS_KEYWORDS from lexer.keyword import KEYWORDS
from lexer.token import TokenType from lexer.token import TokenType
@@ -102,7 +102,7 @@ class MidasLexer(Lexer):
self.advance() self.advance()
lexeme: str = self.source[self.start : self.idx] lexeme: str = self.source[self.start : self.idx]
token_type: TokenType = MIDAS_KEYWORDS.get(lexeme, TokenType.IDENTIFIER) token_type: TokenType = KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
self.add_token(token_type) self.add_token(token_type)
def scan_comment(self): def scan_comment(self):

View File

@@ -1,152 +0,0 @@
from typing import Optional
from core.ast.annotations import (
AnnotationStmt,
ConstraintExpr,
Expr,
LiteralExpr,
SchemaElementExpr,
SchemaExpr,
Stmt,
TypeExpr,
WildcardExpr,
)
from lexer.token import Token, TokenType
from parser.base import Parser
from parser.errors import ParsingError
class AnnotationParser(Parser):
"""A simple parser for custom type annotations"""
SYNC_BOUNDARY: set[TokenType] = set()
def parse(self) -> Optional[Stmt]:
stmt: Optional[Stmt] = None
try:
stmt = self.annotation()
except ParsingError:
self.synchronize()
if not self.is_at_end():
self.error(self.peek(), "Extra tokens")
return stmt
def synchronize(self):
"""Skip tokens until a synchronization boundary is found
This method allows gracefully recovering from a parse error
to a safe place and continue parsing
"""
self.advance()
while not self.is_at_end():
if self.peek().type in self.SYNC_BOUNDARY:
return
self.advance()
def annotation(self) -> AnnotationStmt:
"""Parse an annotation
An annotation is written as `Type` or `Type[Schema]`
Returns:
AnnotationStmt: the parsed annotation statement
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type identifier")
schema: Optional[SchemaExpr] = None
if self.match(TokenType.LEFT_BRACKET):
schema = self.schema()
return AnnotationStmt(name=name, schema=schema)
def type_expr(self) -> TypeExpr:
"""Parse a type expression
Returns:
TypeExpr: the parsed type expression
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
constraints: list[ConstraintExpr] = []
while not self.is_at_end() and self.match(TokenType.PLUS):
self.consume(TokenType.LEFT_PAREN, "Expected '(' before type constraint")
constraints.append(self.constraint_expr())
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after type constraint")
return TypeExpr(name=name, constraints=constraints)
def constraint_expr(self) -> ConstraintExpr:
"""Parse a type constraint
Returns:
ConstraintExpr: the parsed type constraint expression
"""
left: Expr = self.constraint_value()
op: Token = self.constraint_operator()
right: Expr = self.constraint_value()
return ConstraintExpr(left=left, op=op, right=right)
def constraint_value(self) -> Expr:
if self.match(TokenType.UNDERSCORE):
return WildcardExpr(self.previous())
return self.literal()
def literal(self) -> LiteralExpr:
if self.match(TokenType.FALSE):
return LiteralExpr(False)
if self.match(TokenType.TRUE):
return LiteralExpr(True)
if self.match(TokenType.NONE):
return LiteralExpr(None)
if self.match(TokenType.NUMBER):
return LiteralExpr(self.previous().value)
raise self.error(self.peek(), "Expected literal")
def constraint_operator(self) -> Token:
if self.match(TokenType.LESS, TokenType.LESS_EQUAL, TokenType.GREATER, TokenType.GREATER_EQUAL, TokenType.EQUAL_EQUAL, TokenType.BANG_EQUAL):
return self.previous()
raise self.error(self.peek(), "Expected constraint operator")
def schema(self) -> SchemaExpr:
"""Parse a schema definition
A comma separated list of schema elements
Returns:
SchemaExpr: the parsed schema expression
"""
left: Token = self.previous()
elements: list[Expr] = []
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
elements.append(self.schema_element())
if not self.check(TokenType.RIGHT_BRACKET):
self.consume(TokenType.COMMA, "Expected ',' between schema elements")
right: Token = self.consume(TokenType.RIGHT_BRACKET, "Unclosed schema")
return SchemaExpr(left=left, elements=elements, right=right)
def schema_element(self) -> SchemaElementExpr:
"""Parse a schema element
An anonymous element (`_`), a type, an untyped named column (`name: _`),
or a named column (`name: Type`)
Returns:
SchemaElementExpr: the parsed schema element expression
"""
if self.match(TokenType.UNDERSCORE):
return SchemaElementExpr(name=None, type=None)
if not self.check(TokenType.IDENTIFIER):
raise self.error(self.peek(), "Expected schema element")
name: Optional[Token] = None
type: Optional[TypeExpr] = None
if self.check_next(TokenType.COLON):
name = self.advance()
self.advance()
if not self.match(TokenType.UNDERSCORE):
type = self.type_expr()
return SchemaElementExpr(name=name, type=type)

View File

@@ -1,26 +1,34 @@
identifier ::= '[a-zA-Z][a-zA-Z_]*' // W3C EBNF syntax definition for Midas
Identifier ::= [a-zA-Z] [a-zA-Z_]*
OpIdentifier ::= Identifier | "__" Identifier "__"
integer ::= '\d+' Integer ::= '\d+'
number ::= integer ["." integer] Number ::= "-"? Integer ("." Integer)?
boolean ::= "False" | "True" Boolean ::= "False" | "True"
none ::= "None" None ::= "None"
value ::= number | boolean | none Variable ::= Identifier ("." Identifier)*
lambda-value ::= "_" | value Value ::= Number | Boolean | None
lambda-operator ::= ">" | "<" | ">=" | "<=" | "==" | "!=" LambdaValue ::= "_" | Value | Variable
lambda ::= lambda-value lambda-operator lambda-value LambdaOperator ::= ">" | "<" | ">=" | "<=" | "==" | "!="
Lambda ::= LambdaValue (LambdaOperator LambdaValue)+
constraint ::= identifier | "(" lambda ")" SimpleType ::= Identifier "?"?
base-type ::= identifier Template ::= "[" SimpleType "]"
type ::= base-type { "+" constraint } Type ::= Identifier Template? "?"?
Constraint ::= Identifier | Lambda
type-property ::= 'identifier' ":" 'type' SimpleTypeBase ::= "(" Type ")"
type-body ::= "{" { 'type-property' } "}" WrappedConstraint ::= Constraint | "(" Constraint ")"
Constraints ::= WrappedConstraint ("&" WrappedConstraint)*
operation-type ::= "<" 'type' ">" TypeProperty ::= Identifier ":" Type ("where" Constraints)?
ComplexTypeBody ::= "{" TypeProperty* "}"
OpDefinition ::= "op" OpIdentifier "(" Type ")" "->" Type
ExtendBody ::= "{" OpDefinition* "}"
type-statement ::= "type" 'identifier' "<" 'type' {"," 'type'} ">" ['type-body'] TypeStatement ::= "type" Identifier Template? (SimpleTypeBase ("where" Constraints)? | ComplexTypeBody)
operation-statement ::= "op" 'operation-type' 'operator' 'operation-type' "=" 'operation-type' ExtendStatement ::= "extend" Type ExtendBody
constraint-statement ::= "constraint" 'identifier' "=" 'lambda' PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraints
statement ::= type-statement | operation-statement | constraint-statement Statement ::= TypeStatement | ExtendStatement | PredicateStatement

View File

@@ -1,4 +1,15 @@
#import "@preview/fervojo:0.1.1": render #import "@preview/fervojo:0.1.1": default-css, render
#let extra-css = ```css
svg.railroad .terminal rect {
fill: #F7DCD4;
}
```
#let css = default-css() + bytes(extra-css.text)
#let variable = ```
{[`variable` 'identifier'*"."]}
```
#let value = ``` #let value = ```
{[`value` < {[`value` <
@@ -8,90 +19,156 @@
>]} >]}
``` ```
#let constraint = ``` #let lambda-value = ```
{[`constraint` <"_", 'value'> <">", "<", ">=", "<=", "==", "!="> <"_", 'value'>]} {[`lambda-value` <"_", 'value', 'variable'>]}
``` ```
#let type-with-constraints = ``` #let lambda-operator = ```
{[`type-with-constraints` 'identifier' <!, ["+" "(" 'constraint' ")"] * !>]} {[`lambda-operator` <">", "<", ">=", "<=", "==", "!=">]}
```
#let lambda = ```
{[`lambda` 'lambda-value' ['lambda-operator' 'lambda-value']*!]}
```
#let simple-type = ```
{[`simple-type` 'identifier' <!, "?">]}
```
#let template = ```
{[`template` "[" 'simple-type' "]"]}
```
#let type = ```
{[`type` 'identifier' <!, 'template'> <!, "?">]}
```
#let constraint = ```
{[`constraint` <'identifier', 'lambda'>]}
```
#let wrapped-constraint = ```
{[`wrapped-constraint` <'constraint', ["(" 'constraint' ")"]>]}
```
#let constraints = ```
{[`constraints` 'wrapped-constraint'*"&"]}
``` ```
#let type-property = ``` #let type-property = ```
{[`type-property` 'identifier' ":" 'type-with-constraints']} {[`type-property` 'identifier' ":" 'type' <!, ["where" 'constraints']>]}
``` ```
#let type-body = ``` #let type-body = ```
{[`type-body` "{" <!, 'type-property'*!> "}"]} {[`type-body` "{" <!, 'type-property'*!> "}"]}
``` ```
#let operation-type = ```
{[`operation-type` "<" 'type-with-constraints' ">"]}
```
#let type-statement = ``` #let type-statement = ```
{[`type-statement` "type" 'identifier' "<" 'type-with-constraints'*"," ">" <!, 'type-body'>]} {[`type-statement` "type" 'identifier' <!, 'template'> <[["(" 'type' ")"] <!, ["where" 'constraints']>], 'type-body'>]}
``` ```
#let operation-statement = ``` #let op-definition = ```
{[`operation-statement` "op" 'operation-type' "operator" 'operation-type' "=" 'operation-type']} {[`op-definition` "op" <'identifier', ["__" 'identifier' "__"]> "(" 'type' ")" "->" 'type']}
``` ```
#let constraint-statement = ``` #let extend-statement = ```
{[`constraint-statement` "constraint" 'identifier' "=" 'constraint']} {[`extend-statement` "extend" 'type' "{" <!, 'op-definition'*!> "}"]}
```
#let predicate-statement = ```
{[`predicate-statement` "predicate" 'identifier' "(" 'identifier' ":" 'type' ")" "=" 'constraints']}
``` ```
#let statement = ``` #let statement = ```
{[`statement` <'type-statement', 'operation-statement', 'constraint-statement'>]} {[`statement` <'type-statement', 'extend-statement', 'predicate-statement'>]}
``` ```
#let rules = ( #let rules = (
value, variable: variable,
constraint, value: value,
type-with-constraints, lambda-value: lambda-value,
type-property, lambda-operator: lambda-operator,
type-body, lambda: lambda,
operation-type, simple-type: simple-type,
type-statement, template: template,
operation-statement, type: type,
constraint-statement, constraint: constraint,
statement, wrapped-constraint: wrapped-constraint,
constraints: constraints,
type-property: type-property,
type-body: type-body,
type-statement: type-statement,
op-definition: op-definition,
extend-statement: extend-statement,
predicate-statement: predicate-statement,
statement: statement,
)
#let inline = (
"value",
"variable",
"lambda-operator",
"template",
"lambda",
"simple-type",
"wrapped-constraint",
"type-property",
"type-body",
"op-definition",
"type-statement",
"extend-statement",
"predicate-statement",
) )
#set text(font: "Source Sans 3") #set text(font: "Source Sans 3")
= Midas type definition syntax #title[Midas type definition syntax]
#for rule in rules { = Outline
render(rule)
}
/* #box(
#let by-name = ( columns(
value: value, 2,
constraint: constraint, outline(title: none),
type-with-constraints: type-with-constraints, ),
type-property: type-property, height: 8cm,
type-body: type-body, stroke: 1pt,
operation-type: operation-type, inset: 1em,
type-statement: type-statement,
operation-statement: operation-statement,
constraint-statement: constraint-statement,
) )
= Statements and expressions
#for (name, rule) in rules.pairs().rev() {
[== #name]
render(rule, css: css)
}
#let substitute(base-rule) = { #let substitute(base-rule) = {
let new-rule = base-rule let new-rule = base-rule
for (key, rule) in by-name.pairs() { for name in inline {
new-rule = new-rule.replace("'" + key + "'", rule.text.slice(1, -1)) let rule = rules.at(name)
let replacement = rule.text.slice(1, -1).replace(regex("\[`.*?`"), "[")
replacement = "[" + replacement + "#`" + name + "`]"
new-rule = new-rule.replace(
"'" + name + "'",
replacement,
)
} }
if new-rule != base-rule { if new-rule != base-rule {
new-rule = substitute(new-rule) new-rule = substitute(new-rule)
} }
return new-rule.replace(regex("`.*?`"), "") return new-rule
} }
#let combined = raw(substitute(statement.text))
#set page(flipped: true) #set page(flipped: true)
#render(combined)
*/ = Combined rules
#for (name, rule) in rules.pairs() {
if not name in inline {
[== #name]
let combined = substitute(rule.text)
render(raw(combined), css: css)
//raw(block: true, combined)
}
}

204
tester.py Normal file
View File

@@ -0,0 +1,204 @@
from __future__ import annotations
import argparse
import difflib
import json
import sys
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Iterator, Optional
from core.ast.json_serializer import AstJsonSerializer
from core.ast.midas import Stmt
from lexer.base import MidasSyntaxError
from lexer.midas import MidasLexer
from lexer.token import Token
from parser.midas import MidasParser
DEFAULT_BASE_DIR: Path = Path() / "tests"
@dataclass
class CaseResult:
tokens: Optional[list[dict]] = None
stmts: Optional[list[dict]] = None
errors: list[dict] = field(default_factory=list)
def dumps(self) -> str:
return json.dumps(asdict(self), indent=2)
class Tester:
"""A test runner to check for regressions in the lexer and parser"""
def __init__(self, base_dir: Path):
self.base_dir: Path = base_dir
def _list_tests(self) -> list[Path]:
return list(self.base_dir.rglob("*.midas"))
def run_all_tests(self) -> bool:
paths: list[Path] = self._list_tests()
return self.run_tests(paths)
def run_tests(self, tests: list[Path]) -> bool:
rule: str = "-" * 80
n: int = len(tests)
successes: int = 0
failures: int = 0
print(rule)
for i, test in enumerate(tests):
print(f"Case {i+1}/{n}: {test}")
success: bool = self._run_test(test)
if success:
successes += 1
else:
failures += 1
print(rule)
print(f"Success: {successes}/{n}")
print(f"Failed: {failures}/{n}")
print(rule)
return failures == 0
def _run_test(self, path: Path) -> bool:
result: CaseResult = self._exec_case(path)
result_path: Path = self._result_path(path)
expected: str = result_path.read_text()
actual: str = result.dumps()
if expected == actual:
return True
diff = difflib.unified_diff(
expected.splitlines(keepends=True),
actual.splitlines(keepends=True),
fromfile="Snapshot",
tofile="Result",
)
self._print_diff(diff)
return False
def _exec_case(self, path: Path) -> CaseResult:
if not path.exists():
raise FileNotFoundError(f"Could not find test '{path}'")
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
result: CaseResult = CaseResult()
content: str = path.read_text()
lexer: MidasLexer = MidasLexer(content)
tokens: list[Token] = []
try:
tokens = lexer.process()
result.tokens = [
{
"type": token.type.name,
"lexeme": token.lexeme,
"line": token.position.line,
"column": token.position.column,
}
for token in tokens
]
except MidasSyntaxError as e:
result.errors.append(
{
"type": "SyntaxError",
"line": e.pos.line,
"column": e.pos.column,
"message": e.message,
}
)
return result
parser: MidasParser = MidasParser(tokens)
stmts: list[Stmt] = parser.parse()
result.stmts = AstJsonSerializer().serialize(stmts)
result.errors.extend(
[
{
"line": e.token.position.line,
"column": e.token.position.column,
"message": e.message,
}
for e in parser.errors
]
)
return result
def update_all_tests(self):
paths: list[Path] = self._list_tests()
return self.update_tests(paths)
def update_tests(self, tests: list[Path]):
updated: int = 0
for test in tests:
if self._update_test(test):
updated += 1
print(f"Updated {updated}/{len(tests)} tests")
def _update_test(self, path: Path) -> bool:
result: CaseResult = self._exec_case(path)
result_path: Path = self._result_path(path)
current: str = result_path.read_text()
new: str = result.dumps()
if current == new:
return False
result_path.write_text(new)
return True
def _result_path(self, test_path: Path) -> Path:
return test_path.parent / (test_path.name + ".ref.json")
def _print_diff(self, diff: Iterator[str]):
for line in diff:
if line.startswith("+") and not line.startswith("+++"):
print(f"\033[92m{line}\033[0m", end="")
elif line.startswith("-") and not line.startswith("---"):
print(f"\033[91m{line}\033[0m", end="")
else:
print(line, end="")
print()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-D",
"--base-dir",
help="Base directory containing test files",
type=Path,
default=DEFAULT_BASE_DIR,
)
subparsers = parser.add_subparsers(dest="subcommand")
update = subparsers.add_parser("update")
update.add_argument("-a", "--all", action="store_true")
update.add_argument("FILE", type=Path, nargs="*")
run = subparsers.add_parser("run")
run.add_argument("-a", "--all", action="store_true")
run.add_argument("FILE", type=Path, nargs="*")
args = parser.parse_args()
tester: Tester = Tester(args.base_dir)
match args.subcommand:
case "update":
if args.all:
tester.update_all_tests()
else:
tester.update_tests(args.FILE)
case "run":
success: bool
if args.all:
success = tester.run_all_tests()
else:
success = tester.run_tests(args.FILE)
if not success:
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,24 @@
// Simple custom type derived from floats
type Latitude<float>
type Longitude<float>
// Complex custom type, containing two values accessible through properties
type GeoLocation<Latitude, Longitude> {
lat: Latitude
lon: Longitude
}
type LatitudeDiff<float>
type LongitudeDiff<float>
// Simple operation defined on our custom types
op <Latitude> - <Latitude> = <LatitudeDiff>
op <Longitude> - <Longitude> = <LongitudeDiff>
// Simple custom type with a constraint
type Age<int + (0 <= _) + (_ < 150)>
// Predefined custom constraints that can be referenced in other definitions
constraint Positive = _ >= 0
constraint StrictlyPositive = _ > 0
//constraint Even = _ % 2 == 0

File diff suppressed because it is too large Load Diff

View File

@@ -1,129 +0,0 @@
from typing import Any
import pytest
from lexer.annotations import AnnotationLexer
from lexer.token import Token, TokenType
def scan(source: str) -> list[Token]:
return AnnotationLexer(source).process()
def assert_n_tokens(tokens: list[Token], n: int):
assert len(tokens) == n + 1
assert tokens[-1].type == TokenType.EOF
@pytest.mark.parametrize(
"src,expected",
[
("(", TokenType.LEFT_PAREN),
(")", TokenType.RIGHT_PAREN),
("[", TokenType.LEFT_BRACKET),
("]", TokenType.RIGHT_BRACKET),
(":", TokenType.COLON),
(",", TokenType.COMMA),
("_", TokenType.UNDERSCORE),
],
)
def test_punctuation(src: str, expected: TokenType):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected
@pytest.mark.parametrize(
"src,expected",
[
("+", TokenType.PLUS),
(">", TokenType.GREATER),
(">=", TokenType.GREATER_EQUAL),
("<", TokenType.LESS),
("<=", TokenType.LESS_EQUAL),
("=", TokenType.EQUAL),
("==", TokenType.EQUAL_EQUAL),
("!=", TokenType.BANG_EQUAL),
],
)
def test_operators(src: str, expected: TokenType):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected
@pytest.mark.parametrize(
"src,expected",
[
("a", TokenType.IDENTIFIER),
("foo", TokenType.IDENTIFIER),
("foo1", TokenType.IDENTIFIER),
("foo_", TokenType.IDENTIFIER),
("foo_bar1_baz2", TokenType.IDENTIFIER),
("FOO_BAR1_BAZ2", TokenType.IDENTIFIER),
("True", TokenType.TRUE),
("False", TokenType.FALSE),
("None", TokenType.NONE),
],
)
def test_identifiers_keywords(src: str, expected: TokenType):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected
@pytest.mark.parametrize(
"src,expected",
[
("#", TokenType.COMMENT),
("# This is a comment", TokenType.COMMENT),
(" ", TokenType.WHITESPACE),
("\t", TokenType.WHITESPACE),
("\r", TokenType.WHITESPACE),
(" \t \t", TokenType.WHITESPACE),
("\n", TokenType.NEWLINE),
],
)
def test_misc(src: str, expected: TokenType):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected
@pytest.mark.parametrize(
"src,expected_type,expected_value",
[
("0", TokenType.NUMBER, 0),
("0.0", TokenType.NUMBER, 0),
("1234.56", TokenType.NUMBER, 1234.56),
],
)
def test_literals(src: str, expected_type: TokenType, expected_value: Any):
tokens: list[Token] = scan(src)
assert_n_tokens(tokens, 1)
assert tokens[0].type == expected_type
assert tokens[0].value == expected_value
def test_single_bang_error():
with pytest.raises(SyntaxError):
scan("!")
@pytest.mark.parametrize(
"src",
[
"-",
"*",
"/",
"{",
"}",
"@",
'"',
"'",
".",
],
)
def test_unexpected_character(src: str):
with pytest.raises(SyntaxError):
scan(src)

View File

@@ -2,6 +2,7 @@ from typing import Any
import pytest import pytest
from lexer.base import MidasSyntaxError
from lexer.midas import MidasLexer from lexer.midas import MidasLexer
from lexer.token import Token, TokenType from lexer.token import Token, TokenType
@@ -111,7 +112,7 @@ def test_literals(src: str, expected_type: TokenType, expected_value: Any):
def test_single_bang_error(): def test_single_bang_error():
with pytest.raises(SyntaxError): with pytest.raises(MidasSyntaxError):
scan("!") scan("!")
@@ -125,5 +126,5 @@ def test_single_bang_error():
], ],
) )
def test_unexpected_character(src: str): def test_unexpected_character(src: str):
with pytest.raises(SyntaxError): with pytest.raises(MidasSyntaxError):
scan(src) scan(src)

View File

@@ -1,130 +0,0 @@
from typing import Optional
import pytest
from core.ast.annotations import (
AnnotationStmt,
ConstraintExpr,
Expr,
LiteralExpr,
SchemaElementExpr,
SchemaExpr,
Stmt,
TypeExpr,
WildcardExpr,
)
from lexer.annotations import AnnotationLexer
from lexer.position import Position
from lexer.token import Token
from parser.annotations import AnnotationParser
class AstSerializer(Stmt.Visitor[str], Expr.Visitor[str]):
def serialize(self, stmt: Stmt):
return stmt.accept(self)
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> str:
schema: str = ""
if stmt.schema is not None:
schema = " " + stmt.schema.accept(self)
return f"(annotation {stmt.name.lexeme}{schema})"
def visit_schema_expr(self, expr: SchemaExpr) -> str:
elements: list[str] = [elmt.accept(self) for elmt in expr.elements]
return f"(schema {' '.join(elements)})"
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> str:
name: str = expr.name.lexeme if expr.name is not None else "_"
type: str = expr.type.accept(self) if expr.type is not None else "_"
return f"({name} {type})"
def visit_type_expr(self, expr: TypeExpr) -> str:
res: str = f"({expr.name.lexeme}"
for constraint in expr.constraints:
res += " " + constraint.accept(self)
res += ")"
return res
def visit_constraint_expr(self, expr: ConstraintExpr) -> str:
return f"(constraint {expr.left.accept(self)} {expr.op.lexeme} {expr.right.accept(self)})"
def visit_wildcard_expr(self, expr: WildcardExpr) -> str:
return "(_)"
def visit_literal_expr(self, expr: LiteralExpr) -> str:
return f"({expr.value})"
def parse(source: str) -> Optional[Stmt]:
tokens: list[Token] = AnnotationLexer(source).process()
return AnnotationParser(tokens).parse()
def must_parse(source: str) -> Stmt:
stmt: Optional[Stmt] = parse(source)
assert stmt is not None
return stmt
def ast_str(source: str) -> str:
stmt: Stmt = must_parse(source)
return AstSerializer().serialize(stmt)
@pytest.mark.parametrize(
"src,expected",
[
("Type", "(annotation Type)"),
("Type[]", "(annotation Type (schema ))"),
(
"""
Frame[
verified: bool,
birth_year: int,
height: float + ( _ > 0 ) + ( _ < 250 ),
name: str,
date: datetime,
float, # unnamed
unknown: _, # untyped
_ # unnamed and untyped
]
""",
"(annotation Frame (schema (verified (bool)) (birth_year (int)) (height (float (constraint (_) > (0.0)) (constraint (_) < (250.0)))) (name (str)) (date (datetime)) (_ (float)) (unknown _) (_ _)))",
),
],
)
def test_expressions(src: str, expected: str):
assert ast_str(src) == expected
@pytest.mark.parametrize(
"src,pos,should_fail",
[
("", (1, 1), True),
("42", (1, 1), True),
("True", (1, 1), True),
("Type[", (1, 6), True),
("Type[] Type2", (1, 8), False),
("Type[bool:]", (1, 11), True),
("Type[3]", (1, 6), True),
("Type[bool float]", (1, 11), True),
("Type[bool (_ < 2)]", (1, 11), True),
("Type[bool + _ < 2)]", (1, 13), True),
("Type[bool + (_ < 2]", (1, 19), True),
("Type[bool + (< 2)]", (1, 14), True),
("Type[bool + (_ + 2)]", (1, 16), True),
("Type[bool + (Foo + Bar)]", (1, 14), True),
# ("Type[bool,]", (1, 11), True), # trailing comma is accepted, TODO: update parser or EBNF
("Type[bool, Type[]]", (1, 16), True),
("Type[foo: 3]", (1, 11), True),
],
)
def test_parsing_error(src: str, pos: tuple[int, int], should_fail: bool):
tokens: list[Token] = AnnotationLexer(src).process()
parser: AnnotationParser = AnnotationParser(tokens)
stmt: Optional[Stmt] = parser.parse()
if should_fail:
assert stmt is None
assert len(parser.errors) != 0
error_pos: Position = parser.errors[0].token.position
assert (error_pos.line, error_pos.column) == pos