feat: add statements

This commit is contained in:
2026-02-05 17:00:37 +01:00
parent 09f116cea4
commit 0fdcf1f896
6 changed files with 93 additions and 35 deletions

View File

@@ -1 +1,2 @@
3 - (4 / 9 + 1) == 12 print(3 - (4 / 9 + 1) == 12)
print(40 + (8 - -1) / 3 - 0.5 * 2)

29
main.py
View File

@@ -1,11 +1,8 @@
from typing import Any from src.ast.stmt import Stmt
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr
from src.ast.printer import AstPrinter
from src.interpreter.interpreter import Interpreter from src.interpreter.interpreter import Interpreter
from src.lexer import Lexer from src.lexer import Lexer
from src.parser.parser import Parser from src.parser.parser import Parser
from src.token import Token, TokenType from src.token import Token
def main(): def main():
@@ -21,29 +18,13 @@ def main():
source = f.read() source = f.read()
lexer: Lexer = Lexer() lexer: Lexer = Lexer()
tokens: list[Token] = lexer.process(source, path) tokens: list[Token] = lexer.process(source, path)
print(tokens) print(list(filter(lambda t: t.type not in Parser.IGNORE, tokens)))
printer: AstPrinter = AstPrinter()
ast: Expr = BinaryExpr(
UnaryExpr(
Token(TokenType.MINUS, "-", None, None),
LiteralExpr(123)
),
Token(TokenType.STAR, "*", None, None),
GroupingExpr(LiteralExpr(45.67))
)
print(printer.print(ast))
parser: Parser = Parser() parser: Parser = Parser()
ast = parser.parse(tokens) program: list[Stmt] = parser.parse(tokens)
if ast is None:
return
print(printer.print(ast))
interpreter: Interpreter = Interpreter() interpreter: Interpreter = Interpreter()
result: Any = interpreter.interpret(ast) interpreter.interpret(program)
print(f"Result: {result}")
if __name__ == '__main__': if __name__ == '__main__':

41
src/ast/stmt.py Normal file
View File

@@ -0,0 +1,41 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TypeVar, Generic
from src.ast.expr import Expr
T = TypeVar("T")
@dataclass(frozen=True)
class Stmt(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T:
...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_expression_stmt(self, stmt: ExpressionStmt) -> T:
...
@abstractmethod
def visit_print_stmt(self, stmt: PrintStmt) -> T:
...
@dataclass(frozen=True)
class ExpressionStmt(Stmt):
expression: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_expression_stmt(self)
@dataclass(frozen=True)
class PrintStmt(Stmt):
expression: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_print_stmt(self)

View File

@@ -1,21 +1,26 @@
from typing import Any from typing import Any
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr
from src.ast.stmt import Stmt, PrintStmt, T, ExpressionStmt
from src.interpreter.error import PebbleRuntimeError from src.interpreter.error import PebbleRuntimeError
from src.pebble import Pebble from src.pebble import Pebble
from src.token import TokenType, Token from src.token import TokenType, Token
class Interpreter(Expr.Visitor[Any]): class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
def interpret(self, expr: Expr) -> Any: def interpret(self, statements: list[Stmt]) -> None:
try: try:
return self.evaluate(expr) for stmt in statements:
self.execute(stmt)
except PebbleRuntimeError as e: except PebbleRuntimeError as e:
Pebble.runtime_error(e) Pebble.runtime_error(e)
def evaluate(self, expr: Expr) -> Any: def evaluate(self, expr: Expr) -> Any:
return expr.accept(self) return expr.accept(self)
def execute(self, stmt: Stmt) -> None:
stmt.accept(self)
def visit_binary_expr(self, expr: BinaryExpr) -> Any: def visit_binary_expr(self, expr: BinaryExpr) -> Any:
left: Any = self.evaluate(expr.left) left: Any = self.evaluate(expr.left)
right: Any = self.evaluate(expr.right) right: Any = self.evaluate(expr.right)
@@ -75,6 +80,13 @@ class Interpreter(Expr.Visitor[Any]):
def visit_literal_expr(self, expr: LiteralExpr) -> Any: def visit_literal_expr(self, expr: LiteralExpr) -> Any:
return expr.value return expr.value
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:
self.evaluate(stmt.expression)
def visit_print_stmt(self, stmt: PrintStmt) -> None:
value: Any = self.evaluate(stmt.expression)
print(value)
@staticmethod @staticmethod
def is_truthy(value: Any) -> bool: def is_truthy(value: Any) -> bool:
if value is None or value is False: if value is None or value is False:

View File

@@ -14,4 +14,5 @@ KEYWORDS: dict[str, TokenType] = {
"false": TokenType.FALSE, "false": TokenType.FALSE,
"true": TokenType.TRUE, "true": TokenType.TRUE,
"null": TokenType.NULL, "null": TokenType.NULL,
"print": TokenType.PRINT,
} }

View File

@@ -1,6 +1,5 @@
from typing import Optional
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt
from src.parser.error import ParsingError from src.parser.error import ParsingError
from src.pebble import Pebble from src.pebble import Pebble
from src.token import Token, TokenType from src.token import Token, TokenType
@@ -25,15 +24,15 @@ class Parser:
Pebble.token_error(token, msg) Pebble.token_error(token, msg)
return ParsingError() return ParsingError()
def parse(self, tokens: list[Token]) -> Optional[Expr]: def parse(self, tokens: list[Token]) -> list[Stmt]:
self.tokens = list(filter(lambda t: t.type not in self.IGNORE, tokens)) self.tokens = list(filter(lambda t: t.type not in self.IGNORE, tokens))
self.current = 0 self.current = 0
self.length = len(self.tokens) self.length = len(self.tokens)
try: statements: list[Stmt] = []
return self.expression() while not self.is_at_end():
except ParsingError: statements.append(self.statement())
return None return statements
def is_at_end(self) -> bool: def is_at_end(self) -> bool:
return self.current >= self.length return self.current >= self.length
@@ -65,6 +64,12 @@ class Parser:
if not self.match(token_type): if not self.match(token_type):
raise self.error(self.peek(), error_msg) raise self.error(self.peek(), error_msg)
def expect_eol(self, error_msg: str):
if self.is_at_end():
return
if not self.match(TokenType.NEWLINE) and not self.match(TokenType.EOF):
raise self.error(self.peek(), error_msg)
# Parsing # Parsing
def synchronize(self): def synchronize(self):
self.advance() self.advance()
@@ -75,6 +80,23 @@ class Parser:
return return
self.advance() self.advance()
def statement(self) -> Stmt:
if self.match(TokenType.PRINT):
return self.print_stmt()
return self.expression_stmt()
def print_stmt(self) -> Stmt:
self.consume(TokenType.LEFT_PAREN, "Missing parentheses")
value: Expr = self.expression()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
self.expect_eol("Expected end of line after statement")
return PrintStmt(value)
def expression_stmt(self) -> Stmt:
value: Expr = self.expression()
self.expect_eol("Expected end of line after expression")
return ExpressionStmt(value)
def expression(self) -> Expr: def expression(self) -> Expr:
return self.equality() return self.equality()