feat(formatter): add basic formatter

This commit is contained in:
2026-02-06 13:24:14 +01:00
parent 50e838b4bb
commit 6552eaff90
2 changed files with 115 additions and 0 deletions

View File

@@ -1,4 +1,5 @@
from src.ast.stmt import Stmt from src.ast.stmt import Stmt
from src.formatter import Formatter
from src.interpreter.interpreter import Interpreter from src.interpreter.interpreter import Interpreter
from src.lexer import Lexer from src.lexer import Lexer
from src.parser.parser import Parser from src.parser.parser import Parser
@@ -26,6 +27,10 @@ def main():
interpreter: Interpreter = Interpreter() interpreter: Interpreter = Interpreter()
interpreter.interpret(program) interpreter.interpret(program)
formatter: Formatter = Formatter()
with open("formatted.peb", "w") as f:
f.write(formatter.print(program))
if __name__ == '__main__': if __name__ == '__main__':
main() main()

110
src/formatter.py Normal file
View File

@@ -0,0 +1,110 @@
from typing import Any
from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \
CallExpr
from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt
class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
def __init__(self, indent: int = 4):
self.indent: int = indent
self.level: int = 0
def indented(self, text: str) -> str:
return " " * (self.level * self.indent) + text
def print(self, statements: list[Stmt]) -> str:
self.level = 0
res: str = ""
for stmt in statements:
res += self.format(stmt)
return res
def format(self, obj: Expr | Stmt) -> str:
return obj.accept(self)
def visit_assign_expr(self, expr: AssignExpr) -> str:
return f"{expr.name.lexeme} = {self.format(expr.value)}"
def visit_binary_expr(self, expr: BinaryExpr) -> str:
return f"{self.format(expr.left)} {expr.operator.lexeme} {self.format(expr.right)}"
def visit_logical_expr(self, expr: LogicalExpr) -> str:
return f"{self.format(expr.left)} {expr.operator.lexeme} {self.format(expr.right)}"
def visit_unary_expr(self, expr: UnaryExpr) -> str:
return f"{expr.operator.lexeme}{self.format(expr.right)}"
def visit_call_expr(self, expr: CallExpr) -> str:
return f"{self.format(expr.callee)}({', '.join(self.format(arg) for arg in expr.arguments)})"
def visit_grouping_expr(self, expr: GroupingExpr) -> str:
return f"({self.format(expr.expression)})"
def visit_literal_expr(self, expr: LiteralExpr) -> str:
value: Any = expr.value
if isinstance(value, float):
if value.is_integer():
value = int(value)
return str(value)
if value is False:
return "false"
if value is True:
return "true"
if value is None:
return "null"
if isinstance(value, str):
return f'"{value}"'
return str(value)
def visit_variable_expr(self, expr: VariableExpr) -> str:
return expr.name.lexeme
def visit_block_stmt(self, stmt: BlockStmt) -> str:
res: str = self.indented("{\n")
self.level += 1
for sub_stmt in stmt.statements:
res += self.format(sub_stmt)
self.level -= 1
res += self.indented("}\n")
return res
def visit_expression_stmt(self, stmt: ExpressionStmt) -> str:
return self.indented(self.format(stmt.expression) + "\n")
def visit_if_stmt(self, stmt: IfStmt) -> str:
res: str = self.indented(f"if {self.format(stmt.condition)} {self.format(stmt.then_branch)}")
res = res.rstrip("\n")
if stmt.else_branch is not None:
res += f" else {self.format(stmt.else_branch)}"
res = res.rstrip("\n")
res += "\n"
return res
def visit_print_stmt(self, stmt: PrintStmt) -> str:
return self.indented(f"print({self.format(stmt.expression)})\n")
def visit_while_stmt(self, stmt: WhileStmt) -> str:
return self.indented(f"while {self.format(stmt.condition)} " + self.format(stmt.body).rstrip("\n") + "\n")
def visit_for_stmt(self, stmt: ForStmt) -> str:
res: str = self.indented(f"for {stmt.variable.lexeme} ")
if stmt.start is not None:
res += f"{stmt.start_token.lexeme} {self.format(stmt.start)} "
if stmt.end is not None:
res += f"{stmt.end_token.lexeme} {self.format(stmt.end)} "
if stmt.step is not None:
res += f"{stmt.step_token.lexeme} {self.format(stmt.step)} "
res += self.format(stmt.body).rstrip("\n") + "\n"
return res
def visit_let_stmt(self, stmt: LetStmt) -> str:
res: str = self.indented(f"let {stmt.name.lexeme}")
if stmt.initializer is not None:
res += f" = {stmt.initializer.accept(self)}"
res += "\n"
return res