feat: add statements

This commit is contained in:
2026-02-05 17:00:37 +01:00
parent 09f116cea4
commit 0fdcf1f896
6 changed files with 93 additions and 35 deletions

View File

@@ -1 +1,2 @@
3 - (4 / 9 + 1) == 12
print(3 - (4 / 9 + 1) == 12)
print(40 + (8 - -1) / 3 - 0.5 * 2)

29
main.py
View File

@@ -1,11 +1,8 @@
from typing import Any
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr
from src.ast.printer import AstPrinter
from src.ast.stmt import Stmt
from src.interpreter.interpreter import Interpreter
from src.lexer import Lexer
from src.parser.parser import Parser
from src.token import Token, TokenType
from src.token import Token
def main():
@@ -21,29 +18,13 @@ def main():
source = f.read()
lexer: Lexer = Lexer()
tokens: list[Token] = lexer.process(source, path)
print(tokens)
printer: AstPrinter = AstPrinter()
ast: Expr = BinaryExpr(
UnaryExpr(
Token(TokenType.MINUS, "-", None, None),
LiteralExpr(123)
),
Token(TokenType.STAR, "*", None, None),
GroupingExpr(LiteralExpr(45.67))
)
print(printer.print(ast))
print(list(filter(lambda t: t.type not in Parser.IGNORE, tokens)))
parser: Parser = Parser()
ast = parser.parse(tokens)
program: list[Stmt] = parser.parse(tokens)
if ast is None:
return
print(printer.print(ast))
interpreter: Interpreter = Interpreter()
result: Any = interpreter.interpret(ast)
print(f"Result: {result}")
interpreter.interpret(program)
if __name__ == '__main__':

41
src/ast/stmt.py Normal file
View File

@@ -0,0 +1,41 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TypeVar, Generic
from src.ast.expr import Expr
T = TypeVar("T")
@dataclass(frozen=True)
class Stmt(ABC):
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T:
...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_expression_stmt(self, stmt: ExpressionStmt) -> T:
...
@abstractmethod
def visit_print_stmt(self, stmt: PrintStmt) -> T:
...
@dataclass(frozen=True)
class ExpressionStmt(Stmt):
expression: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_expression_stmt(self)
@dataclass(frozen=True)
class PrintStmt(Stmt):
expression: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_print_stmt(self)

View File

@@ -1,21 +1,26 @@
from typing import Any
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr
from src.ast.stmt import Stmt, PrintStmt, T, ExpressionStmt
from src.interpreter.error import PebbleRuntimeError
from src.pebble import Pebble
from src.token import TokenType, Token
class Interpreter(Expr.Visitor[Any]):
def interpret(self, expr: Expr) -> Any:
class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
def interpret(self, statements: list[Stmt]) -> None:
try:
return self.evaluate(expr)
for stmt in statements:
self.execute(stmt)
except PebbleRuntimeError as e:
Pebble.runtime_error(e)
def evaluate(self, expr: Expr) -> Any:
return expr.accept(self)
def execute(self, stmt: Stmt) -> None:
stmt.accept(self)
def visit_binary_expr(self, expr: BinaryExpr) -> Any:
left: Any = self.evaluate(expr.left)
right: Any = self.evaluate(expr.right)
@@ -75,6 +80,13 @@ class Interpreter(Expr.Visitor[Any]):
def visit_literal_expr(self, expr: LiteralExpr) -> Any:
return expr.value
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:
self.evaluate(stmt.expression)
def visit_print_stmt(self, stmt: PrintStmt) -> None:
value: Any = self.evaluate(stmt.expression)
print(value)
@staticmethod
def is_truthy(value: Any) -> bool:
if value is None or value is False:

View File

@@ -14,4 +14,5 @@ KEYWORDS: dict[str, TokenType] = {
"false": TokenType.FALSE,
"true": TokenType.TRUE,
"null": TokenType.NULL,
"print": TokenType.PRINT,
}

View File

@@ -1,6 +1,5 @@
from typing import Optional
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt
from src.parser.error import ParsingError
from src.pebble import Pebble
from src.token import Token, TokenType
@@ -25,15 +24,15 @@ class Parser:
Pebble.token_error(token, msg)
return ParsingError()
def parse(self, tokens: list[Token]) -> Optional[Expr]:
def parse(self, tokens: list[Token]) -> list[Stmt]:
self.tokens = list(filter(lambda t: t.type not in self.IGNORE, tokens))
self.current = 0
self.length = len(self.tokens)
try:
return self.expression()
except ParsingError:
return None
statements: list[Stmt] = []
while not self.is_at_end():
statements.append(self.statement())
return statements
def is_at_end(self) -> bool:
return self.current >= self.length
@@ -65,6 +64,12 @@ class Parser:
if not self.match(token_type):
raise self.error(self.peek(), error_msg)
def expect_eol(self, error_msg: str):
if self.is_at_end():
return
if not self.match(TokenType.NEWLINE) and not self.match(TokenType.EOF):
raise self.error(self.peek(), error_msg)
# Parsing
def synchronize(self):
self.advance()
@@ -75,6 +80,23 @@ class Parser:
return
self.advance()
def statement(self) -> Stmt:
if self.match(TokenType.PRINT):
return self.print_stmt()
return self.expression_stmt()
def print_stmt(self) -> Stmt:
self.consume(TokenType.LEFT_PAREN, "Missing parentheses")
value: Expr = self.expression()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
self.expect_eol("Expected end of line after statement")
return PrintStmt(value)
def expression_stmt(self) -> Stmt:
value: Expr = self.expression()
self.expect_eol("Expected end of line after expression")
return ExpressionStmt(value)
def expression(self) -> Expr:
return self.equality()