refactor(parser): improve AST printer
refactored the messy AST printer impletation with Claude to use a context manager, an enum and extract common functions Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,84 +1,97 @@
|
||||
from typing import Optional
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
import io
|
||||
from typing import Generator, Optional
|
||||
|
||||
from core.ast.annotations import Expr, TypeExpr, SchemaExpr, SchemaElementExpr
|
||||
|
||||
|
||||
class AnnotationAstPrinter(Expr.Visitor[str]):
|
||||
class _Level(Enum):
|
||||
EMPTY = auto()
|
||||
ACTIVE = auto()
|
||||
LAST = auto()
|
||||
|
||||
|
||||
class AnnotationAstPrinter(Expr.Visitor[None]):
|
||||
LAST_CHILD = "└── "
|
||||
CHILD = "├── "
|
||||
VERTICAL = "│ "
|
||||
EMPTY = " "
|
||||
|
||||
def __init__(self):
|
||||
self.level: int = 0
|
||||
self.idx: Optional[int] = None
|
||||
self.last: bool = False
|
||||
self.levels: list[int] = []
|
||||
self._levels: list[_Level] = []
|
||||
self._idx: Optional[int] = None
|
||||
self._buf: io.StringIO = io.StringIO()
|
||||
|
||||
def print(self, expr: Expr):
|
||||
return expr.accept(self)
|
||||
self._buf = io.StringIO()
|
||||
expr.accept(self)
|
||||
return self._buf.getvalue()
|
||||
|
||||
def print_line(self, text: str) -> str:
|
||||
indent: str = ""
|
||||
for enabled in self.levels[:-1]:
|
||||
if enabled:
|
||||
indent += self.VERTICAL
|
||||
@contextmanager
|
||||
def _child_level(self, last: bool = False) -> Generator[None, None, None]:
|
||||
self._levels.append(_Level.LAST if last else _Level.ACTIVE)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._levels.pop()
|
||||
|
||||
def _mark_last(self):
|
||||
if self._levels:
|
||||
self._levels[-1] = _Level.LAST
|
||||
|
||||
def _write_line(self, text: str):
|
||||
indent: str = self._build_indent()
|
||||
if self._idx is not None:
|
||||
text = f"[{self._idx}] {text}"
|
||||
self._idx = None
|
||||
self._buf.write(indent + text + "\n")
|
||||
|
||||
def _build_indent(self) -> str:
|
||||
parts: list[str] = []
|
||||
for level in self._levels[:-1]:
|
||||
parts.append(self.EMPTY if level == _Level.EMPTY else self.VERTICAL)
|
||||
if self._levels:
|
||||
if self._levels[-1] == _Level.LAST:
|
||||
parts.append(self.LAST_CHILD)
|
||||
self._levels[-1] = _Level.EMPTY
|
||||
else:
|
||||
indent += self.EMPTY
|
||||
parts.append(self.CHILD)
|
||||
return "".join(parts)
|
||||
|
||||
if len(self.levels) > 0:
|
||||
if self.levels[-1] == 2:
|
||||
indent += self.LAST_CHILD
|
||||
self.levels[-1] = 0
|
||||
def _write_optional_child(
|
||||
self, label: str, child: Optional[Expr], *, last: bool = False
|
||||
):
|
||||
if last:
|
||||
self._mark_last()
|
||||
if child is None:
|
||||
self._write_line(f"{label}: None")
|
||||
else:
|
||||
indent += self.CHILD
|
||||
if self.idx is not None:
|
||||
text = f"[{self.idx}] {text}"
|
||||
self.idx = None
|
||||
return indent + text + "\n"
|
||||
self._write_line(label)
|
||||
with self._child_level(last=True):
|
||||
child.accept(self)
|
||||
|
||||
def visit_type_expr(self, expr: TypeExpr) -> str:
|
||||
res: str = self.print_line("TypeExpr")
|
||||
self.levels.append(1)
|
||||
res += self.print_line(f'name: "{expr.name.lexeme}"')
|
||||
self.levels[-1] = 2
|
||||
if expr.schema is None:
|
||||
res += self.print_line("schema: None")
|
||||
else:
|
||||
res += self.print_line("schema")
|
||||
self.levels.append(2)
|
||||
res += expr.schema.accept(self)
|
||||
self.levels.pop()
|
||||
self.levels.pop()
|
||||
return res
|
||||
def visit_type_expr(self, expr: TypeExpr):
|
||||
self._write_line("TypeExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||
self._write_optional_child("schema", expr.schema, last=True)
|
||||
|
||||
def visit_schema_expr(self, expr: SchemaExpr) -> str:
|
||||
res: str = self.print_line("SchemaExpr")
|
||||
self.levels.append(1)
|
||||
def visit_schema_expr(self, expr: SchemaExpr):
|
||||
self._write_line("SchemaExpr")
|
||||
with self._child_level():
|
||||
for i, elmt in enumerate(expr.elements):
|
||||
self.idx = i
|
||||
self._idx = i
|
||||
if i == len(expr.elements) - 1:
|
||||
self.levels[-1] = 2
|
||||
res += elmt.accept(self)
|
||||
self.levels.pop()
|
||||
return res
|
||||
self._mark_last()
|
||||
elmt.accept(self)
|
||||
|
||||
def visit_schema_element_expr(self, expr: SchemaElementExpr) -> str:
|
||||
res: str = self.print_line("SchemaElementExpr")
|
||||
self.levels.append(1)
|
||||
res += self.print_line(
|
||||
"name: " + ("None" if expr.name is None else f'"{expr.name.lexeme}"')
|
||||
)
|
||||
self.levels[-1] = 2
|
||||
if expr.type is None:
|
||||
res += self.print_line("type: None")
|
||||
else:
|
||||
res += self.print_line("type")
|
||||
self.levels.append(2)
|
||||
res += expr.type.accept(self)
|
||||
self.levels.pop()
|
||||
self.levels.pop()
|
||||
return res
|
||||
def visit_schema_element_expr(self, expr: SchemaElementExpr):
|
||||
self._write_line("SchemaElementExpr")
|
||||
with self._child_level():
|
||||
name_text: str = "None" if expr.name is None else f'"{expr.name.lexeme}"'
|
||||
self._write_line(f"name: {name_text}")
|
||||
self._write_optional_child("type", expr.type, last=True)
|
||||
|
||||
|
||||
class AnnotationPrinter(Expr.Visitor[str]):
|
||||
|
||||
Reference in New Issue
Block a user