From abf6787946c2eb4a16244be8e6105092036120d9 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Wed, 20 May 2026 15:45:55 +0200 Subject: [PATCH] fix(parser)!: remove annotation lexer and parser --- core/ast/annotations.py | 107 ----------------- core/ast/printer.py | 123 +------------------- lexer/annotations.py | 102 ----------------- lexer/keyword.py | 8 +- lexer/midas.py | 4 +- parser/annotations.py | 152 ------------------------- tests/lexer/test_annotation_lexer.py | 129 --------------------- tests/parser/test_annotation_parser.py | 130 --------------------- 8 files changed, 9 insertions(+), 746 deletions(-) delete mode 100644 core/ast/annotations.py delete mode 100644 lexer/annotations.py delete mode 100644 parser/annotations.py delete mode 100644 tests/lexer/test_annotation_lexer.py delete mode 100644 tests/parser/test_annotation_parser.py diff --git a/core/ast/annotations.py b/core/ast/annotations.py deleted file mode 100644 index a885e29..0000000 --- a/core/ast/annotations.py +++ /dev/null @@ -1,107 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Generic, Optional, TypeVar - -from lexer.token import Token - -T = TypeVar("T") - - -@dataclass(frozen=True) -class Stmt(ABC): - @abstractmethod - def accept(self, visitor: Visitor[T]) -> T: ... - - class Visitor(ABC, Generic[T]): - @abstractmethod - def visit_annotation_stmt(self, stmt: AnnotationStmt) -> T: ... - - -@dataclass(frozen=True) -class AnnotationStmt(Stmt): - name: Token - schema: Optional[SchemaExpr] - - def accept(self, visitor: Stmt.Visitor[T]) -> T: - return visitor.visit_annotation_stmt(self) - - -@dataclass(frozen=True) -class Expr(ABC): - @abstractmethod - def accept(self, visitor: Visitor[T]) -> T: ... - - class Visitor(ABC, Generic[T]): - @abstractmethod - def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ... - - @abstractmethod - def visit_literal_expr(self, expr: LiteralExpr) -> T: ... - - @abstractmethod - def visit_type_expr(self, expr: TypeExpr) -> T: ... - - @abstractmethod - def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ... - - @abstractmethod - def visit_schema_expr(self, expr: SchemaExpr) -> T: ... - - @abstractmethod - def visit_schema_element_expr(self, expr: SchemaElementExpr) -> T: ... - - -@dataclass(frozen=True) -class WildcardExpr(Expr): - token: Token - - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_wildcard_expr(self) - - -@dataclass(frozen=True) -class LiteralExpr(Expr): - value: Any - - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_literal_expr(self) - - -@dataclass(frozen=True) -class TypeExpr(Expr): - name: Token - constraints: list[ConstraintExpr] - - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_type_expr(self) - - -@dataclass(frozen=True) -class ConstraintExpr(Expr): - left: Expr - op: Token - right: Expr - - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_constraint_expr(self) - - -@dataclass(frozen=True) -class SchemaExpr(Expr): - left: Token - elements: list[Expr] - right: Token - - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_schema_expr(self) - - -@dataclass(frozen=True) -class SchemaElementExpr(Expr): - name: Optional[Token] - type: Optional[Expr] - - def accept(self, visitor: Expr.Visitor[T]) -> T: - return visitor.visit_schema_element_expr(self) diff --git a/core/ast/printer.py b/core/ast/printer.py index 086c581..3e13f55 100644 --- a/core/ast/printer.py +++ b/core/ast/printer.py @@ -1,11 +1,10 @@ from __future__ import annotations +import io from contextlib import contextmanager from enum import Enum, auto -import io from typing import Generator, Generic, Optional, Protocol, TypeVar -import core.ast.annotations as a import core.ast.midas as m @@ -84,113 +83,6 @@ class AstPrinter(Generic[T]): child.accept(self) -class AnnotationAstPrinter(AstPrinter, a.Expr.Visitor[None], a.Stmt.Visitor[None]): - def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> None: - self._write_line("AnnotationStmt") - with self._child_level(): - self._write_line(f'name: "{stmt.name.lexeme}"') - self._write_optional_child("schema", stmt.schema, last=True) - - def visit_type_expr(self, expr: a.TypeExpr): - self._write_line("TypeExpr") - with self._child_level(): - self._write_line(f'name: "{expr.name.lexeme}"') - self._write_line("constraints", last=True) - with self._child_level(): - for i, constraint in enumerate(expr.constraints): - self._idx = i - if i == len(expr.constraints) - 1: - self._mark_last() - constraint.accept(self) - - def visit_constraint_expr(self, expr: a.ConstraintExpr) -> None: - self._write_line("ConstraintExpr") - with self._child_level(): - self._write_line("left") - with self._child_level(): - self._mark_last() - expr.left.accept(self) - - self._write_line(f"operator: {expr.op.lexeme}") - - self._write_line("right", last=True) - with self._child_level(): - self._mark_last() - expr.right.accept(self) - - def visit_schema_expr(self, expr: a.SchemaExpr): - self._write_line("SchemaExpr") - with self._child_level(): - for i, elmt in enumerate(expr.elements): - self._idx = i - if i == len(expr.elements) - 1: - self._mark_last() - elmt.accept(self) - - def visit_schema_element_expr(self, expr: a.SchemaElementExpr): - self._write_line("SchemaElementExpr") - with self._child_level(): - name_text: str = "None" if expr.name is None else f'"{expr.name.lexeme}"' - self._write_line(f"name: {name_text}") - self._write_optional_child("type", expr.type, last=True) - - def visit_wildcard_expr(self, expr: a.WildcardExpr) -> None: - self._write_line("WildcardExpr") - - def visit_literal_expr(self, expr: a.LiteralExpr) -> None: - self._write_line("LiteralExpr") - with self._child_level(): - self._write_line(f"value: {expr.value}", last=True) - - -class AnnotationPrinter(a.Expr.Visitor[str], a.Stmt.Visitor[str]): - def print(self, expr: a.Expr | a.Stmt): - return expr.accept(self) - - def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> str: - schema: str = "" - if stmt.schema is not None: - schema = stmt.schema.accept(self) - return f"{stmt.name.lexeme}{schema}" - - def visit_type_expr(self, expr: a.TypeExpr) -> str: - parts: list[str] = [expr.name.lexeme] - for constraint in expr.constraints: - parts.append("(" + constraint.accept(self) + ")") - return " + ".join(parts) - - def visit_constraint_expr(self, expr: a.ConstraintExpr) -> str: - parts: list[str] = [ - expr.left.accept(self), - expr.op.lexeme, - expr.right.accept(self), - ] - return " ".join(parts) - - def visit_schema_expr(self, expr: a.SchemaExpr) -> str: - res: str = expr.left.lexeme - res += ", ".join(elmt.accept(self) for elmt in expr.elements) - res += expr.right.lexeme - return res - - def visit_schema_element_expr(self, expr: a.SchemaElementExpr) -> str: - parts: list[str] = [] - if expr.name is not None: - parts.append(expr.name.lexeme) - - if expr.type is None: - parts.append("_") - else: - parts.append(expr.type.accept(self)) - return ": ".join(parts) - - def visit_wildcard_expr(self, expr: a.WildcardExpr) -> str: - return "_" - - def visit_literal_expr(self, expr: a.LiteralExpr) -> str: - return str(expr.value) - - class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): def visit_type_stmt(self, stmt: m.TypeStmt): self._write_line("TypeStmt") @@ -289,6 +181,7 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): with self._child_level(): self._write_line(f"value: {expr.value}", last=True) + class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): def __init__(self, indent: int = 4): self.indent: int = indent @@ -302,11 +195,8 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): return expr.accept(self) def visit_type_stmt(self, stmt: m.TypeStmt): - bases: list[str] = [ - b.accept(self) - for b in stmt.bases - ] - + bases: list[str] = [b.accept(self) for b in stmt.bases] + res: str = self.indented(f"type {stmt.name.lexeme}<{', '.join(bases)}>") if stmt.body is not None: res += " {\n" @@ -348,8 +238,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): def visit_type_body_expr(self, expr: m.TypeBodyExpr): properties: list[str] = [ - self.indented(prop.accept(self)) - for prop in expr.properties + self.indented(prop.accept(self)) for prop in expr.properties ] return "\n".join(properties) @@ -357,4 +246,4 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): return "_" def visit_literal_expr(self, expr: m.LiteralExpr): - return str(expr.value) \ No newline at end of file + return str(expr.value) diff --git a/lexer/annotations.py b/lexer/annotations.py deleted file mode 100644 index ae9faae..0000000 --- a/lexer/annotations.py +++ /dev/null @@ -1,102 +0,0 @@ -from lexer.base import Lexer -from lexer.keyword import ANNOTATION_KEYWORDS -from lexer.token import TokenType - - -class AnnotationLexer(Lexer): - def scan_token(self) -> None: - char: str = self.advance() - match char: - case "(": - self.add_token(TokenType.LEFT_PAREN) - case ")": - self.add_token(TokenType.RIGHT_PAREN) - case "[": - self.add_token(TokenType.LEFT_BRACKET) - case "]": - self.add_token(TokenType.RIGHT_BRACKET) - case "<": - self.add_token( - TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS - ) - case ">": - self.add_token( - TokenType.GREATER_EQUAL if self.match("=") else TokenType.GREATER - ) - case "=": - self.add_token( - TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL - ) - case "!": - if self.match("="): - self.add_token(TokenType.BANG_EQUAL) - else: - self.error("Unexpected single bang. Did you mean '!=' ?") - case ":": - self.add_token(TokenType.COLON) - case ",": - self.add_token(TokenType.COMMA) - case "_": - self.add_token(TokenType.UNDERSCORE) - case "+": - self.add_token(TokenType.PLUS) - case "#": - self.scan_comment() - case "\n": - self.add_token(TokenType.NEWLINE) - case " " | "\r" | "\t": - # Consume all whitespace characters until EOL or EOF - while ( - self.peek().isspace() - and self.peek() != "\n" - and not self.is_at_end() - ): - self.advance() - self.add_token(TokenType.WHITESPACE) - case _: - if char.isdigit(): - self.scan_number() - elif char.isalpha(): - self.scan_identifier() - else: - self.error("Unexpected character") - return None - - def scan_number(self): - """Scan the rest of number and add it as a token - - This method handles both simple integers and floats. Scientific notation - and base prefixes (0x, 0b, 0o) are not supported - """ - while self.peek().isdigit(): - self.advance() - - if self.peek() == "." and self.peek_next().isdigit(): - self.advance() - while self.peek().isdigit(): - self.advance() - - value: float = float(self.source[self.start : self.idx]) - self.add_token(TokenType.NUMBER, value) - - def scan_identifier(self): - """Scan the rest of an identifier and add it as a token - - An identifier starts with a letter, followed by any number of - alphanumerical characters or underscores - """ - while self.peek().isalnum() or self.peek() == "_": - self.advance() - - lexeme: str = self.source[self.start : self.idx] - token_type: TokenType = ANNOTATION_KEYWORDS.get(lexeme, TokenType.IDENTIFIER) - self.add_token(token_type) - - def scan_comment(self): - """Scan the rest of a comment and add it as a token - - A comment starts with a `#` character and ends at the EOL/EOF - """ - while self.peek() != "\n" and not self.is_at_end(): - self.advance() - self.add_token(TokenType.COMMENT) diff --git a/lexer/keyword.py b/lexer/keyword.py index b66f21a..2ab3f45 100644 --- a/lexer/keyword.py +++ b/lexer/keyword.py @@ -1,12 +1,6 @@ from lexer.token import TokenType -ANNOTATION_KEYWORDS: dict[str, TokenType] = { - "True": TokenType.TRUE, - "False": TokenType.FALSE, - "None": TokenType.NONE, -} - -MIDAS_KEYWORDS: dict[str, TokenType] = { +KEYWORDS: dict[str, TokenType] = { "type": TokenType.TYPE, "op": TokenType.OP, "constraint": TokenType.CONSTRAINT, diff --git a/lexer/midas.py b/lexer/midas.py index ad29a68..b366f09 100644 --- a/lexer/midas.py +++ b/lexer/midas.py @@ -1,5 +1,5 @@ from lexer.base import Lexer -from lexer.keyword import MIDAS_KEYWORDS +from lexer.keyword import KEYWORDS from lexer.token import TokenType @@ -102,7 +102,7 @@ class MidasLexer(Lexer): self.advance() lexeme: str = self.source[self.start : self.idx] - token_type: TokenType = MIDAS_KEYWORDS.get(lexeme, TokenType.IDENTIFIER) + token_type: TokenType = KEYWORDS.get(lexeme, TokenType.IDENTIFIER) self.add_token(token_type) def scan_comment(self): diff --git a/parser/annotations.py b/parser/annotations.py deleted file mode 100644 index 0bf99d6..0000000 --- a/parser/annotations.py +++ /dev/null @@ -1,152 +0,0 @@ -from typing import Optional - -from core.ast.annotations import ( - AnnotationStmt, - ConstraintExpr, - Expr, - LiteralExpr, - SchemaElementExpr, - SchemaExpr, - Stmt, - TypeExpr, - WildcardExpr, -) -from lexer.token import Token, TokenType -from parser.base import Parser -from parser.errors import ParsingError - - -class AnnotationParser(Parser): - """A simple parser for custom type annotations""" - - SYNC_BOUNDARY: set[TokenType] = set() - - def parse(self) -> Optional[Stmt]: - stmt: Optional[Stmt] = None - try: - stmt = self.annotation() - except ParsingError: - self.synchronize() - if not self.is_at_end(): - self.error(self.peek(), "Extra tokens") - return stmt - - def synchronize(self): - """Skip tokens until a synchronization boundary is found - - This method allows gracefully recovering from a parse error - to a safe place and continue parsing - """ - self.advance() - while not self.is_at_end(): - if self.peek().type in self.SYNC_BOUNDARY: - return - self.advance() - - def annotation(self) -> AnnotationStmt: - """Parse an annotation - - An annotation is written as `Type` or `Type[Schema]` - - Returns: - AnnotationStmt: the parsed annotation statement - """ - - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type identifier") - schema: Optional[SchemaExpr] = None - if self.match(TokenType.LEFT_BRACKET): - schema = self.schema() - return AnnotationStmt(name=name, schema=schema) - - def type_expr(self) -> TypeExpr: - """Parse a type expression - - Returns: - TypeExpr: the parsed type expression - """ - name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") - constraints: list[ConstraintExpr] = [] - - while not self.is_at_end() and self.match(TokenType.PLUS): - self.consume(TokenType.LEFT_PAREN, "Expected '(' before type constraint") - constraints.append(self.constraint_expr()) - self.consume(TokenType.RIGHT_PAREN, "Expected ')' after type constraint") - - return TypeExpr(name=name, constraints=constraints) - - def constraint_expr(self) -> ConstraintExpr: - """Parse a type constraint - - Returns: - ConstraintExpr: the parsed type constraint expression - """ - - left: Expr = self.constraint_value() - op: Token = self.constraint_operator() - right: Expr = self.constraint_value() - return ConstraintExpr(left=left, op=op, right=right) - - def constraint_value(self) -> Expr: - if self.match(TokenType.UNDERSCORE): - return WildcardExpr(self.previous()) - return self.literal() - - def literal(self) -> LiteralExpr: - if self.match(TokenType.FALSE): - return LiteralExpr(False) - if self.match(TokenType.TRUE): - return LiteralExpr(True) - if self.match(TokenType.NONE): - return LiteralExpr(None) - - if self.match(TokenType.NUMBER): - return LiteralExpr(self.previous().value) - - raise self.error(self.peek(), "Expected literal") - - def constraint_operator(self) -> Token: - if self.match(TokenType.LESS, TokenType.LESS_EQUAL, TokenType.GREATER, TokenType.GREATER_EQUAL, TokenType.EQUAL_EQUAL, TokenType.BANG_EQUAL): - return self.previous() - raise self.error(self.peek(), "Expected constraint operator") - - def schema(self) -> SchemaExpr: - """Parse a schema definition - - A comma separated list of schema elements - - Returns: - SchemaExpr: the parsed schema expression - """ - left: Token = self.previous() - elements: list[Expr] = [] - while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end(): - elements.append(self.schema_element()) - if not self.check(TokenType.RIGHT_BRACKET): - self.consume(TokenType.COMMA, "Expected ',' between schema elements") - - right: Token = self.consume(TokenType.RIGHT_BRACKET, "Unclosed schema") - return SchemaExpr(left=left, elements=elements, right=right) - - def schema_element(self) -> SchemaElementExpr: - """Parse a schema element - - An anonymous element (`_`), a type, an untyped named column (`name: _`), - or a named column (`name: Type`) - - Returns: - SchemaElementExpr: the parsed schema element expression - """ - if self.match(TokenType.UNDERSCORE): - return SchemaElementExpr(name=None, type=None) - - if not self.check(TokenType.IDENTIFIER): - raise self.error(self.peek(), "Expected schema element") - - name: Optional[Token] = None - type: Optional[TypeExpr] = None - if self.check_next(TokenType.COLON): - name = self.advance() - self.advance() - if not self.match(TokenType.UNDERSCORE): - type = self.type_expr() - return SchemaElementExpr(name=name, type=type) diff --git a/tests/lexer/test_annotation_lexer.py b/tests/lexer/test_annotation_lexer.py deleted file mode 100644 index 33a83a1..0000000 --- a/tests/lexer/test_annotation_lexer.py +++ /dev/null @@ -1,129 +0,0 @@ -from typing import Any - -import pytest - -from lexer.annotations import AnnotationLexer -from lexer.token import Token, TokenType - - -def scan(source: str) -> list[Token]: - return AnnotationLexer(source).process() - - -def assert_n_tokens(tokens: list[Token], n: int): - assert len(tokens) == n + 1 - assert tokens[-1].type == TokenType.EOF - - -@pytest.mark.parametrize( - "src,expected", - [ - ("(", TokenType.LEFT_PAREN), - (")", TokenType.RIGHT_PAREN), - ("[", TokenType.LEFT_BRACKET), - ("]", TokenType.RIGHT_BRACKET), - (":", TokenType.COLON), - (",", TokenType.COMMA), - ("_", TokenType.UNDERSCORE), - ], -) -def test_punctuation(src: str, expected: TokenType): - tokens: list[Token] = scan(src) - assert_n_tokens(tokens, 1) - assert tokens[0].type == expected - - -@pytest.mark.parametrize( - "src,expected", - [ - ("+", TokenType.PLUS), - (">", TokenType.GREATER), - (">=", TokenType.GREATER_EQUAL), - ("<", TokenType.LESS), - ("<=", TokenType.LESS_EQUAL), - ("=", TokenType.EQUAL), - ("==", TokenType.EQUAL_EQUAL), - ("!=", TokenType.BANG_EQUAL), - ], -) -def test_operators(src: str, expected: TokenType): - tokens: list[Token] = scan(src) - assert_n_tokens(tokens, 1) - assert tokens[0].type == expected - - -@pytest.mark.parametrize( - "src,expected", - [ - ("a", TokenType.IDENTIFIER), - ("foo", TokenType.IDENTIFIER), - ("foo1", TokenType.IDENTIFIER), - ("foo_", TokenType.IDENTIFIER), - ("foo_bar1_baz2", TokenType.IDENTIFIER), - ("FOO_BAR1_BAZ2", TokenType.IDENTIFIER), - ("True", TokenType.TRUE), - ("False", TokenType.FALSE), - ("None", TokenType.NONE), - ], -) -def test_identifiers_keywords(src: str, expected: TokenType): - tokens: list[Token] = scan(src) - assert_n_tokens(tokens, 1) - assert tokens[0].type == expected - - -@pytest.mark.parametrize( - "src,expected", - [ - ("#", TokenType.COMMENT), - ("# This is a comment", TokenType.COMMENT), - (" ", TokenType.WHITESPACE), - ("\t", TokenType.WHITESPACE), - ("\r", TokenType.WHITESPACE), - (" \t \t", TokenType.WHITESPACE), - ("\n", TokenType.NEWLINE), - ], -) -def test_misc(src: str, expected: TokenType): - tokens: list[Token] = scan(src) - assert_n_tokens(tokens, 1) - assert tokens[0].type == expected - - -@pytest.mark.parametrize( - "src,expected_type,expected_value", - [ - ("0", TokenType.NUMBER, 0), - ("0.0", TokenType.NUMBER, 0), - ("1234.56", TokenType.NUMBER, 1234.56), - ], -) -def test_literals(src: str, expected_type: TokenType, expected_value: Any): - tokens: list[Token] = scan(src) - assert_n_tokens(tokens, 1) - assert tokens[0].type == expected_type - assert tokens[0].value == expected_value - - -def test_single_bang_error(): - with pytest.raises(SyntaxError): - scan("!") - - -@pytest.mark.parametrize( - "src", - [ - "-", - "*", - "/", - "{", - "}", - "@", - '"', - "'", - ".", - ], -) -def test_unexpected_character(src: str): - with pytest.raises(SyntaxError): - scan(src) diff --git a/tests/parser/test_annotation_parser.py b/tests/parser/test_annotation_parser.py deleted file mode 100644 index 9c034dd..0000000 --- a/tests/parser/test_annotation_parser.py +++ /dev/null @@ -1,130 +0,0 @@ -from typing import Optional - -import pytest - -from core.ast.annotations import ( - AnnotationStmt, - ConstraintExpr, - Expr, - LiteralExpr, - SchemaElementExpr, - SchemaExpr, - Stmt, - TypeExpr, - WildcardExpr, -) -from lexer.annotations import AnnotationLexer -from lexer.position import Position -from lexer.token import Token -from parser.annotations import AnnotationParser - - -class AstSerializer(Stmt.Visitor[str], Expr.Visitor[str]): - def serialize(self, stmt: Stmt): - return stmt.accept(self) - - def visit_annotation_stmt(self, stmt: AnnotationStmt) -> str: - schema: str = "" - if stmt.schema is not None: - schema = " " + stmt.schema.accept(self) - return f"(annotation {stmt.name.lexeme}{schema})" - - def visit_schema_expr(self, expr: SchemaExpr) -> str: - elements: list[str] = [elmt.accept(self) for elmt in expr.elements] - return f"(schema {' '.join(elements)})" - - def visit_schema_element_expr(self, expr: SchemaElementExpr) -> str: - name: str = expr.name.lexeme if expr.name is not None else "_" - type: str = expr.type.accept(self) if expr.type is not None else "_" - return f"({name} {type})" - - def visit_type_expr(self, expr: TypeExpr) -> str: - res: str = f"({expr.name.lexeme}" - for constraint in expr.constraints: - res += " " + constraint.accept(self) - res += ")" - return res - - def visit_constraint_expr(self, expr: ConstraintExpr) -> str: - return f"(constraint {expr.left.accept(self)} {expr.op.lexeme} {expr.right.accept(self)})" - - def visit_wildcard_expr(self, expr: WildcardExpr) -> str: - return "(_)" - - def visit_literal_expr(self, expr: LiteralExpr) -> str: - return f"({expr.value})" - - -def parse(source: str) -> Optional[Stmt]: - tokens: list[Token] = AnnotationLexer(source).process() - return AnnotationParser(tokens).parse() - - -def must_parse(source: str) -> Stmt: - stmt: Optional[Stmt] = parse(source) - assert stmt is not None - return stmt - - -def ast_str(source: str) -> str: - stmt: Stmt = must_parse(source) - return AstSerializer().serialize(stmt) - - -@pytest.mark.parametrize( - "src,expected", - [ - ("Type", "(annotation Type)"), - ("Type[]", "(annotation Type (schema ))"), - ( - """ - Frame[ - verified: bool, - birth_year: int, - height: float + ( _ > 0 ) + ( _ < 250 ), - name: str, - date: datetime, - float, # unnamed - unknown: _, # untyped - _ # unnamed and untyped - ] - """, - "(annotation Frame (schema (verified (bool)) (birth_year (int)) (height (float (constraint (_) > (0.0)) (constraint (_) < (250.0)))) (name (str)) (date (datetime)) (_ (float)) (unknown _) (_ _)))", - ), - ], -) -def test_expressions(src: str, expected: str): - assert ast_str(src) == expected - - -@pytest.mark.parametrize( - "src,pos,should_fail", - [ - ("", (1, 1), True), - ("42", (1, 1), True), - ("True", (1, 1), True), - ("Type[", (1, 6), True), - ("Type[] Type2", (1, 8), False), - ("Type[bool:]", (1, 11), True), - ("Type[3]", (1, 6), True), - ("Type[bool float]", (1, 11), True), - ("Type[bool (_ < 2)]", (1, 11), True), - ("Type[bool + _ < 2)]", (1, 13), True), - ("Type[bool + (_ < 2]", (1, 19), True), - ("Type[bool + (< 2)]", (1, 14), True), - ("Type[bool + (_ + 2)]", (1, 16), True), - ("Type[bool + (Foo + Bar)]", (1, 14), True), - # ("Type[bool,]", (1, 11), True), # trailing comma is accepted, TODO: update parser or EBNF - ("Type[bool, Type[]]", (1, 16), True), - ("Type[foo: 3]", (1, 11), True), - ], -) -def test_parsing_error(src: str, pos: tuple[int, int], should_fail: bool): - tokens: list[Token] = AnnotationLexer(src).process() - parser: AnnotationParser = AnnotationParser(tokens) - stmt: Optional[Stmt] = parser.parse() - if should_fail: - assert stmt is None - assert len(parser.errors) != 0 - error_pos: Position = parser.errors[0].token.position - assert (error_pos.line, error_pos.column) == pos