Files
pebble/src/formatter.py

188 lines
6.6 KiB
Python

from enum import Enum, auto
from typing import Any
from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \
CallExpr, SetExpr, GetExpr, ThisExpr, SuperExpr, FStringExpr, FStringEmbedExpr
from src.ast.stmt import Stmt, LetStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt, FunctionStmt, \
ReturnStmt, BreakStmt, ContinueStmt, ClassStmt
class ClassType(Enum):
NONE = auto()
CLASS = auto()
class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
def __init__(self, indent: int = 4):
self.indent: int = indent
self.level: int = 0
self.current_class: ClassType = ClassType.NONE
def indented(self, text: str) -> str:
return " " * (self.level * self.indent) + text
def print(self, statements: list[Stmt]) -> str:
self.level = 0
res: str = ""
for stmt in statements:
res += self.format(stmt)
return res
def format(self, obj: Expr | Stmt) -> str:
return obj.accept(self)
def visit_assign_expr(self, expr: AssignExpr) -> str:
return f"{expr.name.lexeme} = {self.format(expr.value)}"
def visit_binary_expr(self, expr: BinaryExpr) -> str:
return f"{self.format(expr.left)} {expr.operator.lexeme} {self.format(expr.right)}"
def visit_logical_expr(self, expr: LogicalExpr) -> str:
return f"{self.format(expr.left)} {expr.operator.lexeme} {self.format(expr.right)}"
def visit_set_expr(self, expr: SetExpr) -> str:
return f"{self.format(expr.object)}.{expr.name.lexeme} = {self.format(expr.value)}"
def visit_this_expr(self, expr: ThisExpr) -> str:
return expr.keyword.lexeme
def visit_super_expr(self, expr: SuperExpr) -> str:
return f"{expr.keyword.lexeme}.{expr.method.lexeme}"
def visit_unary_expr(self, expr: UnaryExpr) -> str:
return f"{expr.operator.lexeme}{self.format(expr.right)}"
def visit_call_expr(self, expr: CallExpr) -> str:
return f"{self.format(expr.callee)}({', '.join(self.format(arg) for arg in expr.arguments)})"
def visit_get_expr(self, expr: GetExpr) -> str:
return f"{self.format(expr.object)}.{expr.name.lexeme}"
def visit_grouping_expr(self, expr: GroupingExpr) -> str:
return f"({self.format(expr.expression)})"
def visit_literal_expr(self, expr: LiteralExpr) -> str:
value: Any = expr.value
if isinstance(value, float):
if value.is_integer():
value = int(value)
return str(value)
if value is False:
return "false"
if value is True:
return "true"
if value is None:
return "null"
if isinstance(value, str):
return f'"{value}"'
return str(value)
def visit_fstring_expr(self, expr: FStringExpr) -> str:
res: str = 'f"'
for part in expr.parts:
if isinstance(part, FStringEmbedExpr):
res += self.format(part)
else:
res += part.value
res += '"'
return res
def visit_fstring_embed_expr(self, expr: FStringEmbedExpr) -> str:
return "{" + self.format(expr.expression) + "}"
def visit_variable_expr(self, expr: VariableExpr) -> str:
return expr.name.lexeme
def visit_block_stmt(self, stmt: BlockStmt) -> str:
res: str = self.indented("{\n")
self.level += 1
for sub_stmt in stmt.statements:
res += self.format(sub_stmt)
self.level -= 1
res += self.indented("}\n")
return res
def visit_class_stmt(self, stmt: ClassStmt) -> str:
res: str = self.indented("")
res += f"class {stmt.name.lexeme} "
if stmt.superclass is not None:
res += f"< {stmt.superclass.name.lexeme} "
res += "{\n"
self.level += 1
enclosing_class: ClassType = self.current_class
self.current_class = ClassType.CLASS
for method in stmt.methods:
res += self.format(method)
self.current_class = enclosing_class
self.level -= 1
res += self.indented("}\n")
return res
def visit_expression_stmt(self, stmt: ExpressionStmt) -> str:
return self.indented(self.format(stmt.expression) + "\n")
def visit_function_stmt(self, stmt: FunctionStmt) -> str:
res: str = self.indented("")
if self.current_class != ClassType.CLASS:
res += "fun "
res += stmt.name.lexeme
res += f"({', '.join(param.lexeme for param in stmt.params)}) "
res += "{\n"
self.level += 1
enclosing_class: ClassType = self.current_class
self.current_class = ClassType.NONE
for sub_stmt in stmt.body:
res += self.format(sub_stmt)
self.current_class = enclosing_class
self.level -= 1
res += self.indented("}\n")
return res
def visit_if_stmt(self, stmt: IfStmt) -> str:
res: str = self.indented(f"if {self.format(stmt.condition)} {self.format(stmt.then_branch).lstrip()}")
res = res.rstrip("\n")
if stmt.else_branch is not None:
res += f" else {self.format(stmt.else_branch).lstrip()}"
res = res.rstrip("\n")
res += "\n"
return res
def visit_return_stmt(self, stmt: ReturnStmt) -> str:
res: str = self.indented("return")
if stmt.value is not None:
res += " "
res += self.format(stmt.value).rstrip("\n")
res += "\n"
return res
def visit_while_stmt(self, stmt: WhileStmt) -> str:
return self.indented(f"while {self.format(stmt.condition)} " + self.format(stmt.body).rstrip("\n") + "\n")
def visit_for_stmt(self, stmt: ForStmt) -> str:
res: str = self.indented(f"for {stmt.variable.lexeme} ")
if stmt.start is not None:
res += f"{stmt.start_token.lexeme} {self.format(stmt.start)} "
if stmt.end is not None:
res += f"{stmt.end_token.lexeme} {self.format(stmt.end)} "
if stmt.step is not None:
res += f"{stmt.step_token.lexeme} {self.format(stmt.step)} "
res += self.format(stmt.body).rstrip("\n") + "\n"
return res
def visit_let_stmt(self, stmt: LetStmt) -> str:
res: str = self.indented(f"let {stmt.name.lexeme}")
if stmt.initializer is not None:
res += f" = {stmt.initializer.accept(self)}"
res += "\n"
return res
def visit_break_stmt(self, stmt: BreakStmt) -> str:
return self.indented(f"{stmt.keyword.lexeme}\n")
def visit_continue_stmt(self, stmt: ContinueStmt) -> str:
return self.indented(f"{stmt.keyword.lexeme}\n")