feat: add variable definition and reference

This commit is contained in:
2026-02-05 23:03:32 +01:00
parent 0fdcf1f896
commit 058a57909f
4 changed files with 87 additions and 9 deletions

View File

@@ -33,6 +33,10 @@ class Expr(ABC):
def visit_literal_expr(self, expr: LiteralExpr) -> T:
...
@abstractmethod
def visit_variable_expr(self, expr: VariableExpr) -> T:
...
@dataclass(frozen=True)
class BinaryExpr(Expr):
@@ -67,3 +71,11 @@ class LiteralExpr(Expr):
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_literal_expr(self)
@dataclass(frozen=True)
class VariableExpr(Expr):
name: Token
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_variable_expr(self)

View File

@@ -2,9 +2,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TypeVar, Generic
from typing import TypeVar, Generic, Optional
from src.ast.expr import Expr
from src.token import Token
T = TypeVar("T")
@@ -24,6 +25,10 @@ class Stmt(ABC):
def visit_print_stmt(self, stmt: PrintStmt) -> T:
...
@abstractmethod
def visit_let_stmt(self, stmt: LetStmt) -> T:
...
@dataclass(frozen=True)
class ExpressionStmt(Stmt):
@@ -39,3 +44,12 @@ class PrintStmt(Stmt):
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_print_stmt(self)
@dataclass(frozen=True)
class LetStmt(Stmt):
name: Token
initializer: Optional[Expr]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_let_stmt(self)

View File

@@ -1,14 +1,34 @@
from typing import Any
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr
from src.ast.stmt import Stmt, PrintStmt, T, ExpressionStmt
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr
from src.ast.stmt import Stmt, PrintStmt, T, ExpressionStmt, LetStmt
from src.interpreter.error import PebbleRuntimeError
from src.pebble import Pebble
from src.token import TokenType, Token
class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
class Environment:
def __init__(self):
self.values: dict[str, Any] = {}
def define(self, name: str, value: Any):
self.values[name] = value
def get(self, name: Token) -> Any:
try:
return self.values[name.lexeme]
except IndexError:
raise PebbleRuntimeError(name, f"Undefined variable '{name.lexeme}'.")
def clear(self):
self.values = {}
def __init__(self):
self.env: Interpreter.Environment = Interpreter.Environment()
def interpret(self, statements: list[Stmt]) -> None:
self.env.clear()
try:
for stmt in statements:
self.execute(stmt)
@@ -80,6 +100,9 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
def visit_literal_expr(self, expr: LiteralExpr) -> Any:
return expr.value
def visit_variable_expr(self, expr: VariableExpr) -> Any:
return self.env.get(expr.name)
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:
self.evaluate(stmt.expression)
@@ -87,6 +110,12 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
value: Any = self.evaluate(stmt.expression)
print(value)
def visit_let_stmt(self, stmt: LetStmt) -> None:
value: Any = None
if stmt.initializer is not None:
value = self.evaluate(stmt.initializer)
self.env.define(stmt.name.lexeme, value)
@staticmethod
def is_truthy(value: Any) -> bool:
if value is None or value is False:

View File

@@ -1,5 +1,7 @@
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt
from typing import Optional
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt
from src.parser.error import ParsingError
from src.pebble import Pebble
from src.token import Token, TokenType
@@ -31,7 +33,7 @@ class Parser:
statements: list[Stmt] = []
while not self.is_at_end():
statements.append(self.statement())
statements.append(self.declaration())
return statements
def is_at_end(self) -> bool:
@@ -60,9 +62,10 @@ class Parser:
return True
return False
def consume(self, token_type: TokenType, error_msg: str):
if not self.match(token_type):
raise self.error(self.peek(), error_msg)
def consume(self, token_type: TokenType, error_msg: str) -> Token:
if self.check(token_type):
return self.advance()
raise self.error(self.peek(), error_msg)
def expect_eol(self, error_msg: str):
if self.is_at_end():
@@ -80,6 +83,23 @@ class Parser:
return
self.advance()
def declaration(self) -> Optional[Stmt]:
try:
if self.match(TokenType.LET):
return self.var_declaration()
return self.statement()
except ParsingError:
self.synchronize()
return None
def var_declaration(self) -> Stmt:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected variable name.")
initializer: Optional[Expr] = None
if self.match(TokenType.EQUAL):
initializer = self.expression()
self.expect_eol("Expected end of line after variable initialization")
return LetStmt(name, initializer)
def statement(self) -> Stmt:
if self.match(TokenType.PRINT):
return self.print_stmt()
@@ -150,6 +170,9 @@ class Parser:
if self.match(TokenType.NUMBER, TokenType.STRING):
return LiteralExpr(self.previous().value)
if self.match(TokenType.IDENTIFIER):
return VariableExpr(self.previous())
if self.match(TokenType.LEFT_PAREN):
expr: Expr = self.expression()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")