feat: add basic resolver

This commit is contained in:
2026-02-06 15:05:53 +01:00
parent 0175212026
commit 4a6c00598f
6 changed files with 176 additions and 6 deletions

View File

@@ -0,0 +1,10 @@
let a = "global"
{
fun show_a() {
print(a)
}
show_a()
let a = "block"
show_a()
}

View File

@@ -1,13 +1,14 @@
from src.ast.stmt import Stmt from src.ast.stmt import Stmt
from src.formatter import Formatter from src.formatter import Formatter
from src.interpreter.interpreter import Interpreter from src.interpreter.interpreter import Interpreter
from src.interpreter.resolver import Resolver
from src.lexer import Lexer from src.lexer import Lexer
from src.parser.parser import Parser from src.parser.parser import Parser
from src.token import Token from src.token import Token
def main(): def main():
path: str = "examples/14_closure.peb" path: str = "examples/15_resolution.peb"
source: str = "" source: str = ""
with open(path, "r") as f: with open(path, "r") as f:
source = f.read() source = f.read()
@@ -19,6 +20,9 @@ def main():
program: list[Stmt] = parser.parse() program: list[Stmt] = parser.parse()
interpreter: Interpreter = Interpreter() interpreter: Interpreter = Interpreter()
resolver: Resolver = Resolver(interpreter)
resolver.resolve(*program)
interpreter.interpret(program) interpreter.interpret(program)
formatter: Formatter = Formatter() formatter: Formatter = Formatter()

View File

@@ -32,3 +32,15 @@ class Environment:
def clear(self): def clear(self):
self.values = {} self.values = {}
def get_at(self, distance: int, name: str) -> Any:
return self.ancestor(distance).values.get(name)
def assign_at(self, distance: int, name: Token, value: Any):
self.ancestor(distance).values[name.lexeme] = value
def ancestor(self, distance: int) -> Environment:
env: Environment = self
for i in range(distance):
env = env.enclosing
return env

View File

@@ -19,10 +19,9 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
def __init__(self): def __init__(self):
self.global_env: GlobalEnvironment = GlobalEnvironment() self.global_env: GlobalEnvironment = GlobalEnvironment()
self.env: Environment = self.global_env self.env: Environment = self.global_env
self.locals: dict[Expr, int] = {}
def interpret(self, statements: list[Stmt]) -> None: def interpret(self, statements: list[Stmt]) -> None:
self.global_env = GlobalEnvironment()
self.env = self.global_env
try: try:
for stmt in statements: for stmt in statements:
self.execute(stmt) self.execute(stmt)
@@ -44,9 +43,22 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
finally: finally:
self.env = previous_env self.env = previous_env
def resolve(self, expr: Expr, depth: int) -> None:
self.locals[expr] = depth
def look_up_variable(self, name: Token, expr: Expr):
distance: int = self.locals.get(expr)
if distance is not None:
return self.env.get_at(distance, name.lexeme)
return self.global_env.get(name)
def visit_assign_expr(self, expr: AssignExpr) -> Any: def visit_assign_expr(self, expr: AssignExpr) -> Any:
value: Any = self.evaluate(expr.value) value: Any = self.evaluate(expr.value)
self.env.assign(expr.name, value) distance: int = self.locals.get(expr)
if distance is not None:
self.env.assign_at(distance, expr.name, value)
else:
self.global_env.assign(expr.name, value)
return value return value
def visit_logical_expr(self, expr: LogicalExpr) -> Any: def visit_logical_expr(self, expr: LogicalExpr) -> Any:
@@ -137,7 +149,7 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
return expr.value return expr.value
def visit_variable_expr(self, expr: VariableExpr) -> Any: def visit_variable_expr(self, expr: VariableExpr) -> Any:
return self.env.get(expr.name) return self.look_up_variable(expr.name, expr)
def visit_block_stmt(self, stmt: BlockStmt) -> None: def visit_block_stmt(self, stmt: BlockStmt) -> None:
self.execute_block(stmt.statements, Environment(self.env)) self.execute_block(stmt.statements, Environment(self.env))

132
src/interpreter/resolver.py Normal file
View File

