feat: add variable definition and reference

This commit is contained in:
2026-02-05 23:03:32 +01:00
parent 0fdcf1f896
commit 058a57909f
4 changed files with 87 additions and 9 deletions

View File

@@ -33,6 +33,10 @@ class Expr(ABC):
def visit_literal_expr(self, expr: LiteralExpr) -> T: def visit_literal_expr(self, expr: LiteralExpr) -> T:
... ...
@abstractmethod
def visit_variable_expr(self, expr: VariableExpr) -> T:
...
@dataclass(frozen=True) @dataclass(frozen=True)
class BinaryExpr(Expr): class BinaryExpr(Expr):
@@ -67,3 +71,11 @@ class LiteralExpr(Expr):
def accept(self, visitor: Expr.Visitor[T]) -> T: def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_literal_expr(self) return visitor.visit_literal_expr(self)
@dataclass(frozen=True)
class VariableExpr(Expr):
name: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_variable_expr(self)

View File

@@ -2,9 +2,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeVar, Generic from typing import TypeVar, Generic, Optional
from src.ast.expr import Expr from src.ast.expr import Expr
from src.token import Token
T = TypeVar("T") T = TypeVar("T")
@@ -24,6 +25,10 @@ class Stmt(ABC):
def visit_print_stmt(self, stmt: PrintStmt) -> T: def visit_print_stmt(self, stmt: PrintStmt) -> T:
... ...
@abstractmethod
def visit_let_stmt(self, stmt: LetStmt) -> T:
...
@dataclass(frozen=True) @dataclass(frozen=True)
class ExpressionStmt(Stmt): class ExpressionStmt(Stmt):
@@ -39,3 +44,12 @@ class PrintStmt(Stmt):
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_print_stmt(self) return visitor.visit_print_stmt(self)
@dataclass(frozen=True)
class LetStmt(Stmt):
name: Token
initializer: Optional[Expr]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_let_stmt(self)

View File

@@ -1,14 +1,34 @@
from typing import Any from typing import Any
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr
from src.ast.stmt import Stmt, PrintStmt, T, ExpressionStmt from src.ast.stmt import Stmt, PrintStmt, T, ExpressionStmt, LetStmt
from src.interpreter.error import PebbleRuntimeError from src.interpreter.error import PebbleRuntimeError
from src.pebble import Pebble from src.pebble import Pebble
from src.token import TokenType, Token from src.token import TokenType, Token
class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]): class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
class Environment:
def __init__(self):
self.values: dict[str, Any] = {}
def define(self, name: str, value: Any):
self.values[name] = value
def get(self, name: Token) -> Any:
try:
return self.values[name.lexeme]
except IndexError:
raise PebbleRuntimeError(name, f"Undefined variable '{name.lexeme}'.")
def clear(self):
self.values = {}
def __init__(self):
self.env: Interpreter.Environment = Interpreter.Environment()
def interpret(self, statements: list[Stmt]) -> None: def interpret(self, statements: list[Stmt]) -> None:
self.env.clear()
try: try:
for stmt in statements: for stmt in statements:
self.execute(stmt) self.execute(stmt)
@@ -80,6 +100,9 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
def visit_literal_expr(self, expr: LiteralExpr) -> Any: def visit_literal_expr(self, expr: LiteralExpr) -> Any:
return expr.value return expr.value
def visit_variable_expr(self, expr: VariableExpr) -> Any:
return self.env.get(expr.name)
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None: def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:
self.evaluate(stmt.expression) self.evaluate(stmt.expression)
@@ -87,6 +110,12 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
value: Any = self.evaluate(stmt.expression) value: Any = self.evaluate(stmt.expression)
print(value) print(value)
def visit_let_stmt(self, stmt: LetStmt) -> None:
value: Any = None
if stmt.initializer is not None:
value = self.evaluate(stmt.initializer)
self.env.define(stmt.name.lexeme, value)
@staticmethod @staticmethod
def is_truthy(value: Any) -> bool: def is_truthy(value: Any) -> bool:
if value is None or value is False: if value is None or value is False:

View File

@@ -1,5 +1,7 @@
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr from typing import Optional
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt
from src.parser.error import ParsingError from src.parser.error import ParsingError
from src.pebble import Pebble from src.pebble import Pebble
from src.token import Token, TokenType from src.token import Token, TokenType
@@ -31,7 +33,7 @@ class Parser:
statements: list[Stmt] = [] statements: list[Stmt] = []
while not self.is_at_end(): while not self.is_at_end():
statements.append(self.statement()) statements.append(self.declaration())
return statements return statements
def is_at_end(self) -> bool: def is_at_end(self) -> bool:
@@ -60,8 +62,9 @@ class Parser:
return True return True
return False return False
def consume(self, token_type: TokenType, error_msg: str): def consume(self, token_type: TokenType, error_msg: str) -> Token:
if not self.match(token_type): if self.check(token_type):
return self.advance()
raise self.error(self.peek(), error_msg) raise self.error(self.peek(), error_msg)
def expect_eol(self, error_msg: str): def expect_eol(self, error_msg: str):
@@ -80,6 +83,23 @@ class Parser:
return return
self.advance() self.advance()
def declaration(self) -> Optional[Stmt]:
try:
if self.match(TokenType.LET):
return self.var_declaration()
return self.statement()
except ParsingError:
self.synchronize()
return None
def var_declaration(self) -> Stmt:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected variable name.")
initializer: Optional[Expr] = None
if self.match(TokenType.EQUAL):
initializer = self.expression()
self.expect_eol("Expected end of line after variable initialization")
return LetStmt(name, initializer)
def statement(self) -> Stmt: def statement(self) -> Stmt:
if self.match(TokenType.PRINT): if self.match(TokenType.PRINT):
return self.print_stmt() return self.print_stmt()
@@ -150,6 +170,9 @@ class Parser:
if self.match(TokenType.NUMBER, TokenType.STRING): if self.match(TokenType.NUMBER, TokenType.STRING):
return LiteralExpr(self.previous().value) return LiteralExpr(self.previous().value)
if self.match(TokenType.IDENTIFIER):
return VariableExpr(self.previous())
if self.match(TokenType.LEFT_PAREN): if self.match(TokenType.LEFT_PAREN):
expr: Expr = self.expression() expr: Expr = self.expression()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis") self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")