188 lines
6.6 KiB
Python
188 lines
6.6 KiB
Python
from enum import Enum, auto
|
|
from typing import Any
|
|
|
|
from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \
|
|
CallExpr, SetExpr, GetExpr, ThisExpr, SuperExpr, FStringExpr, FStringEmbedExpr
|
|
from src.ast.stmt import Stmt, LetStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt, FunctionStmt, \
|
|
ReturnStmt, BreakStmt, ContinueStmt, ClassStmt
|
|
|
|
|
|
class ClassType(Enum):
|
|
NONE = auto()
|
|
CLASS = auto()
|
|
|
|
|
|
class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
|
|
def __init__(self, indent: int = 4):
|
|
self.indent: int = indent
|
|
self.level: int = 0
|
|
self.current_class: ClassType = ClassType.NONE
|
|
|
|
def indented(self, text: str) -> str:
|
|
return " " * (self.level * self.indent) + text
|
|
|
|
def print(self, statements: list[Stmt]) -> str:
|
|
self.level = 0
|
|
res: str = ""
|
|
for stmt in statements:
|
|
res += self.format(stmt)
|
|
return res
|
|
|
|
def format(self, obj: Expr | Stmt) -> str:
|
|
return obj.accept(self)
|
|
|
|
def visit_assign_expr(self, expr: AssignExpr) -> str:
|
|
return f"{expr.name.lexeme} = {self.format(expr.value)}"
|
|
|
|
def visit_binary_expr(self, expr: BinaryExpr) -> str:
|
|
return f"{self.format(expr.left)} {expr.operator.lexeme} {self.format(expr.right)}"
|
|
|
|
def visit_logical_expr(self, expr: LogicalExpr) -> str:
|
|
return f"{self.format(expr.left)} {expr.operator.lexeme} {self.format(expr.right)}"
|
|
|
|
def visit_set_expr(self, expr: SetExpr) -> str:
|
|
return f"{self.format(expr.object)}.{expr.name.lexeme} = {self.format(expr.value)}"
|
|
|
|
def visit_this_expr(self, expr: ThisExpr) -> str:
|
|
return expr.keyword.lexeme
|
|
|
|
def visit_super_expr(self, expr: SuperExpr) -> str:
|
|
return f"{expr.keyword.lexeme}.{expr.method.lexeme}"
|
|
|
|
def visit_unary_expr(self, expr: UnaryExpr) -> str:
|
|
return f"{expr.operator.lexeme}{self.format(expr.right)}"
|
|
|
|
def visit_call_expr(self, expr: CallExpr) -> str:
|
|
return f"{self.format(expr.callee)}({', '.join(self.format(arg) for arg in expr.arguments)})"
|
|
|
|
def visit_get_expr(self, expr: GetExpr) -> str:
|
|
return f"{self.format(expr.object)}.{expr.name.lexeme}"
|
|
|
|
def visit_grouping_expr(self, expr: GroupingExpr) -> str:
|
|
return f"({self.format(expr.expression)})"
|
|
|
|
def visit_literal_expr(self, expr: LiteralExpr) -> str:
|
|
value: Any = expr.value
|
|
if isinstance(value, float):
|
|
if value.is_integer():
|
|
value = int(value)
|
|
return str(value)
|
|
if value is False:
|
|
return "false"
|
|
if value is True:
|
|
return "true"
|
|
if value is None:
|
|
return "null"
|
|
if isinstance(value, str):
|
|
return f'"{value}"'
|
|
return str(value)
|
|
|
|
def visit_fstring_expr(self, expr: FStringExpr) -> str:
|
|
res: str = 'f"'
|
|
for part in expr.parts:
|
|
if isinstance(part, FStringEmbedExpr):
|
|
res += self.format(part)
|
|
else:
|
|
res += part.value
|
|
res += '"'
|
|
return res
|
|
|
|
def visit_fstring_embed_expr(self, expr: FStringEmbedExpr) -> str:
|
|
return "{" + self.format(expr.expression) + "}"
|
|
|
|
def visit_variable_expr(self, expr: VariableExpr) -> str:
|
|
return expr.name.lexeme
|
|
|
|
def visit_block_stmt(self, stmt: BlockStmt) -> str:
|
|
res: str = self.indented("{\n")
|
|
self.level += 1
|
|
for sub_stmt in stmt.statements:
|
|
res += self.format(sub_stmt)
|
|
self.level -= 1
|
|
res += self.indented("}\n")
|
|
return res
|
|
|
|
def visit_class_stmt(self, stmt: ClassStmt) -> str:
|
|
res: str = self.indented("")
|
|
res += f"class {stmt.name.lexeme} "
|
|
if stmt.superclass is not None:
|
|
res += f"< {stmt.superclass.name.lexeme} "
|
|
res += "{\n"
|
|
self.level += 1
|
|
enclosing_class: ClassType = self.current_class
|
|
self.current_class = ClassType.CLASS
|
|
for method in stmt.methods:
|
|
res += self.format(method)
|
|
self.current_class = enclosing_class
|
|
self.level -= 1
|
|
res += self.indented("}\n")
|
|
return res
|
|
|
|
def visit_expression_stmt(self, stmt: ExpressionStmt) -> str:
|
|
return self.indented(self.format(stmt.expression) + "\n")
|
|
|
|
def visit_function_stmt(self, stmt: FunctionStmt) -> str:
|
|
res: str = self.indented("")
|
|
if self.current_class != ClassType.CLASS:
|
|
res += "fun "
|
|
res += stmt.name.lexeme
|
|
res += f"({', '.join(param.lexeme for param in stmt.params)}) "
|
|
res += "{\n"
|
|
self.level += 1
|
|
enclosing_class: ClassType = self.current_class
|
|
self.current_class = ClassType.NONE
|
|
for sub_stmt in stmt.body:
|
|
res += self.format(sub_stmt)
|
|
self.current_class = enclosing_class
|
|
self.level -= 1
|
|
res += self.indented("}\n")
|
|
return res
|
|
|
|
def visit_if_stmt(self, stmt: IfStmt) -> str:
|
|
res: str = self.indented(f"if {self.format(stmt.condition)} {self.format(stmt.then_branch).lstrip()}")
|
|
res = res.rstrip("\n")
|
|
if stmt.else_branch is not None:
|
|
res += f" else {self.format(stmt.else_branch).lstrip()}"
|
|
res = res.rstrip("\n")
|
|
res += "\n"
|
|
return res
|
|
|
|
def visit_return_stmt(self, stmt: ReturnStmt) -> str:
|
|
res: str = self.indented("return")
|
|
if stmt.value is not None:
|
|
res += " "
|
|
res += self.format(stmt.value).rstrip("\n")
|
|
res += "\n"
|
|
return res
|
|
|
|
def visit_while_stmt(self, stmt: WhileStmt) -> str:
|
|
return self.indented(f"while {self.format(stmt.condition)} " + self.format(stmt.body).rstrip("\n") + "\n")
|
|
|
|
def visit_for_stmt(self, stmt: ForStmt) -> str:
|
|
res: str = self.indented(f"for {stmt.variable.lexeme} ")
|
|
|
|
if stmt.start is not None:
|
|
res += f"{stmt.start_token.lexeme} {self.format(stmt.start)} "
|
|
|
|
if stmt.end is not None:
|
|
res += f"{stmt.end_token.lexeme} {self.format(stmt.end)} "
|
|
|
|
if stmt.step is not None:
|
|
res += f"{stmt.step_token.lexeme} {self.format(stmt.step)} "
|
|
|
|
res += self.format(stmt.body).rstrip("\n") + "\n"
|
|
return res
|
|
|
|
def visit_let_stmt(self, stmt: LetStmt) -> str:
|
|
res: str = self.indented(f"let {stmt.name.lexeme}")
|
|
if stmt.initializer is not None:
|
|
res += f" = {stmt.initializer.accept(self)}"
|
|
res += "\n"
|
|
return res
|
|
|
|
def visit_break_stmt(self, stmt: BreakStmt) -> str:
|
|
return self.indented(f"{stmt.keyword.lexeme}\n")
|
|
|
|
def visit_continue_stmt(self, stmt: ContinueStmt) -> str:
|
|
return self.indented(f"{stmt.keyword.lexeme}\n")
|