fix(parser)!: remove annotation lexer and parser
This commit is contained in:
@@ -1,107 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Generic, Optional, TypeVar
|
|
||||||
|
|
||||||
from lexer.token import Token
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class Stmt(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
|
||||||
|
|
||||||
class Visitor(ABC, Generic[T]):
|
|
||||||
@abstractmethod
|
|
||||||
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> T: ...
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class AnnotationStmt(Stmt):
|
|
||||||
name: Token
|
|
||||||
schema: Optional[SchemaExpr]
|
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_annotation_stmt(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class Expr(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
|
||||||
|
|
||||||
class Visitor(ABC, Generic[T]):
|
|
||||||
@abstractmethod
|
|
||||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_literal_expr(self, expr: LiteralExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_type_expr(self, expr: TypeExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_schema_expr(self, expr: SchemaExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> T: ...
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class WildcardExpr(Expr):
|
|
||||||
token: Token
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_wildcard_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class LiteralExpr(Expr):
|
|
||||||
value: Any
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_literal_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class TypeExpr(Expr):
|
|
||||||
name: Token
|
|
||||||
constraints: list[ConstraintExpr]
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_type_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ConstraintExpr(Expr):
|
|
||||||
left: Expr
|
|
||||||
op: Token
|
|
||||||
right: Expr
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_constraint_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class SchemaExpr(Expr):
|
|
||||||
left: Token
|
|
||||||
elements: list[Expr]
|
|
||||||
right: Token
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_schema_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class SchemaElementExpr(Expr):
|
|
||||||
name: Optional[Token]
|
|
||||||
type: Optional[Expr]
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_schema_element_expr(self)
|
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
import io
|
|
||||||
from typing import Generator, Generic, Optional, Protocol, TypeVar
|
from typing import Generator, Generic, Optional, Protocol, TypeVar
|
||||||
|
|
||||||
import core.ast.annotations as a
|
|
||||||
import core.ast.midas as m
|
import core.ast.midas as m
|
||||||
|
|
||||||
|
|
||||||
@@ -84,113 +83,6 @@ class AstPrinter(Generic[T]):
|
|||||||
child.accept(self)
|
child.accept(self)
|
||||||
|
|
||||||
|
|
||||||
class AnnotationAstPrinter(AstPrinter, a.Expr.Visitor[None], a.Stmt.Visitor[None]):
|
|
||||||
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> None:
|
|
||||||
self._write_line("AnnotationStmt")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
|
||||||
self._write_optional_child("schema", stmt.schema, last=True)
|
|
||||||
|
|
||||||
def visit_type_expr(self, expr: a.TypeExpr):
|
|
||||||
self._write_line("TypeExpr")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
|
||||||
self._write_line("constraints", last=True)
|
|
||||||
with self._child_level():
|
|
||||||
for i, constraint in enumerate(expr.constraints):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(expr.constraints) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
constraint.accept(self)
|
|
||||||
|
|
||||||
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> None:
|
|
||||||
self._write_line("ConstraintExpr")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line("left")
|
|
||||||
with self._child_level():
|
|
||||||
self._mark_last()
|
|
||||||
expr.left.accept(self)
|
|
||||||
|
|
||||||
self._write_line(f"operator: {expr.op.lexeme}")
|
|
||||||
|
|
||||||
self._write_line("right", last=True)
|
|
||||||
with self._child_level():
|
|
||||||
self._mark_last()
|
|
||||||
expr.right.accept(self)
|
|
||||||
|
|
||||||
def visit_schema_expr(self, expr: a.SchemaExpr):
|
|
||||||
self._write_line("SchemaExpr")
|
|
||||||
with self._child_level():
|
|
||||||
for i, elmt in enumerate(expr.elements):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(expr.elements) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
elmt.accept(self)
|
|
||||||
|
|
||||||
def visit_schema_element_expr(self, expr: a.SchemaElementExpr):
|
|
||||||
self._write_line("SchemaElementExpr")
|
|
||||||
with self._child_level():
|
|
||||||
name_text: str = "None" if expr.name is None else f'"{expr.name.lexeme}"'
|
|
||||||
self._write_line(f"name: {name_text}")
|
|
||||||
self._write_optional_child("type", expr.type, last=True)
|
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> None:
|
|
||||||
self._write_line("WildcardExpr")
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: a.LiteralExpr) -> None:
|
|
||||||
self._write_line("LiteralExpr")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line(f"value: {expr.value}", last=True)
|
|
||||||
|
|
||||||
|
|
||||||
class AnnotationPrinter(a.Expr.Visitor[str], a.Stmt.Visitor[str]):
|
|
||||||
def print(self, expr: a.Expr | a.Stmt):
|
|
||||||
return expr.accept(self)
|
|
||||||
|
|
||||||
def visit_annotation_stmt(self, stmt: a.AnnotationStmt) -> str:
|
|
||||||
schema: str = ""
|
|
||||||
if stmt.schema is not None:
|
|
||||||
schema = stmt.schema.accept(self)
|
|
||||||
return f"{stmt.name.lexeme}{schema}"
|
|
||||||
|
|
||||||
def visit_type_expr(self, expr: a.TypeExpr) -> str:
|
|
||||||
parts: list[str] = [expr.name.lexeme]
|
|
||||||
for constraint in expr.constraints:
|
|
||||||
parts.append("(" + constraint.accept(self) + ")")
|
|
||||||
return " + ".join(parts)
|
|
||||||
|
|
||||||
def visit_constraint_expr(self, expr: a.ConstraintExpr) -> str:
|
|
||||||
parts: list[str] = [
|
|
||||||
expr.left.accept(self),
|
|
||||||
expr.op.lexeme,
|
|
||||||
expr.right.accept(self),
|
|
||||||
]
|
|
||||||
return " ".join(parts)
|
|
||||||
|
|
||||||
def visit_schema_expr(self, expr: a.SchemaExpr) -> str:
|
|
||||||
res: str = expr.left.lexeme
|
|
||||||
res += ", ".join(elmt.accept(self) for elmt in expr.elements)
|
|
||||||
res += expr.right.lexeme
|
|
||||||
return res
|
|
||||||
|
|
||||||
def visit_schema_element_expr(self, expr: a.SchemaElementExpr) -> str:
|
|
||||||
parts: list[str] = []
|
|
||||||
if expr.name is not None:
|
|
||||||
parts.append(expr.name.lexeme)
|
|
||||||
|
|
||||||
if expr.type is None:
|
|
||||||
parts.append("_")
|
|
||||||
else:
|
|
||||||
parts.append(expr.type.accept(self))
|
|
||||||
return ": ".join(parts)
|
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: a.WildcardExpr) -> str:
|
|
||||||
return "_"
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: a.LiteralExpr) -> str:
|
|
||||||
return str(expr.value)
|
|
||||||
|
|
||||||
|
|
||||||
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
||||||
def visit_type_stmt(self, stmt: m.TypeStmt):
|
def visit_type_stmt(self, stmt: m.TypeStmt):
|
||||||
self._write_line("TypeStmt")
|
self._write_line("TypeStmt")
|
||||||
@@ -289,6 +181,7 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
|||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line(f"value: {expr.value}", last=True)
|
self._write_line(f"value: {expr.value}", last=True)
|
||||||
|
|
||||||
|
|
||||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||||
def __init__(self, indent: int = 4):
|
def __init__(self, indent: int = 4):
|
||||||
self.indent: int = indent
|
self.indent: int = indent
|
||||||
@@ -302,11 +195,8 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
|||||||
return expr.accept(self)
|
return expr.accept(self)
|
||||||
|
|
||||||
def visit_type_stmt(self, stmt: m.TypeStmt):
|
def visit_type_stmt(self, stmt: m.TypeStmt):
|
||||||
bases: list[str] = [
|
bases: list[str] = [b.accept(self) for b in stmt.bases]
|
||||||
b.accept(self)
|
|
||||||
for b in stmt.bases
|
|
||||||
]
|
|
||||||
|
|
||||||
res: str = self.indented(f"type {stmt.name.lexeme}<{', '.join(bases)}>")
|
res: str = self.indented(f"type {stmt.name.lexeme}<{', '.join(bases)}>")
|
||||||
if stmt.body is not None:
|
if stmt.body is not None:
|
||||||
res += " {\n"
|
res += " {\n"
|
||||||
@@ -348,8 +238,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
|||||||
|
|
||||||
def visit_type_body_expr(self, expr: m.TypeBodyExpr):
|
def visit_type_body_expr(self, expr: m.TypeBodyExpr):
|
||||||
properties: list[str] = [
|
properties: list[str] = [
|
||||||
self.indented(prop.accept(self))
|
self.indented(prop.accept(self)) for prop in expr.properties
|
||||||
for prop in expr.properties
|
|
||||||
]
|
]
|
||||||
return "\n".join(properties)
|
return "\n".join(properties)
|
||||||
|
|
||||||
@@ -357,4 +246,4 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
|||||||
return "_"
|
return "_"
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: m.LiteralExpr):
|
def visit_literal_expr(self, expr: m.LiteralExpr):
|
||||||
return str(expr.value)
|
return str(expr.value)
|
||||||
|
|||||||
@@ -1,102 +0,0 @@
|
|||||||
from lexer.base import Lexer
|
|
||||||
from lexer.keyword import ANNOTATION_KEYWORDS
|
|
||||||
from lexer.token import TokenType
|
|
||||||
|
|
||||||
|
|
||||||
class AnnotationLexer(Lexer):
|
|
||||||
def scan_token(self) -> None:
|
|
||||||
char: str = self.advance()
|
|
||||||
match char:
|
|
||||||
case "(":
|
|
||||||
self.add_token(TokenType.LEFT_PAREN)
|
|
||||||
case ")":
|
|
||||||
self.add_token(TokenType.RIGHT_PAREN)
|
|
||||||
case "[":
|
|
||||||
self.add_token(TokenType.LEFT_BRACKET)
|
|
||||||
case "]":
|
|
||||||
self.add_token(TokenType.RIGHT_BRACKET)
|
|
||||||
case "<":
|
|
||||||
self.add_token(
|
|
||||||
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS
|
|
||||||
)
|
|
||||||
case ">":
|
|
||||||
self.add_token(
|
|
||||||
TokenType.GREATER_EQUAL if self.match("=") else TokenType.GREATER
|
|
||||||
)
|
|
||||||
case "=":
|
|
||||||
self.add_token(
|
|
||||||
TokenType.EQUAL_EQUAL if self.match("=") else TokenType.EQUAL
|
|
||||||
)
|
|
||||||
case "!":
|
|
||||||
if self.match("="):
|
|
||||||
self.add_token(TokenType.BANG_EQUAL)
|
|
||||||
else:
|
|
||||||
self.error("Unexpected single bang. Did you mean '!=' ?")
|
|
||||||
case ":":
|
|
||||||
self.add_token(TokenType.COLON)
|
|
||||||
case ",":
|
|
||||||
self.add_token(TokenType.COMMA)
|
|
||||||
case "_":
|
|
||||||
self.add_token(TokenType.UNDERSCORE)
|
|
||||||
case "+":
|
|
||||||
self.add_token(TokenType.PLUS)
|
|
||||||
case "#":
|
|
||||||
self.scan_comment()
|
|
||||||
case "\n":
|
|
||||||
self.add_token(TokenType.NEWLINE)
|
|
||||||
case " " | "\r" | "\t":
|
|
||||||
# Consume all whitespace characters until EOL or EOF
|
|
||||||
while (
|
|
||||||
self.peek().isspace()
|
|
||||||
and self.peek() != "\n"
|
|
||||||
and not self.is_at_end()
|
|
||||||
):
|
|
||||||
self.advance()
|
|
||||||
self.add_token(TokenType.WHITESPACE)
|
|
||||||
case _:
|
|
||||||
if char.isdigit():
|
|
||||||
self.scan_number()
|
|
||||||
elif char.isalpha():
|
|
||||||
self.scan_identifier()
|
|
||||||
else:
|
|
||||||
self.error("Unexpected character")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def scan_number(self):
|
|
||||||
"""Scan the rest of number and add it as a token
|
|
||||||
|
|
||||||
This method handles both simple integers and floats. Scientific notation
|
|
||||||
and base prefixes (0x, 0b, 0o) are not supported
|
|
||||||
"""
|
|
||||||
while self.peek().isdigit():
|
|
||||||
self.advance()
|
|
||||||
|
|
||||||
if self.peek() == "." and self.peek_next().isdigit():
|
|
||||||
self.advance()
|
|
||||||
while self.peek().isdigit():
|
|
||||||
self.advance()
|
|
||||||
|
|
||||||
value: float = float(self.source[self.start : self.idx])
|
|
||||||
self.add_token(TokenType.NUMBER, value)
|
|
||||||
|
|
||||||
def scan_identifier(self):
|
|
||||||
"""Scan the rest of an identifier and add it as a token
|
|
||||||
|
|
||||||
An identifier starts with a letter, followed by any number of
|
|
||||||
alphanumerical characters or underscores
|
|
||||||
"""
|
|
||||||
while self.peek().isalnum() or self.peek() == "_":
|
|
||||||
self.advance()
|
|
||||||
|
|
||||||
lexeme: str = self.source[self.start : self.idx]
|
|
||||||
token_type: TokenType = ANNOTATION_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
|
||||||
self.add_token(token_type)
|
|
||||||
|
|
||||||
def scan_comment(self):
|
|
||||||
"""Scan the rest of a comment and add it as a token
|
|
||||||
|
|
||||||
A comment starts with a `#` character and ends at the EOL/EOF
|
|
||||||
"""
|
|
||||||
while self.peek() != "\n" and not self.is_at_end():
|
|
||||||
self.advance()
|
|
||||||
self.add_token(TokenType.COMMENT)
|
|
||||||
@@ -1,12 +1,6 @@
|
|||||||
from lexer.token import TokenType
|
from lexer.token import TokenType
|
||||||
|
|
||||||
ANNOTATION_KEYWORDS: dict[str, TokenType] = {
|
KEYWORDS: dict[str, TokenType] = {
|
||||||
"True": TokenType.TRUE,
|
|
||||||
"False": TokenType.FALSE,
|
|
||||||
"None": TokenType.NONE,
|
|
||||||
}
|
|
||||||
|
|
||||||
MIDAS_KEYWORDS: dict[str, TokenType] = {
|
|
||||||
"type": TokenType.TYPE,
|
"type": TokenType.TYPE,
|
||||||
"op": TokenType.OP,
|
"op": TokenType.OP,
|
||||||
"constraint": TokenType.CONSTRAINT,
|
"constraint": TokenType.CONSTRAINT,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from lexer.base import Lexer
|
from lexer.base import Lexer
|
||||||
from lexer.keyword import MIDAS_KEYWORDS
|
from lexer.keyword import KEYWORDS
|
||||||
from lexer.token import TokenType
|
from lexer.token import TokenType
|
||||||
|
|
||||||
|
|
||||||
@@ -102,7 +102,7 @@ class MidasLexer(Lexer):
|
|||||||
self.advance()
|
self.advance()
|
||||||
|
|
||||||
lexeme: str = self.source[self.start : self.idx]
|
lexeme: str = self.source[self.start : self.idx]
|
||||||
token_type: TokenType = MIDAS_KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
token_type: TokenType = KEYWORDS.get(lexeme, TokenType.IDENTIFIER)
|
||||||
self.add_token(token_type)
|
self.add_token(token_type)
|
||||||
|
|
||||||
def scan_comment(self):
|
def scan_comment(self):
|
||||||
|
|||||||
@@ -1,152 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from core.ast.annotations import (
|
|
||||||
AnnotationStmt,
|
|
||||||
ConstraintExpr,
|
|
||||||
Expr,
|
|
||||||
LiteralExpr,
|
|
||||||
SchemaElementExpr,
|
|
||||||
SchemaExpr,
|
|
||||||
Stmt,
|
|
||||||
TypeExpr,
|
|
||||||
WildcardExpr,
|
|
||||||
)
|
|
||||||
from lexer.token import Token, TokenType
|
|
||||||
from parser.base import Parser
|
|
||||||
from parser.errors import ParsingError
|
|
||||||
|
|
||||||
|
|
||||||
class AnnotationParser(Parser):
|
|
||||||
"""A simple parser for custom type annotations"""
|
|
||||||
|
|
||||||
SYNC_BOUNDARY: set[TokenType] = set()
|
|
||||||
|
|
||||||
def parse(self) -> Optional[Stmt]:
|
|
||||||
stmt: Optional[Stmt] = None
|
|
||||||
try:
|
|
||||||
stmt = self.annotation()
|
|
||||||
except ParsingError:
|
|
||||||
self.synchronize()
|
|
||||||
if not self.is_at_end():
|
|
||||||
self.error(self.peek(), "Extra tokens")
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
def synchronize(self):
|
|
||||||
"""Skip tokens until a synchronization boundary is found
|
|
||||||
|
|
||||||
This method allows gracefully recovering from a parse error
|
|
||||||
to a safe place and continue parsing
|
|
||||||
"""
|
|
||||||
self.advance()
|
|
||||||
while not self.is_at_end():
|
|
||||||
if self.peek().type in self.SYNC_BOUNDARY:
|
|
||||||
return
|
|
||||||
self.advance()
|
|
||||||
|
|
||||||
def annotation(self) -> AnnotationStmt:
|
|
||||||
"""Parse an annotation
|
|
||||||
|
|
||||||
An annotation is written as `Type` or `Type[Schema]`
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AnnotationStmt: the parsed annotation statement
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type identifier")
|
|
||||||
schema: Optional[SchemaExpr] = None
|
|
||||||
if self.match(TokenType.LEFT_BRACKET):
|
|
||||||
schema = self.schema()
|
|
||||||
return AnnotationStmt(name=name, schema=schema)
|
|
||||||
|
|
||||||
def type_expr(self) -> TypeExpr:
|
|
||||||
"""Parse a type expression
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TypeExpr: the parsed type expression
|
|
||||||
"""
|
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
|
||||||
constraints: list[ConstraintExpr] = []
|
|
||||||
|
|
||||||
while not self.is_at_end() and self.match(TokenType.PLUS):
|
|
||||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before type constraint")
|
|
||||||
constraints.append(self.constraint_expr())
|
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after type constraint")
|
|
||||||
|
|
||||||
return TypeExpr(name=name, constraints=constraints)
|
|
||||||
|
|
||||||
def constraint_expr(self) -> ConstraintExpr:
|
|
||||||
"""Parse a type constraint
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ConstraintExpr: the parsed type constraint expression
|
|
||||||
"""
|
|
||||||
|
|
||||||
left: Expr = self.constraint_value()
|
|
||||||
op: Token = self.constraint_operator()
|
|
||||||
right: Expr = self.constraint_value()
|
|
||||||
return ConstraintExpr(left=left, op=op, right=right)
|
|
||||||
|
|
||||||
def constraint_value(self) -> Expr:
|
|
||||||
if self.match(TokenType.UNDERSCORE):
|
|
||||||
return WildcardExpr(self.previous())
|
|
||||||
return self.literal()
|
|
||||||
|
|
||||||
def literal(self) -> LiteralExpr:
|
|
||||||
if self.match(TokenType.FALSE):
|
|
||||||
return LiteralExpr(False)
|
|
||||||
if self.match(TokenType.TRUE):
|
|
||||||
return LiteralExpr(True)
|
|
||||||
if self.match(TokenType.NONE):
|
|
||||||
return LiteralExpr(None)
|
|
||||||
|
|
||||||
if self.match(TokenType.NUMBER):
|
|
||||||
return LiteralExpr(self.previous().value)
|
|
||||||
|
|
||||||
raise self.error(self.peek(), "Expected literal")
|
|
||||||
|
|
||||||
def constraint_operator(self) -> Token:
|
|
||||||
if self.match(TokenType.LESS, TokenType.LESS_EQUAL, TokenType.GREATER, TokenType.GREATER_EQUAL, TokenType.EQUAL_EQUAL, TokenType.BANG_EQUAL):
|
|
||||||
return self.previous()
|
|
||||||
raise self.error(self.peek(), "Expected constraint operator")
|
|
||||||
|
|
||||||
def schema(self) -> SchemaExpr:
|
|
||||||
"""Parse a schema definition
|
|
||||||
|
|
||||||
A comma separated list of schema elements
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SchemaExpr: the parsed schema expression
|
|
||||||
"""
|
|
||||||
left: Token = self.previous()
|
|
||||||
elements: list[Expr] = []
|
|
||||||
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
|
|
||||||
elements.append(self.schema_element())
|
|
||||||
if not self.check(TokenType.RIGHT_BRACKET):
|
|
||||||
self.consume(TokenType.COMMA, "Expected ',' between schema elements")
|
|
||||||
|
|
||||||
right: Token = self.consume(TokenType.RIGHT_BRACKET, "Unclosed schema")
|
|
||||||
return SchemaExpr(left=left, elements=elements, right=right)
|
|
||||||
|
|
||||||
def schema_element(self) -> SchemaElementExpr:
|
|
||||||
"""Parse a schema element
|
|
||||||
|
|
||||||
An anonymous element (`_`), a type, an untyped named column (`name: _`),
|
|
||||||
or a named column (`name: Type`)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SchemaElementExpr: the parsed schema element expression
|
|
||||||
"""
|
|
||||||
if self.match(TokenType.UNDERSCORE):
|
|
||||||
return SchemaElementExpr(name=None, type=None)
|
|
||||||
|
|
||||||
if not self.check(TokenType.IDENTIFIER):
|
|
||||||
raise self.error(self.peek(), "Expected schema element")
|
|
||||||
|
|
||||||
name: Optional[Token] = None
|
|
||||||
type: Optional[TypeExpr] = None
|
|
||||||
if self.check_next(TokenType.COLON):
|
|
||||||
name = self.advance()
|
|
||||||
self.advance()
|
|
||||||
if not self.match(TokenType.UNDERSCORE):
|
|
||||||
type = self.type_expr()
|
|
||||||
return SchemaElementExpr(name=name, type=type)
|
|
||||||
@@ -1,129 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lexer.annotations import AnnotationLexer
|
|
||||||
from lexer.token import Token, TokenType
|
|
||||||
|
|
||||||
|
|
||||||
def scan(source: str) -> list[Token]:
|
|
||||||
return AnnotationLexer(source).process()
|
|
||||||
|
|
||||||
|
|
||||||
def assert_n_tokens(tokens: list[Token], n: int):
|
|
||||||
assert len(tokens) == n + 1
|
|
||||||
assert tokens[-1].type == TokenType.EOF
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("(", TokenType.LEFT_PAREN),
|
|
||||||
(")", TokenType.RIGHT_PAREN),
|
|
||||||
("[", TokenType.LEFT_BRACKET),
|
|
||||||
("]", TokenType.RIGHT_BRACKET),
|
|
||||||
(":", TokenType.COLON),
|
|
||||||
(",", TokenType.COMMA),
|
|
||||||
("_", TokenType.UNDERSCORE),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_punctuation(src: str, expected: TokenType):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("+", TokenType.PLUS),
|
|
||||||
(">", TokenType.GREATER),
|
|
||||||
(">=", TokenType.GREATER_EQUAL),
|
|
||||||
("<", TokenType.LESS),
|
|
||||||
("<=", TokenType.LESS_EQUAL),
|
|
||||||
("=", TokenType.EQUAL),
|
|
||||||
("==", TokenType.EQUAL_EQUAL),
|
|
||||||
("!=", TokenType.BANG_EQUAL),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_operators(src: str, expected: TokenType):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("a", TokenType.IDENTIFIER),
|
|
||||||
("foo", TokenType.IDENTIFIER),
|
|
||||||
("foo1", TokenType.IDENTIFIER),
|
|
||||||
("foo_", TokenType.IDENTIFIER),
|
|
||||||
("foo_bar1_baz2", TokenType.IDENTIFIER),
|
|
||||||
("FOO_BAR1_BAZ2", TokenType.IDENTIFIER),
|
|
||||||
("True", TokenType.TRUE),
|
|
||||||
("False", TokenType.FALSE),
|
|
||||||
("None", TokenType.NONE),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_identifiers_keywords(src: str, expected: TokenType):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("#", TokenType.COMMENT),
|
|
||||||
("# This is a comment", TokenType.COMMENT),
|
|
||||||
(" ", TokenType.WHITESPACE),
|
|
||||||
("\t", TokenType.WHITESPACE),
|
|
||||||
("\r", TokenType.WHITESPACE),
|
|
||||||
(" \t \t", TokenType.WHITESPACE),
|
|
||||||
("\n", TokenType.NEWLINE),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_misc(src: str, expected: TokenType):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected_type,expected_value",
|
|
||||||
[
|
|
||||||
("0", TokenType.NUMBER, 0),
|
|
||||||
("0.0", TokenType.NUMBER, 0),
|
|
||||||
("1234.56", TokenType.NUMBER, 1234.56),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_literals(src: str, expected_type: TokenType, expected_value: Any):
|
|
||||||
tokens: list[Token] = scan(src)
|
|
||||||
assert_n_tokens(tokens, 1)
|
|
||||||
assert tokens[0].type == expected_type
|
|
||||||
assert tokens[0].value == expected_value
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_bang_error():
|
|
||||||
with pytest.raises(SyntaxError):
|
|
||||||
scan("!")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src",
|
|
||||||
[
|
|
||||||
"-",
|
|
||||||
"*",
|
|
||||||
"/",
|
|
||||||
"{",
|
|
||||||
"}",
|
|
||||||
"@",
|
|
||||||
'"',
|
|
||||||
"'",
|
|
||||||
".",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_unexpected_character(src: str):
|
|
||||||
with pytest.raises(SyntaxError):
|
|
||||||
scan(src)
|
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from core.ast.annotations import (
|
|
||||||
AnnotationStmt,
|
|
||||||
ConstraintExpr,
|
|
||||||
Expr,
|
|
||||||
LiteralExpr,
|
|
||||||
SchemaElementExpr,
|
|
||||||
SchemaExpr,
|
|
||||||
Stmt,
|
|
||||||
TypeExpr,
|
|
||||||
WildcardExpr,
|
|
||||||
)
|
|
||||||
from lexer.annotations import AnnotationLexer
|
|
||||||
from lexer.position import Position
|
|
||||||
from lexer.token import Token
|
|
||||||
from parser.annotations import AnnotationParser
|
|
||||||
|
|
||||||
|
|
||||||
class AstSerializer(Stmt.Visitor[str], Expr.Visitor[str]):
|
|
||||||
def serialize(self, stmt: Stmt):
|
|
||||||
return stmt.accept(self)
|
|
||||||
|
|
||||||
def visit_annotation_stmt(self, stmt: AnnotationStmt) -> str:
|
|
||||||
schema: str = ""
|
|
||||||
if stmt.schema is not None:
|
|
||||||
schema = " " + stmt.schema.accept(self)
|
|
||||||
return f"(annotation {stmt.name.lexeme}{schema})"
|
|
||||||
|
|
||||||
def visit_schema_expr(self, expr: SchemaExpr) -> str:
|
|
||||||
elements: list[str] = [elmt.accept(self) for elmt in expr.elements]
|
|
||||||
return f"(schema {' '.join(elements)})"
|
|
||||||
|
|
||||||
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> str:
|
|
||||||
name: str = expr.name.lexeme if expr.name is not None else "_"
|
|
||||||
type: str = expr.type.accept(self) if expr.type is not None else "_"
|
|
||||||
return f"({name} {type})"
|
|
||||||
|
|
||||||
def visit_type_expr(self, expr: TypeExpr) -> str:
|
|
||||||
res: str = f"({expr.name.lexeme}"
|
|
||||||
for constraint in expr.constraints:
|
|
||||||
res += " " + constraint.accept(self)
|
|
||||||
res += ")"
|
|
||||||
return res
|
|
||||||
|
|
||||||
def visit_constraint_expr(self, expr: ConstraintExpr) -> str:
|
|
||||||
return f"(constraint {expr.left.accept(self)} {expr.op.lexeme} {expr.right.accept(self)})"
|
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> str:
|
|
||||||
return "(_)"
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: LiteralExpr) -> str:
|
|
||||||
return f"({expr.value})"
|
|
||||||
|
|
||||||
|
|
||||||
def parse(source: str) -> Optional[Stmt]:
|
|
||||||
tokens: list[Token] = AnnotationLexer(source).process()
|
|
||||||
return AnnotationParser(tokens).parse()
|
|
||||||
|
|
||||||
|
|
||||||
def must_parse(source: str) -> Stmt:
|
|
||||||
stmt: Optional[Stmt] = parse(source)
|
|
||||||
assert stmt is not None
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
|
|
||||||
def ast_str(source: str) -> str:
|
|
||||||
stmt: Stmt = must_parse(source)
|
|
||||||
return AstSerializer().serialize(stmt)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,expected",
|
|
||||||
[
|
|
||||||
("Type", "(annotation Type)"),
|
|
||||||
("Type[]", "(annotation Type (schema ))"),
|
|
||||||
(
|
|
||||||
"""
|
|
||||||
Frame[
|
|
||||||
verified: bool,
|
|
||||||
birth_year: int,
|
|
||||||
height: float + ( _ > 0 ) + ( _ < 250 ),
|
|
||||||
name: str,
|
|
||||||
date: datetime,
|
|
||||||
float, # unnamed
|
|
||||||
unknown: _, # untyped
|
|
||||||
_ # unnamed and untyped
|
|
||||||
]
|
|
||||||
""",
|
|
||||||
"(annotation Frame (schema (verified (bool)) (birth_year (int)) (height (float (constraint (_) > (0.0)) (constraint (_) < (250.0)))) (name (str)) (date (datetime)) (_ (float)) (unknown _) (_ _)))",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_expressions(src: str, expected: str):
|
|
||||||
assert ast_str(src) == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"src,pos,should_fail",
|
|
||||||
[
|
|
||||||
("", (1, 1), True),
|
|
||||||
("42", (1, 1), True),
|
|
||||||
("True", (1, 1), True),
|
|
||||||
("Type[", (1, 6), True),
|
|
||||||
("Type[] Type2", (1, 8), False),
|
|
||||||
("Type[bool:]", (1, 11), True),
|
|
||||||
("Type[3]", (1, 6), True),
|
|
||||||
("Type[bool float]", (1, 11), True),
|
|
||||||
("Type[bool (_ < 2)]", (1, 11), True),
|
|
||||||
("Type[bool + _ < 2)]", (1, 13), True),
|
|
||||||
("Type[bool + (_ < 2]", (1, 19), True),
|
|
||||||
("Type[bool + (< 2)]", (1, 14), True),
|
|
||||||
("Type[bool + (_ + 2)]", (1, 16), True),
|
|
||||||
("Type[bool + (Foo + Bar)]", (1, 14), True),
|
|
||||||
# ("Type[bool,]", (1, 11), True), # trailing comma is accepted, TODO: update parser or EBNF
|
|
||||||
("Type[bool, Type[]]", (1, 16), True),
|
|
||||||
("Type[foo: 3]", (1, 11), True),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_parsing_error(src: str, pos: tuple[int, int], should_fail: bool):
|
|
||||||
tokens: list[Token] = AnnotationLexer(src).process()
|
|
||||||
parser: AnnotationParser = AnnotationParser(tokens)
|
|
||||||
stmt: Optional[Stmt] = parser.parse()
|
|
||||||
if should_fail:
|
|
||||||
assert stmt is None
|
|
||||||
assert len(parser.errors) != 0
|
|
||||||
error_pos: Position = parser.errors[0].token.position
|
|
||||||
assert (error_pos.line, error_pos.column) == pos
|
|
||||||
Reference in New Issue
Block a user