feat(resolver): add error for top-level return statement

This commit is contained in:
2026-02-06 15:16:06 +01:00
parent c6eb1ab2c9
commit 7e844a5007
2 changed files with 16 additions and 2 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum, auto
from typing import Any, TYPE_CHECKING from typing import Any, TYPE_CHECKING
from src.ast.stmt import FunctionStmt from src.ast.stmt import FunctionStmt
@@ -11,6 +12,11 @@ if TYPE_CHECKING:
from src.interpreter.interpreter import Interpreter from src.interpreter.interpreter import Interpreter
class FunctionType(Enum):
NONE = auto()
FUNCTION = auto()
class PebbleFunction(PebbleCallable): class PebbleFunction(PebbleCallable):
def __init__(self, declaration: FunctionStmt, closure: Environment): def __init__(self, declaration: FunctionStmt, closure: Environment):
self.declaration: FunctionStmt = declaration self.declaration: FunctionStmt = declaration

View File

@@ -6,6 +6,7 @@ from src.ast.expr import Expr, LogicalExpr, VariableExpr, LiteralExpr, GroupingE
AssignExpr AssignExpr
from src.ast.stmt import Stmt, ForStmt, WhileStmt, LetStmt, ReturnStmt, PrintStmt, IfStmt, FunctionStmt, \ from src.ast.stmt import Stmt, ForStmt, WhileStmt, LetStmt, ReturnStmt, PrintStmt, IfStmt, FunctionStmt, \
ExpressionStmt, BlockStmt ExpressionStmt, BlockStmt
from src.core.function import FunctionType
from src.pebble import Pebble from src.pebble import Pebble
from src.token import Token from src.token import Token
@@ -17,6 +18,7 @@ class Resolver(Expr.Visitor[None], Stmt.Visitor[None]):
def __init__(self, interpreter: Interpreter): def __init__(self, interpreter: Interpreter):
self.interpreter: Interpreter = interpreter self.interpreter: Interpreter = interpreter
self.scopes: list[dict[str, bool]] = [] self.scopes: list[dict[str, bool]] = []
self.current_func: FunctionType = FunctionType.NONE
def resolve(self, *objects: Expr | Stmt) -> None: def resolve(self, *objects: Expr | Stmt) -> None:
for obj in objects: for obj in objects:
@@ -46,13 +48,16 @@ class Resolver(Expr.Visitor[None], Stmt.Visitor[None]):
if name.lexeme in scope: if name.lexeme in scope:
self.interpreter.resolve(expr, i) self.interpreter.resolve(expr, i)
def resolve_function(self, function: FunctionStmt) -> None: def resolve_function(self, function: FunctionStmt, type: FunctionType) -> None:
enclosing_func: FunctionType = self.current_func
self.current_func = type
self.begin_scope() self.begin_scope()
for param in function.params: for param in function.params:
self.declare(param) self.declare(param)
self.define(param) self.define(param)
self.resolve(*function.body) self.resolve(*function.body)
self.end_scope() self.end_scope()
self.current_func = enclosing_func
def visit_assign_expr(self, expr: AssignExpr) -> None: def visit_assign_expr(self, expr: AssignExpr) -> None:
self.resolve(expr.value) self.resolve(expr.value)
@@ -96,7 +101,7 @@ class Resolver(Expr.Visitor[None], Stmt.Visitor[None]):
def visit_function_stmt(self, stmt: FunctionStmt) -> None: def visit_function_stmt(self, stmt: FunctionStmt) -> None:
self.declare(stmt.name) self.declare(stmt.name)
self.define(stmt.name) self.define(stmt.name)
self.resolve_function(stmt) self.resolve_function(stmt, FunctionType.FUNCTION)
def visit_if_stmt(self, stmt: IfStmt) -> None: def visit_if_stmt(self, stmt: IfStmt) -> None:
self.resolve(stmt.condition) self.resolve(stmt.condition)
@@ -108,6 +113,9 @@ class Resolver(Expr.Visitor[None], Stmt.Visitor[None]):
self.resolve(stmt.expression) self.resolve(stmt.expression)
def visit_return_stmt(self, stmt: ReturnStmt) -> None: def visit_return_stmt(self, stmt: ReturnStmt) -> None:
if self.current_func == FunctionType.NONE:
Pebble.token_error(stmt.keyword, "Cannot return from top-level scope.")
if stmt.value is not None: if stmt.value is not None:
self.resolve(stmt.value) self.resolve(stmt.value)