refactor(ast): move expression visitor class inside Expr

This commit is contained in:
2026-02-05 16:30:30 +01:00
parent 9c2a1fb908
commit 7adcd3d93d
3 changed files with 21 additions and 22 deletions

View File

@@ -16,23 +16,22 @@ class Expr(ABC):
def accept(self, visitor: Visitor[T]) -> T: def accept(self, visitor: Visitor[T]) -> T:
... ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_binary_expr(self, expr: BinaryExpr) -> T:
...
class Visitor(ABC, Generic[T]): @abstractmethod
@abstractmethod def visit_unary_expr(self, expr: UnaryExpr) -> T:
def visit_binary_expr(self, expr: BinaryExpr) -> T: ...
...
@abstractmethod @abstractmethod
def visit_unary_expr(self, expr: UnaryExpr) -> T: def visit_grouping_expr(self, expr: GroupingExpr) -> T:
... ...
@abstractmethod @abstractmethod
def visit_grouping_expr(self, expr: GroupingExpr) -> T: def visit_literal_expr(self, expr: LiteralExpr) -> T:
... ...
@abstractmethod
def visit_literal_expr(self, expr: LiteralExpr) -> T:
...
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -41,7 +40,7 @@ class BinaryExpr(Expr):
operator: Token operator: Token
right: Expr right: Expr
def accept(self, visitor: Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_binary_expr(self) return visitor.visit_binary_expr(self)
@@ -50,7 +49,7 @@ class UnaryExpr(Expr):
operator: Token operator: Token
right: Expr right: Expr
def accept(self, visitor: Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_unary_expr(self) return visitor.visit_unary_expr(self)
@@ -58,7 +57,7 @@ class UnaryExpr(Expr):
class GroupingExpr(Expr): class GroupingExpr(Expr):
expression: Expr expression: Expr
def accept(self, visitor: Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_grouping_expr(self) return visitor.visit_grouping_expr(self)
@@ -66,5 +65,5 @@ class GroupingExpr(Expr):
class LiteralExpr(Expr): class LiteralExpr(Expr):
value: Any value: Any
def accept(self, visitor: Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_literal_expr(self) return visitor.visit_literal_expr(self)

View File

@@ -1,7 +1,7 @@
from src.ast.expr import Visitor, Expr, LiteralExpr, T, GroupingExpr, UnaryExpr, BinaryExpr from src.ast.expr import Expr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr
class AstPrinter(Visitor[str]): class AstPrinter(Expr.Visitor[str]):
def print(self, expr: Expr): def print(self, expr: Expr):
return expr.accept(self) return expr.accept(self)

View File

@@ -1,12 +1,12 @@
from typing import Any from typing import Any
from src.ast.expr import Visitor, LiteralExpr, T, GroupingExpr, UnaryExpr, BinaryExpr, Expr from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr
from src.interpreter.error import PebbleRuntimeError from src.interpreter.error import PebbleRuntimeError
from src.pebble import Pebble from src.pebble import Pebble
from src.token import TokenType, Token from src.token import TokenType, Token
class Interpreter(Visitor[Any]): class Interpreter(Expr.Visitor[Any]):
def interpret(self, expr: Expr) -> Any: def interpret(self, expr: Expr) -> Any:
try: try:
return self.evaluate(expr) return self.evaluate(expr)