feat: add variable definition and reference
This commit is contained in:
@@ -33,6 +33,10 @@ class Expr(ABC):
|
|||||||
def visit_literal_expr(self, expr: LiteralExpr) -> T:
|
def visit_literal_expr(self, expr: LiteralExpr) -> T:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_variable_expr(self, expr: VariableExpr) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class BinaryExpr(Expr):
|
class BinaryExpr(Expr):
|
||||||
@@ -67,3 +71,11 @@ class LiteralExpr(Expr):
|
|||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
return visitor.visit_literal_expr(self)
|
return visitor.visit_literal_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class VariableExpr(Expr):
|
||||||
|
name: Token
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_variable_expr(self)
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TypeVar, Generic
|
from typing import TypeVar, Generic, Optional
|
||||||
|
|
||||||
from src.ast.expr import Expr
|
from src.ast.expr import Expr
|
||||||
|
from src.token import Token
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@@ -24,6 +25,10 @@ class Stmt(ABC):
|
|||||||
def visit_print_stmt(self, stmt: PrintStmt) -> T:
|
def visit_print_stmt(self, stmt: PrintStmt) -> T:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_let_stmt(self, stmt: LetStmt) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ExpressionStmt(Stmt):
|
class ExpressionStmt(Stmt):
|
||||||
@@ -39,3 +44,12 @@ class PrintStmt(Stmt):
|
|||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
return visitor.visit_print_stmt(self)
|
return visitor.visit_print_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LetStmt(Stmt):
|
||||||
|
name: Token
|
||||||
|
initializer: Optional[Expr]
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_let_stmt(self)
|
||||||
|
|||||||
@@ -1,14 +1,34 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr
|
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr
|
||||||
from src.ast.stmt import Stmt, PrintStmt, T, ExpressionStmt
|
from src.ast.stmt import Stmt, PrintStmt, T, ExpressionStmt, LetStmt
|
||||||
from src.interpreter.error import PebbleRuntimeError
|
from src.interpreter.error import PebbleRuntimeError
|
||||||
from src.pebble import Pebble
|
from src.pebble import Pebble
|
||||||
from src.token import TokenType, Token
|
from src.token import TokenType, Token
|
||||||
|
|
||||||
|
|
||||||
class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
|
class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
|
||||||
|
class Environment:
|
||||||
|
def __init__(self):
|
||||||
|
self.values: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def define(self, name: str, value: Any):
|
||||||
|
self.values[name] = value
|
||||||
|
|
||||||
|
def get(self, name: Token) -> Any:
|
||||||
|
try:
|
||||||
|
return self.values[name.lexeme]
|
||||||
|
except IndexError:
|
||||||
|
raise PebbleRuntimeError(name, f"Undefined variable '{name.lexeme}'.")
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.values = {}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.env: Interpreter.Environment = Interpreter.Environment()
|
||||||
|
|
||||||
def interpret(self, statements: list[Stmt]) -> None:
|
def interpret(self, statements: list[Stmt]) -> None:
|
||||||
|
self.env.clear()
|
||||||
try:
|
try:
|
||||||
for stmt in statements:
|
for stmt in statements:
|
||||||
self.execute(stmt)
|
self.execute(stmt)
|
||||||
@@ -80,6 +100,9 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
|
|||||||
def visit_literal_expr(self, expr: LiteralExpr) -> Any:
|
def visit_literal_expr(self, expr: LiteralExpr) -> Any:
|
||||||
return expr.value
|
return expr.value
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: VariableExpr) -> Any:
|
||||||
|
return self.env.get(expr.name)
|
||||||
|
|
||||||
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:
|
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:
|
||||||
self.evaluate(stmt.expression)
|
self.evaluate(stmt.expression)
|
||||||
|
|
||||||
@@ -87,6 +110,12 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
|
|||||||
value: Any = self.evaluate(stmt.expression)
|
value: Any = self.evaluate(stmt.expression)
|
||||||
print(value)
|
print(value)
|
||||||
|
|
||||||
|
def visit_let_stmt(self, stmt: LetStmt) -> None:
|
||||||
|
value: Any = None
|
||||||
|
if stmt.initializer is not None:
|
||||||
|
value = self.evaluate(stmt.initializer)
|
||||||
|
self.env.define(stmt.name.lexeme, value)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_truthy(value: Any) -> bool:
|
def is_truthy(value: Any) -> bool:
|
||||||
if value is None or value is False:
|
if value is None or value is False:
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr
|
from typing import Optional
|
||||||
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt
|
|
||||||
|
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr
|
||||||
|
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt
|
||||||
from src.parser.error import ParsingError
|
from src.parser.error import ParsingError
|
||||||
from src.pebble import Pebble
|
from src.pebble import Pebble
|
||||||
from src.token import Token, TokenType
|
from src.token import Token, TokenType
|
||||||
@@ -31,7 +33,7 @@ class Parser:
|
|||||||
|
|
||||||
statements: list[Stmt] = []
|
statements: list[Stmt] = []
|
||||||
while not self.is_at_end():
|
while not self.is_at_end():
|
||||||
statements.append(self.statement())
|
statements.append(self.declaration())
|
||||||
return statements
|
return statements
|
||||||
|
|
||||||
def is_at_end(self) -> bool:
|
def is_at_end(self) -> bool:
|
||||||
@@ -60,9 +62,10 @@ class Parser:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def consume(self, token_type: TokenType, error_msg: str):
|
def consume(self, token_type: TokenType, error_msg: str) -> Token:
|
||||||
if not self.match(token_type):
|
if self.check(token_type):
|
||||||
raise self.error(self.peek(), error_msg)
|
return self.advance()
|
||||||
|
raise self.error(self.peek(), error_msg)
|
||||||
|
|
||||||
def expect_eol(self, error_msg: str):
|
def expect_eol(self, error_msg: str):
|
||||||
if self.is_at_end():
|
if self.is_at_end():
|
||||||
@@ -80,6 +83,23 @@ class Parser:
|
|||||||
return
|
return
|
||||||
self.advance()
|
self.advance()
|
||||||
|
|
||||||
|
def declaration(self) -> Optional[Stmt]:
|
||||||
|
try:
|
||||||
|
if self.match(TokenType.LET):
|
||||||
|
return self.var_declaration()
|
||||||
|
return self.statement()
|
||||||
|
except ParsingError:
|
||||||
|
self.synchronize()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def var_declaration(self) -> Stmt:
|
||||||
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected variable name.")
|
||||||
|
initializer: Optional[Expr] = None
|
||||||
|
if self.match(TokenType.EQUAL):
|
||||||
|
initializer = self.expression()
|
||||||
|
self.expect_eol("Expected end of line after variable initialization")
|
||||||
|
return LetStmt(name, initializer)
|
||||||
|
|
||||||
def statement(self) -> Stmt:
|
def statement(self) -> Stmt:
|
||||||
if self.match(TokenType.PRINT):
|
if self.match(TokenType.PRINT):
|
||||||
return self.print_stmt()
|
return self.print_stmt()
|
||||||
@@ -150,6 +170,9 @@ class Parser:
|
|||||||
if self.match(TokenType.NUMBER, TokenType.STRING):
|
if self.match(TokenType.NUMBER, TokenType.STRING):
|
||||||
return LiteralExpr(self.previous().value)
|
return LiteralExpr(self.previous().value)
|
||||||
|
|
||||||
|
if self.match(TokenType.IDENTIFIER):
|
||||||
|
return VariableExpr(self.previous())
|
||||||
|
|
||||||
if self.match(TokenType.LEFT_PAREN):
|
if self.match(TokenType.LEFT_PAREN):
|
||||||
expr: Expr = self.expression()
|
expr: Expr = self.expression()
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
||||||
|
|||||||
Reference in New Issue
Block a user