diff --git a/core/ast/printer.py b/core/ast/printer.py index c624637..4394834 100644 --- a/core/ast/printer.py +++ b/core/ast/printer.py @@ -1,9 +1,11 @@ +from __future__ import annotations + from contextlib import contextmanager from enum import Enum, auto import io -from typing import Generator, Optional +from typing import Generator, Generic, Optional, Protocol, TypeVar -from core.ast.annotations import Expr, TypeExpr, SchemaExpr, SchemaElementExpr +import core.ast.annotations as a class _Level(Enum): @@ -12,7 +14,14 @@ class _Level(Enum): LAST = auto() -class AnnotationAstPrinter(Expr.Visitor[None]): +class Expr(Protocol): + def accept(self, printer: AstPrinter) -> None: ... + + +T = TypeVar("T", bound=Expr) + + +class AstPrinter(Generic[T]): LAST_CHILD = "└── " CHILD = "├── " VERTICAL = "│ " @@ -23,7 +32,7 @@ class AnnotationAstPrinter(Expr.Visitor[None]): self._idx: Optional[int] = None self._buf: io.StringIO = io.StringIO() - def print(self, expr: Expr): + def print(self, expr: T): self._buf = io.StringIO() expr.accept(self) return self._buf.getvalue() @@ -60,7 +69,7 @@ class AnnotationAstPrinter(Expr.Visitor[None]): return "".join(parts) def _write_optional_child( - self, label: str, child: Optional[Expr], *, last: bool = False + self, label: str, child: Optional[T], *, last: bool = False ): if last: self._mark_last() @@ -71,13 +80,15 @@ class AnnotationAstPrinter(Expr.Visitor[None]): with self._child_level(last=True): child.accept(self) - def visit_type_expr(self, expr: TypeExpr): + +class AnnotationAstPrinter(AstPrinter, a.Expr.Visitor[None]): + def visit_type_expr(self, expr: a.TypeExpr): self._write_line("TypeExpr") with self._child_level(): self._write_line(f'name: "{expr.name.lexeme}"') self._write_optional_child("schema", expr.schema, last=True) - def visit_schema_expr(self, expr: SchemaExpr): + def visit_schema_expr(self, expr: a.SchemaExpr): self._write_line("SchemaExpr") with self._child_level(): for i, elmt in enumerate(expr.elements): @@ -86,7 +97,7 @@ class AnnotationAstPrinter(Expr.Visitor[None]): self._mark_last() elmt.accept(self) - def visit_schema_element_expr(self, expr: SchemaElementExpr): + def visit_schema_element_expr(self, expr: a.SchemaElementExpr): self._write_line("SchemaElementExpr") with self._child_level(): name_text: str = "None" if expr.name is None else f'"{expr.name.lexeme}"' @@ -94,23 +105,23 @@ class AnnotationAstPrinter(Expr.Visitor[None]): self._write_optional_child("type", expr.type, last=True) -class AnnotationPrinter(Expr.Visitor[str]): - def print(self, expr: Expr): +class AnnotationPrinter(a.Expr.Visitor[str]): + def print(self, expr: a.Expr): return expr.accept(self) - def visit_type_expr(self, expr: TypeExpr) -> str: + def visit_type_expr(self, expr: a.TypeExpr) -> str: schema: str = "" if expr.schema is not None: schema = expr.schema.accept(self) return f"{expr.name.lexeme}{schema}" - def visit_schema_expr(self, expr: SchemaExpr) -> str: + def visit_schema_expr(self, expr: a.SchemaExpr) -> str: res: str = expr.left.lexeme res += ", ".join(elmt.accept(self) for elmt in expr.elements) res += expr.right.lexeme return res - def visit_schema_element_expr(self, expr: SchemaElementExpr) -> str: + def visit_schema_element_expr(self, expr: a.SchemaElementExpr) -> str: parts: list[str] = [] if expr.name is not None: parts.append(expr.name.lexeme)