diff --git a/src/ast/expr.py b/src/ast/expr.py index 9db771a..f7ba222 100644 --- a/src/ast/expr.py +++ b/src/ast/expr.py @@ -33,6 +33,10 @@ class Expr(ABC): def visit_literal_expr(self, expr: LiteralExpr) -> T: ... + @abstractmethod + def visit_variable_expr(self, expr: VariableExpr) -> T: + ... + @dataclass(frozen=True) class BinaryExpr(Expr): @@ -67,3 +71,11 @@ class LiteralExpr(Expr): def accept(self, visitor: Expr.Visitor[T]) -> T: return visitor.visit_literal_expr(self) + + +@dataclass(frozen=True) +class VariableExpr(Expr): + name: Token + + def accept(self, visitor: Expr.Visitor[T]) -> T: + return visitor.visit_variable_expr(self) diff --git a/src/ast/stmt.py b/src/ast/stmt.py index db76e7d..95f8355 100644 --- a/src/ast/stmt.py +++ b/src/ast/stmt.py @@ -2,9 +2,10 @@ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TypeVar, Generic +from typing import TypeVar, Generic, Optional from src.ast.expr import Expr +from src.token import Token T = TypeVar("T") @@ -24,6 +25,10 @@ class Stmt(ABC): def visit_print_stmt(self, stmt: PrintStmt) -> T: ... + @abstractmethod + def visit_let_stmt(self, stmt: LetStmt) -> T: + ... + @dataclass(frozen=True) class ExpressionStmt(Stmt): @@ -39,3 +44,12 @@ class PrintStmt(Stmt): def accept(self, visitor: Stmt.Visitor[T]) -> T: return visitor.visit_print_stmt(self) + + +@dataclass(frozen=True) +class LetStmt(Stmt): + name: Token + initializer: Optional[Expr] + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_let_stmt(self) diff --git a/src/interpreter/interpreter.py b/src/interpreter/interpreter.py index 73a0f31..d49f03e 100644 --- a/src/interpreter/interpreter.py +++ b/src/interpreter/interpreter.py @@ -1,14 +1,34 @@ from typing import Any -from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr -from src.ast.stmt import Stmt, PrintStmt, T, ExpressionStmt +from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr +from src.ast.stmt import Stmt, PrintStmt, T, ExpressionStmt, LetStmt from src.interpreter.error import PebbleRuntimeError from src.pebble import Pebble from src.token import TokenType, Token class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]): + class Environment: + def __init__(self): + self.values: dict[str, Any] = {} + + def define(self, name: str, value: Any): + self.values[name] = value + + def get(self, name: Token) -> Any: + try: + return self.values[name.lexeme] + except IndexError: + raise PebbleRuntimeError(name, f"Undefined variable '{name.lexeme}'.") + + def clear(self): + self.values = {} + + def __init__(self): + self.env: Interpreter.Environment = Interpreter.Environment() + def interpret(self, statements: list[Stmt]) -> None: + self.env.clear() try: for stmt in statements: self.execute(stmt) @@ -80,6 +100,9 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]): def visit_literal_expr(self, expr: LiteralExpr) -> Any: return expr.value + def visit_variable_expr(self, expr: VariableExpr) -> Any: + return self.env.get(expr.name) + def visit_expression_stmt(self, stmt: ExpressionStmt) -> None: self.evaluate(stmt.expression) @@ -87,6 +110,12 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]): value: Any = self.evaluate(stmt.expression) print(value) + def visit_let_stmt(self, stmt: LetStmt) -> None: + value: Any = None + if stmt.initializer is not None: + value = self.evaluate(stmt.initializer) + self.env.define(stmt.name.lexeme, value) + @staticmethod def is_truthy(value: Any) -> bool: if value is None or value is False: diff --git a/src/parser/parser.py b/src/parser/parser.py index b8cd205..b854eaf 100644 --- a/src/parser/parser.py +++ b/src/parser/parser.py @@ -1,5 +1,7 @@ -from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr -from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt +from typing import Optional + +from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr +from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt from src.parser.error import ParsingError from src.pebble import Pebble from src.token import Token, TokenType @@ -31,7 +33,7 @@ class Parser: statements: list[Stmt] = [] while not self.is_at_end(): - statements.append(self.statement()) + statements.append(self.declaration()) return statements def is_at_end(self) -> bool: @@ -60,9 +62,10 @@ class Parser: return True return False - def consume(self, token_type: TokenType, error_msg: str): - if not self.match(token_type): - raise self.error(self.peek(), error_msg) + def consume(self, token_type: TokenType, error_msg: str) -> Token: + if self.check(token_type): + return self.advance() + raise self.error(self.peek(), error_msg) def expect_eol(self, error_msg: str): if self.is_at_end(): @@ -80,6 +83,23 @@ class Parser: return self.advance() + def declaration(self) -> Optional[Stmt]: + try: + if self.match(TokenType.LET): + return self.var_declaration() + return self.statement() + except ParsingError: + self.synchronize() + return None + + def var_declaration(self) -> Stmt: + name: Token = self.consume(TokenType.IDENTIFIER, "Expected variable name.") + initializer: Optional[Expr] = None + if self.match(TokenType.EQUAL): + initializer = self.expression() + self.expect_eol("Expected end of line after variable initialization") + return LetStmt(name, initializer) + def statement(self) -> Stmt: if self.match(TokenType.PRINT): return self.print_stmt() @@ -150,6 +170,9 @@ class Parser: if self.match(TokenType.NUMBER, TokenType.STRING): return LiteralExpr(self.previous().value) + if self.match(TokenType.IDENTIFIER): + return VariableExpr(self.previous()) + if self.match(TokenType.LEFT_PAREN): expr: Expr = self.expression() self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")