feat: add basic resolver

This commit is contained in:
2026-02-06 15:05:53 +01:00
parent 0175212026
commit 4a6c00598f
6 changed files with 176 additions and 6 deletions

View File

@@ -0,0 +1,10 @@
let a = "global"
{
fun show_a() {
print(a)
}
show_a()
let a = "block"
show_a()
}

View File

@@ -1,13 +1,14 @@
from src.ast.stmt import Stmt
from src.formatter import Formatter
from src.interpreter.interpreter import Interpreter
from src.interpreter.resolver import Resolver
from src.lexer import Lexer
from src.parser.parser import Parser
from src.token import Token
def main():
path: str = "examples/14_closure.peb"
path: str = "examples/15_resolution.peb"
source: str = ""
with open(path, "r") as f:
source = f.read()
@@ -19,6 +20,9 @@ def main():
program: list[Stmt] = parser.parse()
interpreter: Interpreter = Interpreter()
resolver: Resolver = Resolver(interpreter)
resolver.resolve(*program)
interpreter.interpret(program)
formatter: Formatter = Formatter()

View File

@@ -32,3 +32,15 @@ class Environment:
def clear(self):
self.values = {}
def get_at(self, distance: int, name: str) -> Any:
return self.ancestor(distance).values.get(name)
def assign_at(self, distance: int, name: Token, value: Any):
self.ancestor(distance).values[name.lexeme] = value
def ancestor(self, distance: int) -> Environment:
env: Environment = self
for i in range(distance):
env = env.enclosing
return env

View File

@@ -19,10 +19,9 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
def __init__(self):
self.global_env: GlobalEnvironment = GlobalEnvironment()
self.env: Environment = self.global_env
self.locals: dict[Expr, int] = {}
def interpret(self, statements: list[Stmt]) -> None:
self.global_env = GlobalEnvironment()
self.env = self.global_env
try:
for stmt in statements:
self.execute(stmt)
@@ -44,9 +43,22 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
finally:
self.env = previous_env
def resolve(self, expr: Expr, depth: int) -> None:
self.locals[expr] = depth
def look_up_variable(self, name: Token, expr: Expr):
distance: int = self.locals.get(expr)
if distance is not None:
return self.env.get_at(distance, name.lexeme)
return self.global_env.get(name)
def visit_assign_expr(self, expr: AssignExpr) -> Any:
value: Any = self.evaluate(expr.value)
self.env.assign(expr.name, value)
distance: int = self.locals.get(expr)
if distance is not None:
self.env.assign_at(distance, expr.name, value)
else:
self.global_env.assign(expr.name, value)
return value
def visit_logical_expr(self, expr: LogicalExpr) -> Any:
@@ -137,7 +149,7 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
return expr.value
def visit_variable_expr(self, expr: VariableExpr) -> Any:
return self.env.get(expr.name)
return self.look_up_variable(expr.name, expr)
def visit_block_stmt(self, stmt: BlockStmt) -> None:
self.execute_block(stmt.statements, Environment(self.env))

132
src/interpreter/resolver.py Normal file
View File

@@ -0,0 +1,132 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from src.ast.expr import Expr, LogicalExpr, VariableExpr, LiteralExpr, GroupingExpr, CallExpr, UnaryExpr, BinaryExpr, \
AssignExpr
from src.ast.stmt import Stmt, ForStmt, WhileStmt, LetStmt, ReturnStmt, PrintStmt, IfStmt, FunctionStmt, \
ExpressionStmt, BlockStmt
from src.pebble import Pebble
from src.token import Token
if TYPE_CHECKING:
from src.interpreter.interpreter import Interpreter
class Resolver(Expr.Visitor[None], Stmt.Visitor[None]):
def __init__(self, interpreter: Interpreter):
self.interpreter: Interpreter = interpreter
self.scopes: list[dict[str, bool]] = []
def resolve(self, *objects: Expr | Stmt) -> None:
for obj in objects:
obj.accept(self)
def begin_scope(self) -> None:
self.scopes.append({})
def end_scope(self) -> None:
self.scopes.pop()
def declare(self, name: Token) -> None:
if len(self.scopes) == 0:
return
self.scopes[-1][name.lexeme] = False
def define(self, name: Token) -> None:
if len(self.scopes) == 0:
return
self.scopes[-1][name.lexeme] = True
def resolve_local(self, expr: Expr, name: Token) -> None:
for i, scope in enumerate(reversed(self.scopes)):
if name.lexeme in scope:
self.interpreter.resolve(expr, i)
def resolve_function(self, function: FunctionStmt) -> None:
self.begin_scope()
for param in function.params:
self.declare(param)
self.define(param)
self.resolve(*function.body)
self.end_scope()
def visit_assign_expr(self, expr: AssignExpr) -> None:
self.resolve(expr.value)
self.resolve_local(expr, expr.name)
def visit_binary_expr(self, expr: BinaryExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)
def visit_unary_expr(self, expr: UnaryExpr) -> None:
self.resolve(expr.right)
def visit_call_expr(self, expr: CallExpr) -> None:
self.resolve(expr.callee)
for arg in expr.arguments:
self.resolve(arg)
def visit_grouping_expr(self, expr: GroupingExpr) -> None:
self.resolve(expr.expression)
def visit_literal_expr(self, expr: LiteralExpr) -> None:
pass
def visit_variable_expr(self, expr: VariableExpr) -> None:
if len(self.scopes) != 0 and self.scopes[-1].get(expr.name.lexeme) is False:
Pebble.token_error(expr.name, "Variable is not initialized.")
self.resolve_local(expr, expr.name)
def visit_logical_expr(self, expr: LogicalExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)
def visit_block_stmt(self, stmt: BlockStmt) -> None:
self.begin_scope()
self.resolve(*stmt.statements)
self.end_scope()
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:
self.resolve(stmt.expression)
def visit_function_stmt(self, stmt: FunctionStmt) -> None:
self.declare(stmt.name)
self.define(stmt.name)
self.resolve_function(stmt)
def visit_if_stmt(self, stmt: IfStmt) -> None:
self.resolve(stmt.condition)
self.resolve(stmt.then_branch)
if stmt.else_branch is not None:
self.resolve(stmt.else_branch)
def visit_print_stmt(self, stmt: PrintStmt) -> None:
self.resolve(stmt.expression)
def visit_return_stmt(self, stmt: ReturnStmt) -> None:
if stmt.value is not None:
self.resolve(stmt.value)
def visit_let_stmt(self, stmt: LetStmt) -> None:
self.declare(stmt.name)
if stmt.initializer is not None:
self.resolve(stmt.initializer)
self.define(stmt.name)
def visit_while_stmt(self, stmt: WhileStmt) -> None:
self.resolve(stmt.condition)
self.resolve(stmt.body)
def visit_for_stmt(self, stmt: ForStmt) -> None:
self.begin_scope()
self.declare(stmt.variable)
self.define(stmt.variable)
if stmt.start is not None:
self.resolve(stmt.start)
if stmt.end is not None:
self.resolve(stmt.end)
if stmt.step is not None:
self.resolve(stmt.step)
self.resolve(stmt.body)
self.end_scope()

View File

@@ -65,7 +65,7 @@ class TokenType(Enum):
NEWLINE = auto()
@dataclass
@dataclass(frozen=True)
class Token:
type: TokenType
lexeme: str