@@ -0,0 +1,132 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from src.ast.expr import Expr, LogicalExpr, VariableExpr, LiteralExpr, GroupingExpr, CallExpr, UnaryExpr, BinaryExpr, \
AssignExpr
from src.ast.stmt import Stmt, ForStmt, WhileStmt, LetStmt, ReturnStmt, PrintStmt, IfStmt, FunctionStmt, \
ExpressionStmt, BlockStmt
from src.pebble import Pebble
from src.token import Token
if TYPE_CHECKING:
from src.interpreter.interpreter import Interpreter
class Resolver(Expr.Visitor[None], Stmt.Visitor[None]):
def __init__(self, interpreter: Interpreter):
self.interpreter: Interpreter = interpreter
self.scopes: list[dict[str, bool]] = []
def resolve(self, *objects: Expr | Stmt) -> None:
for obj in objects:
obj.accept(self)
def begin_scope(self) -> None:
self.scopes.append({})
def end_scope(self) -> None:
self.scopes.pop()
def declare(self, name: Token) -> None:
if len(self.scopes) == 0:
return
self.scopes[-1][name.lexeme] = False
def define(self, name: Token) -> None:
if len(self.scopes) == 0:
return
self.scopes[-1][name.lexeme] = True
def resolve_local(self, expr: Expr, name: Token) -> None:
for i, scope in enumerate(reversed(self.scopes)):
if name.lexeme in scope:
self.interpreter.resolve(expr, i)
def resolve_function(self, function: FunctionStmt) -> None:
self.begin_scope()
for param in function.params:
self.declare(param)
self.define(param)
self.resolve(*function.body)
self.end_scope()
def visit_assign_expr(self, expr: AssignExpr) -> None:
self.resolve(expr.value)
self.resolve_local(expr, expr.name)
def visit_binary_expr(self, expr: BinaryExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)
def visit_unary_expr(self, expr: UnaryExpr) -> None:
self.resolve(expr.right)
def visit_call_expr(self, expr: CallExpr) -> None:
self.resolve(expr.callee)
for arg in expr.arguments:
self.resolve(arg)
def visit_grouping_expr(self, expr: GroupingExpr) -> None:
self.resolve(expr.expression)
def visit_literal_expr(self, expr: LiteralExpr) -> None:
pass
def visit_variable_expr(self, expr: VariableExpr) -> None:
if len(self.scopes) != 0 and self.scopes[-1].get(expr.name.lexeme) is False:
Pebble.token_error(expr.name, "Variable is not initialized.")
self.resolve_local(expr, expr.name)
def visit_logical_expr(self, expr: LogicalExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)
def visit_block_stmt(self, stmt: BlockStmt) -> None:
self.begin_scope()
self.resolve(*stmt.statements)
self.end_scope()
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:
self.resolve(stmt.expression)
def visit_function_stmt(self, stmt: FunctionStmt) -> None:
self.declare(stmt.name)
self.define(stmt.name)
self.resolve_function(stmt)
def visit_if_stmt(self, stmt: IfStmt) -> None:
self.resolve(stmt.condition)
self.resolve(stmt.then_branch)
if stmt.else_branch is not None:
self.resolve(stmt.else_branch)
def visit_print_stmt(self, stmt: PrintStmt) -> None:
self.resolve(stmt.expression)
def visit_return_stmt(self, stmt: ReturnStmt) -> None:
if stmt.value is not None:
self.resolve(stmt.value)
def visit_let_stmt(self, stmt: LetStmt) -> None:
self.declare(stmt.name)
if stmt.initializer is not None:
self.resolve(stmt.initializer)
self.define(stmt.name)
def visit_while_stmt(self, stmt: WhileStmt) -> None:
self.resolve(stmt.condition)
self.resolve(stmt.body)
def visit_for_stmt(self, stmt: ForStmt) -> None:
self.begin_scope()
self.declare(stmt.variable)
self.define(stmt.variable)
if stmt.start is not None:
self.resolve(stmt.start)
if stmt.end is not None:
self.resolve(stmt.end)
if stmt.step is not None:
self.resolve(stmt.step)
self.resolve(stmt.body)
self.end_scope()

View File

@@ -65,7 +65,7 @@ class TokenType(Enum):
NEWLINE = auto() NEWLINE = auto()
@dataclass @dataclass(frozen=True)
class Token: class Token:
type: TokenType type: TokenType
lexeme: str lexeme: str