diff --git a/examples/07_math.peb b/examples/07_math.peb index 559650d..120a878 100644 --- a/examples/07_math.peb +++ b/examples/07_math.peb @@ -1 +1,2 @@ -3 - (4 / 9 + 1) == 12 \ No newline at end of file +print(3 - (4 / 9 + 1) == 12) +print(40 + (8 - -1) / 3 - 0.5 * 2) \ No newline at end of file diff --git a/main.py b/main.py index 91871f5..344d80b 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,8 @@ -from typing import Any - -from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr -from src.ast.printer import AstPrinter +from src.ast.stmt import Stmt from src.interpreter.interpreter import Interpreter from src.lexer import Lexer from src.parser.parser import Parser -from src.token import Token, TokenType +from src.token import Token def main(): @@ -21,29 +18,13 @@ def main(): source = f.read() lexer: Lexer = Lexer() tokens: list[Token] = lexer.process(source, path) - print(tokens) - - printer: AstPrinter = AstPrinter() - ast: Expr = BinaryExpr( - UnaryExpr( - Token(TokenType.MINUS, "-", None, None), - LiteralExpr(123) - ), - Token(TokenType.STAR, "*", None, None), - GroupingExpr(LiteralExpr(45.67)) - ) - print(printer.print(ast)) + print(list(filter(lambda t: t.type not in Parser.IGNORE, tokens))) parser: Parser = Parser() - ast = parser.parse(tokens) + program: list[Stmt] = parser.parse(tokens) - if ast is None: - return - - print(printer.print(ast)) interpreter: Interpreter = Interpreter() - result: Any = interpreter.interpret(ast) - print(f"Result: {result}") + interpreter.interpret(program) if __name__ == '__main__': diff --git a/src/ast/stmt.py b/src/ast/stmt.py new file mode 100644 index 0000000..db76e7d --- /dev/null +++ b/src/ast/stmt.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TypeVar, Generic + +from src.ast.expr import Expr + +T = TypeVar("T") + + +@dataclass(frozen=True) +class Stmt(ABC): + @abstractmethod + def accept(self, visitor: Visitor[T]) -> T: + ... + + class Visitor(ABC, Generic[T]): + @abstractmethod + def visit_expression_stmt(self, stmt: ExpressionStmt) -> T: + ... + + @abstractmethod + def visit_print_stmt(self, stmt: PrintStmt) -> T: + ... + + +@dataclass(frozen=True) +class ExpressionStmt(Stmt): + expression: Expr + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_expression_stmt(self) + + +@dataclass(frozen=True) +class PrintStmt(Stmt): + expression: Expr + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_print_stmt(self) diff --git a/src/interpreter/interpreter.py b/src/interpreter/interpreter.py index 8671fba..73a0f31 100644 --- a/src/interpreter/interpreter.py +++ b/src/interpreter/interpreter.py @@ -1,21 +1,26 @@ from typing import Any from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr +from src.ast.stmt import Stmt, PrintStmt, T, ExpressionStmt from src.interpreter.error import PebbleRuntimeError from src.pebble import Pebble from src.token import TokenType, Token -class Interpreter(Expr.Visitor[Any]): - def interpret(self, expr: Expr) -> Any: +class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]): + def interpret(self, statements: list[Stmt]) -> None: try: - return self.evaluate(expr) + for stmt in statements: + self.execute(stmt) except PebbleRuntimeError as e: Pebble.runtime_error(e) def evaluate(self, expr: Expr) -> Any: return expr.accept(self) + def execute(self, stmt: Stmt) -> None: + stmt.accept(self) + def visit_binary_expr(self, expr: BinaryExpr) -> Any: left: Any = self.evaluate(expr.left) right: Any = self.evaluate(expr.right) @@ -75,6 +80,13 @@ class Interpreter(Expr.Visitor[Any]): def visit_literal_expr(self, expr: LiteralExpr) -> Any: return expr.value + def visit_expression_stmt(self, stmt: ExpressionStmt) -> None: + self.evaluate(stmt.expression) + + def visit_print_stmt(self, stmt: PrintStmt) -> None: + value: Any = self.evaluate(stmt.expression) + print(value) + @staticmethod def is_truthy(value: Any) -> bool: if value is None or value is False: diff --git a/src/keyword.py b/src/keyword.py index a3b7bc6..fb369a1 100644 --- a/src/keyword.py +++ b/src/keyword.py @@ -14,4 +14,5 @@ KEYWORDS: dict[str, TokenType] = { "false": TokenType.FALSE, "true": TokenType.TRUE, "null": TokenType.NULL, + "print": TokenType.PRINT, } diff --git a/src/parser/parser.py b/src/parser/parser.py index 55c4b96..b8cd205 100644 --- a/src/parser/parser.py +++ b/src/parser/parser.py @@ -1,6 +1,5 @@ -from typing import Optional - from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr +from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt from src.parser.error import ParsingError from src.pebble import Pebble from src.token import Token, TokenType @@ -25,15 +24,15 @@ class Parser: Pebble.token_error(token, msg) return ParsingError() - def parse(self, tokens: list[Token]) -> Optional[Expr]: + def parse(self, tokens: list[Token]) -> list[Stmt]: self.tokens = list(filter(lambda t: t.type not in self.IGNORE, tokens)) self.current = 0 self.length = len(self.tokens) - try: - return self.expression() - except ParsingError: - return None + statements: list[Stmt] = [] + while not self.is_at_end(): + statements.append(self.statement()) + return statements def is_at_end(self) -> bool: return self.current >= self.length @@ -65,6 +64,12 @@ class Parser: if not self.match(token_type): raise self.error(self.peek(), error_msg) + def expect_eol(self, error_msg: str): + if self.is_at_end(): + return + if not self.match(TokenType.NEWLINE) and not self.match(TokenType.EOF): + raise self.error(self.peek(), error_msg) + # Parsing def synchronize(self): self.advance() @@ -75,6 +80,23 @@ class Parser: return self.advance() + def statement(self) -> Stmt: + if self.match(TokenType.PRINT): + return self.print_stmt() + return self.expression_stmt() + + def print_stmt(self) -> Stmt: + self.consume(TokenType.LEFT_PAREN, "Missing parentheses") + value: Expr = self.expression() + self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis") + self.expect_eol("Expected end of line after statement") + return PrintStmt(value) + + def expression_stmt(self) -> Stmt: + value: Expr = self.expression() + self.expect_eol("Expected end of line after expression") + return ExpressionStmt(value) + def expression(self) -> Expr: return self.equality()