diff --git a/main.py b/main.py index 7ee7e42..ab744f7 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ from src.ast.stmt import Stmt +from src.formatter import Formatter from src.interpreter.interpreter import Interpreter from src.lexer import Lexer from src.parser.parser import Parser @@ -26,6 +27,10 @@ def main(): interpreter: Interpreter = Interpreter() interpreter.interpret(program) + formatter: Formatter = Formatter() + with open("formatted.peb", "w") as f: + f.write(formatter.print(program)) + if __name__ == '__main__': main() diff --git a/src/formatter.py b/src/formatter.py new file mode 100644 index 0000000..ee9e742 --- /dev/null +++ b/src/formatter.py @@ -0,0 +1,110 @@ +from typing import Any + +from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \ + CallExpr +from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt + + +class Formatter(Expr.Visitor[str], Stmt.Visitor[str]): + def __init__(self, indent: int = 4): + self.indent: int = indent + self.level: int = 0 + + def indented(self, text: str) -> str: + return " " * (self.level * self.indent) + text + + def print(self, statements: list[Stmt]) -> str: + self.level = 0 + res: str = "" + for stmt in statements: + res += self.format(stmt) + return res + + def format(self, obj: Expr | Stmt) -> str: + return obj.accept(self) + + def visit_assign_expr(self, expr: AssignExpr) -> str: + return f"{expr.name.lexeme} = {self.format(expr.value)}" + + def visit_binary_expr(self, expr: BinaryExpr) -> str: + return f"{self.format(expr.left)} {expr.operator.lexeme} {self.format(expr.right)}" + + def visit_logical_expr(self, expr: LogicalExpr) -> str: + return f"{self.format(expr.left)} {expr.operator.lexeme} {self.format(expr.right)}" + + def visit_unary_expr(self, expr: UnaryExpr) -> str: + return f"{expr.operator.lexeme}{self.format(expr.right)}" + + def visit_call_expr(self, expr: CallExpr) -> str: + return f"{self.format(expr.callee)}({', '.join(self.format(arg) for arg in expr.arguments)})" + + def visit_grouping_expr(self, expr: GroupingExpr) -> str: + return f"({self.format(expr.expression)})" + + def visit_literal_expr(self, expr: LiteralExpr) -> str: + value: Any = expr.value + if isinstance(value, float): + if value.is_integer(): + value = int(value) + return str(value) + if value is False: + return "false" + if value is True: + return "true" + if value is None: + return "null" + if isinstance(value, str): + return f'"{value}"' + return str(value) + + def visit_variable_expr(self, expr: VariableExpr) -> str: + return expr.name.lexeme + + def visit_block_stmt(self, stmt: BlockStmt) -> str: + res: str = self.indented("{\n") + self.level += 1 + for sub_stmt in stmt.statements: + res += self.format(sub_stmt) + self.level -= 1 + res += self.indented("}\n") + return res + + def visit_expression_stmt(self, stmt: ExpressionStmt) -> str: + return self.indented(self.format(stmt.expression) + "\n") + + def visit_if_stmt(self, stmt: IfStmt) -> str: + res: str = self.indented(f"if {self.format(stmt.condition)} {self.format(stmt.then_branch)}") + res = res.rstrip("\n") + if stmt.else_branch is not None: + res += f" else {self.format(stmt.else_branch)}" + res = res.rstrip("\n") + res += "\n" + return res + + def visit_print_stmt(self, stmt: PrintStmt) -> str: + return self.indented(f"print({self.format(stmt.expression)})\n") + + def visit_while_stmt(self, stmt: WhileStmt) -> str: + return self.indented(f"while {self.format(stmt.condition)} " + self.format(stmt.body).rstrip("\n") + "\n") + + def visit_for_stmt(self, stmt: ForStmt) -> str: + res: str = self.indented(f"for {stmt.variable.lexeme} ") + + if stmt.start is not None: + res += f"{stmt.start_token.lexeme} {self.format(stmt.start)} " + + if stmt.end is not None: + res += f"{stmt.end_token.lexeme} {self.format(stmt.end)} " + + if stmt.step is not None: + res += f"{stmt.step_token.lexeme} {self.format(stmt.step)} " + + res += self.format(stmt.body).rstrip("\n") + "\n" + return res + + def visit_let_stmt(self, stmt: LetStmt) -> str: + res: str = self.indented(f"let {stmt.name.lexeme}") + if stmt.initializer is not None: + res += f" = {stmt.initializer.accept(self)}" + res += "\n" + return res