diff --git a/examples/15_resolution.peb b/examples/15_resolution.peb new file mode 100644 index 0000000..9c0c80b --- /dev/null +++ b/examples/15_resolution.peb @@ -0,0 +1,10 @@ +let a = "global" +{ + fun show_a() { + print(a) + } + + show_a() + let a = "block" + show_a() +} diff --git a/main.py b/main.py index ec5d8ed..4c5e0e4 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,14 @@ from src.ast.stmt import Stmt from src.formatter import Formatter from src.interpreter.interpreter import Interpreter +from src.interpreter.resolver import Resolver from src.lexer import Lexer from src.parser.parser import Parser from src.token import Token def main(): - path: str = "examples/14_closure.peb" + path: str = "examples/15_resolution.peb" source: str = "" with open(path, "r") as f: source = f.read() @@ -19,6 +20,9 @@ def main(): program: list[Stmt] = parser.parse() interpreter: Interpreter = Interpreter() + resolver: Resolver = Resolver(interpreter) + + resolver.resolve(*program) interpreter.interpret(program) formatter: Formatter = Formatter() diff --git a/src/interpreter/environment.py b/src/interpreter/environment.py index 496c495..e99fc49 100644 --- a/src/interpreter/environment.py +++ b/src/interpreter/environment.py @@ -32,3 +32,15 @@ class Environment: def clear(self): self.values = {} + + def get_at(self, distance: int, name: str) -> Any: + return self.ancestor(distance).values.get(name) + + def assign_at(self, distance: int, name: Token, value: Any): + self.ancestor(distance).values[name.lexeme] = value + + def ancestor(self, distance: int) -> Environment: + env: Environment = self + for i in range(distance): + env = env.enclosing + return env diff --git a/src/interpreter/interpreter.py b/src/interpreter/interpreter.py index d685242..0d8206a 100644 --- a/src/interpreter/interpreter.py +++ b/src/interpreter/interpreter.py @@ -19,10 +19,9 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]): def __init__(self): self.global_env: GlobalEnvironment = GlobalEnvironment() self.env: Environment = self.global_env + self.locals: dict[Expr, int] = {} def interpret(self, statements: list[Stmt]) -> None: - self.global_env = GlobalEnvironment() - self.env = self.global_env try: for stmt in statements: self.execute(stmt) @@ -44,9 +43,22 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]): finally: self.env = previous_env + def resolve(self, expr: Expr, depth: int) -> None: + self.locals[expr] = depth + + def look_up_variable(self, name: Token, expr: Expr): + distance: int = self.locals.get(expr) + if distance is not None: + return self.env.get_at(distance, name.lexeme) + return self.global_env.get(name) + def visit_assign_expr(self, expr: AssignExpr) -> Any: value: Any = self.evaluate(expr.value) - self.env.assign(expr.name, value) + distance: int = self.locals.get(expr) + if distance is not None: + self.env.assign_at(distance, expr.name, value) + else: + self.global_env.assign(expr.name, value) return value def visit_logical_expr(self, expr: LogicalExpr) -> Any: @@ -137,7 +149,7 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]): return expr.value def visit_variable_expr(self, expr: VariableExpr) -> Any: - return self.env.get(expr.name) + return self.look_up_variable(expr.name, expr) def visit_block_stmt(self, stmt: BlockStmt) -> None: self.execute_block(stmt.statements, Environment(self.env)) diff --git a/src/interpreter/resolver.py b/src/interpreter/resolver.py new file mode 100644 index 0000000..002ff13 --- /dev/null +++ b/src/interpreter/resolver.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from src.ast.expr import Expr, LogicalExpr, VariableExpr, LiteralExpr, GroupingExpr, CallExpr, UnaryExpr, BinaryExpr, \ + AssignExpr +from src.ast.stmt import Stmt, ForStmt, WhileStmt, LetStmt, ReturnStmt, PrintStmt, IfStmt, FunctionStmt, \ + ExpressionStmt, BlockStmt +from src.pebble import Pebble +from src.token import Token + +if TYPE_CHECKING: + from src.interpreter.interpreter import Interpreter + + +class Resolver(Expr.Visitor[None], Stmt.Visitor[None]): + def __init__(self, interpreter: Interpreter): + self.interpreter: Interpreter = interpreter + self.scopes: list[dict[str, bool]] = [] + + def resolve(self, *objects: Expr | Stmt) -> None: + for obj in objects: + obj.accept(self) + + def begin_scope(self) -> None: + self.scopes.append({}) + + def end_scope(self) -> None: + self.scopes.pop() + + def declare(self, name: Token) -> None: + if len(self.scopes) == 0: + return + self.scopes[-1][name.lexeme] = False + + def define(self, name: Token) -> None: + if len(self.scopes) == 0: + return + self.scopes[-1][name.lexeme] = True + + def resolve_local(self, expr: Expr, name: Token) -> None: + for i, scope in enumerate(reversed(self.scopes)): + if name.lexeme in scope: + self.interpreter.resolve(expr, i) + + def resolve_function(self, function: FunctionStmt) -> None: + self.begin_scope() + for param in function.params: + self.declare(param) + self.define(param) + self.resolve(*function.body) + self.end_scope() + + def visit_assign_expr(self, expr: AssignExpr) -> None: + self.resolve(expr.value) + self.resolve_local(expr, expr.name) + + def visit_binary_expr(self, expr: BinaryExpr) -> None: + self.resolve(expr.left) + self.resolve(expr.right) + + def visit_unary_expr(self, expr: UnaryExpr) -> None: + self.resolve(expr.right) + + def visit_call_expr(self, expr: CallExpr) -> None: + self.resolve(expr.callee) + for arg in expr.arguments: + self.resolve(arg) + + def visit_grouping_expr(self, expr: GroupingExpr) -> None: + self.resolve(expr.expression) + + def visit_literal_expr(self, expr: LiteralExpr) -> None: + pass + + def visit_variable_expr(self, expr: VariableExpr) -> None: + if len(self.scopes) != 0 and self.scopes[-1].get(expr.name.lexeme) is False: + Pebble.token_error(expr.name, "Variable is not initialized.") + self.resolve_local(expr, expr.name) + + def visit_logical_expr(self, expr: LogicalExpr) -> None: + self.resolve(expr.left) + self.resolve(expr.right) + + def visit_block_stmt(self, stmt: BlockStmt) -> None: + self.begin_scope() + self.resolve(*stmt.statements) + self.end_scope() + + def visit_expression_stmt(self, stmt: ExpressionStmt) -> None: + self.resolve(stmt.expression) + + def visit_function_stmt(self, stmt: FunctionStmt) -> None: + self.declare(stmt.name) + self.define(stmt.name) + self.resolve_function(stmt) + + def visit_if_stmt(self, stmt: IfStmt) -> None: + self.resolve(stmt.condition) + self.resolve(stmt.then_branch) + if stmt.else_branch is not None: + self.resolve(stmt.else_branch) + + def visit_print_stmt(self, stmt: PrintStmt) -> None: + self.resolve(stmt.expression) + + def visit_return_stmt(self, stmt: ReturnStmt) -> None: + if stmt.value is not None: + self.resolve(stmt.value) + + def visit_let_stmt(self, stmt: LetStmt) -> None: + self.declare(stmt.name) + if stmt.initializer is not None: + self.resolve(stmt.initializer) + self.define(stmt.name) + + def visit_while_stmt(self, stmt: WhileStmt) -> None: + self.resolve(stmt.condition) + self.resolve(stmt.body) + + def visit_for_stmt(self, stmt: ForStmt) -> None: + self.begin_scope() + self.declare(stmt.variable) + self.define(stmt.variable) + if stmt.start is not None: + self.resolve(stmt.start) + if stmt.end is not None: + self.resolve(stmt.end) + if stmt.step is not None: + self.resolve(stmt.step) + self.resolve(stmt.body) + self.end_scope() diff --git a/src/token.py b/src/token.py index b8f446f..7a8e6f2 100644 --- a/src/token.py +++ b/src/token.py @@ -65,7 +65,7 @@ class TokenType(Enum): NEWLINE = auto() -@dataclass +@dataclass(frozen=True) class Token: type: TokenType lexeme: str