4 Commits

Author SHA1 Message Date
88f92d6e1f tests(parser): add simple types snapshot test 2026-05-19 14:12:12 +02:00
db4ed74365 tests(parser): add snapshot test runner
the diff printing function was suggested by Gemini

Co-authored-by: Gemini <noreply@gemini.google.com>
2026-05-19 14:11:32 +02:00
7cbf4fdece feat(tests): add AST JSON serializer 2026-05-19 14:00:32 +02:00
1fa9a09bfe feat(parser): use custom syntax error class 2026-05-19 13:57:00 +02:00
5 changed files with 1464 additions and 3 deletions

View File

@@ -0,0 +1,81 @@
from core.ast.midas import (
ConstraintExpr,
ConstraintStmt,
Expr,
LiteralExpr,
OpStmt,
PropertyStmt,
Stmt,
TypeBodyExpr,
TypeExpr,
TypeStmt,
WildcardExpr,
)
class AstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"""An AST serializer which produces a JSON-compatible structure"""
def serialize(self, stmts: list[Stmt]) -> list[dict]:
return [stmt.accept(self) for stmt in stmts]
def visit_type_stmt(self, stmt: TypeStmt) -> dict:
return {
"_type": "TypeStmt",
"name": stmt.name.lexeme,
"bases": [base.accept(self) for base in stmt.bases],
"body": stmt.body.accept(self) if stmt.body is not None else None,
}
def visit_type_expr(self, expr: TypeExpr) -> dict:
return {
"_type": "TypeExpr",
"name": expr.name.lexeme,
"constraints": [constraint.accept(self) for constraint in expr.constraints],
}
def visit_constraint_expr(self, expr: ConstraintExpr) -> dict:
return {
"_type": "ConstraintExpr",
"left": expr.left.accept(self),
"op": expr.op.lexeme,
"right": expr.right.accept(self),
}
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
return {"_type": "WildcardExpr"}
def visit_literal_expr(self, expr: LiteralExpr) -> dict:
return {
"_type": "LiteralExpr",
"value": expr.value,
}
def visit_type_body_expr(self, expr: TypeBodyExpr) -> dict:
return {
"_type": "TypeBodyExpr",
"properties": [prop.accept(self) for prop in expr.properties],
}
def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
return {
"_type": "PropertyStmt",
"name": stmt.name.lexeme,
"type": stmt.type.accept(self),
}
def visit_op_stmt(self, stmt: OpStmt) -> dict:
return {
"_type": "OpStmt",
"left": stmt.left.accept(self),
"op": stmt.op.lexeme,
"right": stmt.right.accept(self),
"result": stmt.result.accept(self),
}
def visit_constraint_stmt(self, stmt: ConstraintStmt) -> dict:
return {
"_type": "ConstraintStmt",
"name": stmt.name.lexeme,
"constraint": stmt.constraint.accept(self),
}

View File

@@ -5,6 +5,13 @@ from lexer.position import Position
from lexer.token import Token, TokenType from lexer.token import Token, TokenType
class MidasSyntaxError(Exception):
def __init__(self, pos: Position, message: str):
super().__init__(f"[ERROR] Error at {pos}: {message}")
self.pos: Position = pos
self.message: str = message
class Lexer(ABC): class Lexer(ABC):
"""An abstract lexer which provides methods to easily extend it into a concrete one """An abstract lexer which provides methods to easily extend it into a concrete one
@@ -38,9 +45,9 @@ class Lexer(ABC):
msg (str): the error message msg (str): the error message
Raises: Raises:
SyntaxError MidasSyntaxError
""" """
raise SyntaxError(f"[ERROR] Error at {self.start_pos}: {msg}") raise MidasSyntaxError(self.start_pos, msg)
def process(self) -> list[Token]: def process(self) -> list[Token]:
"""Scan tokens out of the source text """Scan tokens out of the source text
@@ -49,7 +56,7 @@ class Lexer(ABC):
list[Token]: all the tokens that could be scanned list[Token]: all the tokens that could be scanned
Raises: Raises:
SyntaxError: if a syntax error is found MidasSyntaxError: if a syntax error is found
""" """
self.scan_tokens() self.scan_tokens()
self.tokens.append(Token(TokenType.EOF, "", None, self.get_position())) self.tokens.append(Token(TokenType.EOF, "", None, self.get_position()))

204
tester.py Normal file
View File

