feat(resolver): add error for top-level return statement
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum, auto
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
from src.ast.stmt import FunctionStmt
|
from src.ast.stmt import FunctionStmt
|
||||||
@@ -11,6 +12,11 @@ if TYPE_CHECKING:
|
|||||||
from src.interpreter.interpreter import Interpreter
|
from src.interpreter.interpreter import Interpreter
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionType(Enum):
|
||||||
|
NONE = auto()
|
||||||
|
FUNCTION = auto()
|
||||||
|
|
||||||
|
|
||||||
class PebbleFunction(PebbleCallable):
|
class PebbleFunction(PebbleCallable):
|
||||||
def __init__(self, declaration: FunctionStmt, closure: Environment):
|
def __init__(self, declaration: FunctionStmt, closure: Environment):
|
||||||
self.declaration: FunctionStmt = declaration
|
self.declaration: FunctionStmt = declaration
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from src.ast.expr import Expr, LogicalExpr, VariableExpr, LiteralExpr, GroupingE
|
|||||||
AssignExpr
|
AssignExpr
|
||||||
from src.ast.stmt import Stmt, ForStmt, WhileStmt, LetStmt, ReturnStmt, PrintStmt, IfStmt, FunctionStmt, \
|
from src.ast.stmt import Stmt, ForStmt, WhileStmt, LetStmt, ReturnStmt, PrintStmt, IfStmt, FunctionStmt, \
|
||||||
ExpressionStmt, BlockStmt
|
ExpressionStmt, BlockStmt
|
||||||
|
from src.core.function import FunctionType
|
||||||
from src.pebble import Pebble
|
from src.pebble import Pebble
|
||||||
from src.token import Token
|
from src.token import Token
|
||||||
|
|
||||||
@@ -17,6 +18,7 @@ class Resolver(Expr.Visitor[None], Stmt.Visitor[None]):
|
|||||||
def __init__(self, interpreter: Interpreter):
|
def __init__(self, interpreter: Interpreter):
|
||||||
self.interpreter: Interpreter = interpreter
|
self.interpreter: Interpreter = interpreter
|
||||||
self.scopes: list[dict[str, bool]] = []
|
self.scopes: list[dict[str, bool]] = []
|
||||||
|
self.current_func: FunctionType = FunctionType.NONE
|
||||||
|
|
||||||
def resolve(self, *objects: Expr | Stmt) -> None:
|
def resolve(self, *objects: Expr | Stmt) -> None:
|
||||||
for obj in objects:
|
for obj in objects:
|
||||||
@@ -46,13 +48,16 @@ class Resolver(Expr.Visitor[None], Stmt.Visitor[None]):
|
|||||||
if name.lexeme in scope:
|
if name.lexeme in scope:
|
||||||
self.interpreter.resolve(expr, i)
|
self.interpreter.resolve(expr, i)
|
||||||
|
|
||||||
def resolve_function(self, function: FunctionStmt) -> None:
|
def resolve_function(self, function: FunctionStmt, type: FunctionType) -> None:
|
||||||
|
enclosing_func: FunctionType = self.current_func
|
||||||
|
self.current_func = type
|
||||||
self.begin_scope()
|
self.begin_scope()
|
||||||
for param in function.params:
|
for param in function.params:
|
||||||
self.declare(param)
|
self.declare(param)
|
||||||
self.define(param)
|
self.define(param)
|
||||||
self.resolve(*function.body)
|
self.resolve(*function.body)
|
||||||
self.end_scope()
|
self.end_scope()
|
||||||
|
self.current_func = enclosing_func
|
||||||
|
|
||||||
def visit_assign_expr(self, expr: AssignExpr) -> None:
|
def visit_assign_expr(self, expr: AssignExpr) -> None:
|
||||||
self.resolve(expr.value)
|
self.resolve(expr.value)
|
||||||
@@ -96,7 +101,7 @@ class Resolver(Expr.Visitor[None], Stmt.Visitor[None]):
|
|||||||
def visit_function_stmt(self, stmt: FunctionStmt) -> None:
|
def visit_function_stmt(self, stmt: FunctionStmt) -> None:
|
||||||
self.declare(stmt.name)
|
self.declare(stmt.name)
|
||||||
self.define(stmt.name)
|
self.define(stmt.name)
|
||||||
self.resolve_function(stmt)
|
self.resolve_function(stmt, FunctionType.FUNCTION)
|
||||||
|
|
||||||
def visit_if_stmt(self, stmt: IfStmt) -> None:
|
def visit_if_stmt(self, stmt: IfStmt) -> None:
|
||||||
self.resolve(stmt.condition)
|
self.resolve(stmt.condition)
|
||||||
@@ -108,6 +113,9 @@ class Resolver(Expr.Visitor[None], Stmt.Visitor[None]):
|
|||||||
self.resolve(stmt.expression)
|
self.resolve(stmt.expression)
|
||||||
|
|
||||||
def visit_return_stmt(self, stmt: ReturnStmt) -> None:
|
def visit_return_stmt(self, stmt: ReturnStmt) -> None:
|
||||||
|
if self.current_func == FunctionType.NONE:
|
||||||
|
Pebble.token_error(stmt.keyword, "Cannot return from top-level scope.")
|
||||||
|
|
||||||
if stmt.value is not None:
|
if stmt.value is not None:
|
||||||
self.resolve(stmt.value)
|
self.resolve(stmt.value)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user