diff --git a/core/ast/printer.py b/core/ast/printer.py index c9d6841..c624637 100644 --- a/core/ast/printer.py +++ b/core/ast/printer.py @@ -1,84 +1,97 @@ -from typing import Optional +from contextlib import contextmanager +from enum import Enum, auto +import io +from typing import Generator, Optional from core.ast.annotations import Expr, TypeExpr, SchemaExpr, SchemaElementExpr -class AnnotationAstPrinter(Expr.Visitor[str]): +class _Level(Enum): + EMPTY = auto() + ACTIVE = auto() + LAST = auto() + + +class AnnotationAstPrinter(Expr.Visitor[None]): LAST_CHILD = "└── " CHILD = "├── " VERTICAL = "│ " EMPTY = " " def __init__(self): - self.level: int = 0 - self.idx: Optional[int] = None - self.last: bool = False - self.levels: list[int] = [] + self._levels: list[_Level] = [] + self._idx: Optional[int] = None + self._buf: io.StringIO = io.StringIO() def print(self, expr: Expr): - return expr.accept(self) + self._buf = io.StringIO() + expr.accept(self) + return self._buf.getvalue() - def print_line(self, text: str) -> str: - indent: str = "" - for enabled in self.levels[:-1]: - if enabled: - indent += self.VERTICAL + @contextmanager + def _child_level(self, last: bool = False) -> Generator[None, None, None]: + self._levels.append(_Level.LAST if last else _Level.ACTIVE) + try: + yield + finally: + self._levels.pop() + + def _mark_last(self): + if self._levels: + self._levels[-1] = _Level.LAST + + def _write_line(self, text: str): + indent: str = self._build_indent() + if self._idx is not None: + text = f"[{self._idx}] {text}" + self._idx = None + self._buf.write(indent + text + "\n") + + def _build_indent(self) -> str: + parts: list[str] = [] + for level in self._levels[:-1]: + parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL) + if self._levels: + if self._levels[-1] == _Level.LAST: + parts.append(self.LAST_CHILD) + self._levels[-1] = _Level.EMPTY else: - indent += self.EMPTY + parts.append(self.CHILD) + return "".join(parts) - if len(self.levels) > 0: - if self.levels[-1] == 2: - indent += self.LAST_CHILD - self.levels[-1] = 0 - else: - indent += self.CHILD - if self.idx is not None: - text = f"[{self.idx}] {text}" - self.idx = None - return indent + text + "\n" - - def visit_type_expr(self, expr: TypeExpr) -> str: - res: str = self.print_line("TypeExpr") - self.levels.append(1) - res += self.print_line(f'name: "{expr.name.lexeme}"') - self.levels[-1] = 2 - if expr.schema is None: - res += self.print_line("schema: None") + def _write_optional_child( + self, label: str, child: Optional[Expr], *, last: bool = False + ): + if last: + self._mark_last() + if child is None: + self._write_line(f"{label}: None") else: - res += self.print_line("schema") - self.levels.append(2) - res += expr.schema.accept(self) - self.levels.pop() - self.levels.pop() - return res + self._write_line(label) + with self._child_level(last=True): + child.accept(self) - def visit_schema_expr(self, expr: SchemaExpr) -> str: - res: str = self.print_line("SchemaExpr") - self.levels.append(1) - for i, elmt in enumerate(expr.elements): - self.idx = i - if i == len(expr.elements) - 1: - self.levels[-1] = 2 - res += elmt.accept(self) - self.levels.pop() - return res + def visit_type_expr(self, expr: TypeExpr): + self._write_line("TypeExpr") + with self._child_level(): + self._write_line(f'name: "{expr.name.lexeme}"') + self._write_optional_child("schema", expr.schema, last=True) - def visit_schema_element_expr(self, expr: SchemaElementExpr) -> str: - res: str = self.print_line("SchemaElementExpr") - self.levels.append(1) - res += self.print_line( - "name: " + ("None" if expr.name is None else f'"{expr.name.lexeme}"') - ) - self.levels[-1] = 2 - if expr.type is None: - res += self.print_line("type: None") - else: - res += self.print_line("type") - self.levels.append(2) - res += expr.type.accept(self) - self.levels.pop() - self.levels.pop() - return res + def visit_schema_expr(self, expr: SchemaExpr): + self._write_line("SchemaExpr") + with self._child_level(): + for i, elmt in enumerate(expr.elements): + self._idx = i + if i == len(expr.elements) - 1: + self._mark_last() + elmt.accept(self) + + def visit_schema_element_expr(self, expr: SchemaElementExpr): + self._write_line("SchemaElementExpr") + with self._child_level(): + name_text: str = "None" if expr.name is None else f'"{expr.name.lexeme}"' + self._write_line(f"name: {name_text}") + self._write_optional_child("type", expr.type, last=True) class AnnotationPrinter(Expr.Visitor[str]):