@@ -0,0 +1,204 @@
from __future__ import annotations
import argparse
import difflib
import json
import sys
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Iterator, Optional
from core.ast.json_serializer import AstJsonSerializer
from core.ast.midas import Stmt
from lexer.base import MidasSyntaxError
from lexer.midas import MidasLexer
from lexer.token import Token
from parser.midas import MidasParser
DEFAULT_BASE_DIR: Path = Path() / "tests"
@dataclass
class CaseResult:
tokens: Optional[list[dict]] = None
stmts: Optional[list[dict]] = None
errors: list[dict] = field(default_factory=list)
def dumps(self) -> str:
return json.dumps(asdict(self), indent=2)
class Tester:
"""A test runner to check for regressions in the lexer and parser"""
def __init__(self, base_dir: Path):
self.base_dir: Path = base_dir
def _list_tests(self) -> list[Path]:
return list(self.base_dir.rglob("*.midas"))
def run_all_tests(self) -> bool:
paths: list[Path] = self._list_tests()
return self.run_tests(paths)
def run_tests(self, tests: list[Path]) -> bool:
rule: str = "-" * 80
n: int = len(tests)
successes: int = 0
failures: int = 0
print(rule)
for i, test in enumerate(tests):
print(f"Case {i+1}/{n}: {test}")
success: bool = self._run_test(test)
if success:
successes += 1
else:
failures += 1
print(rule)
print(f"Success: {successes}/{n}")
print(f"Failed: {failures}/{n}")
print(rule)
return failures == 0
def _run_test(self, path: Path) -> bool:
result: CaseResult = self._exec_case(path)
result_path: Path = self._result_path(path)
expected: str = result_path.read_text()
actual: str = result.dumps()
if expected == actual:
return True
diff = difflib.unified_diff(
expected.splitlines(keepends=True),
actual.splitlines(keepends=True),
fromfile="Snapshot",
tofile="Result",
)
self._print_diff(diff)
return False
def _exec_case(self, path: Path) -> CaseResult:
if not path.exists():
raise FileNotFoundError(f"Could not find test '{path}'")
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
result: CaseResult = CaseResult()
content: str = path.read_text()
lexer: MidasLexer = MidasLexer(content)
tokens: list[Token] = []
try:
tokens = lexer.process()
result.tokens = [
{
"type": token.type.name,
"lexeme": token.lexeme,
"line": token.position.line,
"column": token.position.column,
}
for token in tokens
]
except MidasSyntaxError as e:
result.errors.append(
{
"type": "SyntaxError",
"line": e.pos.line,
"column": e.pos.column,
"message": e.message,
}
)
return result
parser: MidasParser = MidasParser(tokens)
stmts: list[Stmt] = parser.parse()
result.stmts = AstJsonSerializer().serialize(stmts)
result.errors.extend(
[
{
"line": e.token.position.line,
"column": e.token.position.column,
"message": e.message,
}
for e in parser.errors
]
)
return result
def update_all_tests(self):
paths: list[Path] = self._list_tests()
return self.update_tests(paths)
def update_tests(self, tests: list[Path]):
updated: int = 0
for test in tests:
if self._update_test(test):
updated += 1
print(f"Updated {updated}/{len(tests)} tests")
def _update_test(self, path: Path) -> bool:
result: CaseResult = self._exec_case(path)
result_path: Path = self._result_path(path)
current: str = result_path.read_text()
new: str = result.dumps()
if current == new:
return False
result_path.write_text(new)
return True
def _result_path(self, test_path: Path) -> Path:
return test_path.parent / (test_path.name + ".ref.json")
def _print_diff(self, diff: Iterator[str]):
for line in diff:
if line.startswith("+") and not line.startswith("+++"):
print(f"\033[92m{line}\033[0m", end="")
elif line.startswith("-") and not line.startswith("---"):
print(f"\033[91m{line}\033[0m", end="")
else:
print(line, end="")
print()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-D",
"--base-dir",
help="Base directory containing test files",
type=Path,
default=DEFAULT_BASE_DIR,
)
subparsers = parser.add_subparsers(dest="subcommand")
update = subparsers.add_parser("update")
update.add_argument("-a", "--all", action="store_true")
update.add_argument("FILE", type=Path, nargs="*")
run = subparsers.add_parser("run")
run.add_argument("-a", "--all", action="store_true")
run.add_argument("FILE", type=Path, nargs="*")
args = parser.parse_args()
tester: Tester = Tester(args.base_dir)
match args.subcommand:
case "update":
if args.all:
tester.update_all_tests()
else:
tester.update_tests(args.FILE)
case "run":
success: bool
if args.all:
success = tester.run_all_tests()
else:
success = tester.run_tests(args.FILE)
if not success:
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,24 @@
// Simple custom type derived from floats
type Latitude<float>
type Longitude<float>
// Complex custom type, containing two values accessible through properties
type GeoLocation<Latitude, Longitude> {
lat: Latitude
lon: Longitude
}
type LatitudeDiff<float>
type LongitudeDiff<float>
// Simple operation defined on our custom types
op <Latitude> - <Latitude> = <LatitudeDiff>
op <Longitude> - <Longitude> = <LongitudeDiff>
// Simple custom type with a constraint
type Age<int + (0 <= _) + (_ < 150)>
// Predefined custom constraints that can be referenced in other definitions
constraint Positive = _ >= 0
constraint StrictlyPositive = _ > 0
//constraint Even = _ % 2 == 0

File diff suppressed because it is too large Load Diff