fix(parser): prepare printer for midas printer

This commit is contained in:
2026-05-14 02:00:52 +02:00
parent 6d885a0449
commit 4d25b43a4e

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
from contextlib import contextmanager
from enum import Enum, auto
import io
from typing import Generator, Optional
from typing import Generator, Generic, Optional, Protocol, TypeVar
from core.ast.annotations import Expr, TypeExpr, SchemaExpr, SchemaElementExpr
import core.ast.annotations as a
class _Level(Enum):
@@ -12,7 +14,14 @@ class _Level(Enum):
LAST = auto()
class AnnotationAstPrinter(Expr.Visitor[None]):
class Expr(Protocol):
def accept(self, printer: AstPrinter) -> None: ...
T = TypeVar("T", bound=Expr)
class AstPrinter(Generic[T]):
LAST_CHILD = "└── "
CHILD = "├── "
VERTICAL = ""
@@ -23,7 +32,7 @@ class AnnotationAstPrinter(Expr.Visitor[None]):
self._idx: Optional[int] = None
self._buf: io.StringIO = io.StringIO()
def print(self, expr: Expr):
def print(self, expr: T):
self._buf = io.StringIO()
expr.accept(self)
return self._buf.getvalue()
@@ -60,7 +69,7 @@ class AnnotationAstPrinter(Expr.Visitor[None]):
return "".join(parts)
def _write_optional_child(
self, label: str, child: Optional[Expr], *, last: bool = False
self, label: str, child: Optional[T], *, last: bool = False
):
if last:
self._mark_last()
@@ -71,13 +80,15 @@ class AnnotationAstPrinter(Expr.Visitor[None]):
with self._child_level(last=True):
child.accept(self)
def visit_type_expr(self, expr: TypeExpr):
class AnnotationAstPrinter(AstPrinter, a.Expr.Visitor[None]):
def visit_type_expr(self, expr: a.TypeExpr):
self._write_line("TypeExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_optional_child("schema", expr.schema, last=True)
def visit_schema_expr(self, expr: SchemaExpr):
def visit_schema_expr(self, expr: a.SchemaExpr):
self._write_line("SchemaExpr")
with self._child_level():
for i, elmt in enumerate(expr.elements):
@@ -86,7 +97,7 @@ class AnnotationAstPrinter(Expr.Visitor[None]):
self._mark_last()
elmt.accept(self)
def visit_schema_element_expr(self, expr: SchemaElementExpr):
def visit_schema_element_expr(self, expr: a.SchemaElementExpr):
self._write_line("SchemaElementExpr")
with self._child_level():
name_text: str = "None" if expr.name is None else f'"{expr.name.lexeme}"'
@@ -94,23 +105,23 @@ class AnnotationAstPrinter(Expr.Visitor[None]):
self._write_optional_child("type", expr.type, last=True)
class AnnotationPrinter(Expr.Visitor[str]):
def print(self, expr: Expr):
class AnnotationPrinter(a.Expr.Visitor[str]):
def print(self, expr: a.Expr):
return expr.accept(self)
def visit_type_expr(self, expr: TypeExpr) -> str:
def visit_type_expr(self, expr: a.TypeExpr) -> str:
schema: str = ""
if expr.schema is not None:
schema = expr.schema.accept(self)
return f"{expr.name.lexeme}{schema}"
def visit_schema_expr(self, expr: SchemaExpr) -> str:
def visit_schema_expr(self, expr: a.SchemaExpr) -> str:
res: str = expr.left.lexeme
res += ", ".join(elmt.accept(self) for elmt in expr.elements)
res += expr.right.lexeme
return res
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> str:
def visit_schema_element_expr(self, expr: a.SchemaElementExpr) -> str:
parts: list[str] = []
if expr.name is not None:
parts.append(expr.name.lexeme)