feat: add return statements

This commit is contained in:
2026-02-06 14:06:55 +01:00
parent 671b91426b
commit fa13a88478
10 changed files with 72 additions and 6 deletions

10
examples/13_return.peb Normal file
View File

@@ -0,0 +1,10 @@
fun add(a, b) {
return a + b
}
fun double(a) {
return add(a, a)
}
print(add(1, 3))
print(double(16))

View File

@@ -7,7 +7,7 @@ from src.token import Token
def main(): def main():
path: str = "examples/12_function_def.peb" path: str = "examples/13_return.peb"
source: str = "" source: str = ""
with open(path, "r") as f: with open(path, "r") as f:
source = f.read() source = f.read()

View File

@@ -37,6 +37,10 @@ class Stmt(ABC):
def visit_print_stmt(self, stmt: PrintStmt) -> T: def visit_print_stmt(self, stmt: PrintStmt) -> T:
... ...
@abstractmethod
def visit_return_stmt(self, stmt: ReturnStmt) -> T:
...
@abstractmethod @abstractmethod
def visit_let_stmt(self, stmt: LetStmt) -> T: def visit_let_stmt(self, stmt: LetStmt) -> T:
... ...
@@ -94,6 +98,15 @@ class PrintStmt(Stmt):
return visitor.visit_print_stmt(self) return visitor.visit_print_stmt(self)
@dataclass(frozen=True)
class ReturnStmt(Stmt):
keyword: Token
value: Optional[Expr]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_return_stmt(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class LetStmt(Stmt): class LetStmt(Stmt):
name: Token name: Token

View File

@@ -5,6 +5,7 @@ from typing import Any, TYPE_CHECKING
from src.ast.stmt import FunctionStmt from src.ast.stmt import FunctionStmt
from src.core.callable import PebbleCallable from src.core.callable import PebbleCallable
from src.interpreter.environment import Environment from src.interpreter.environment import Environment
from src.interpreter.exceptions import ReturnException
if TYPE_CHECKING: if TYPE_CHECKING:
from src.interpreter.interpreter import Interpreter from src.interpreter.interpreter import Interpreter
@@ -17,11 +18,16 @@ class PebbleFunction(PebbleCallable):
def arity(self) -> int: def arity(self) -> int:
return len(self.declaration.params) return len(self.declaration.params)
def call(self, interpreter: Interpreter, arguments: list[Any]) -> None: def call(self, interpreter: Interpreter, arguments: list[Any]) -> Any:
env: Environment = Environment(interpreter.global_env) env: Environment = Environment(interpreter.global_env)
for i, param in enumerate(self.declaration.params): for i, param in enumerate(self.declaration.params):
env.define(param.lexeme, arguments[i]) env.define(param.lexeme, arguments[i])
interpreter.execute_block(self.declaration.body, env)
try:
interpreter.execute_block(self.declaration.body, env)
except ReturnException as ret:
return ret.value
return None
def __str__(self): def __str__(self):
return f"<function {self.declaration.name.lexeme}>" return f"<function {self.declaration.name.lexeme}>"

View File

@@ -2,7 +2,8 @@ from typing import Any
from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \ from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \
CallExpr CallExpr
from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt, FunctionStmt from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt, FunctionStmt, \
ReturnStmt
class Formatter(Expr.Visitor[str], Stmt.Visitor[str]): class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
@@ -95,6 +96,14 @@ class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
def visit_print_stmt(self, stmt: PrintStmt) -> str: def visit_print_stmt(self, stmt: PrintStmt) -> str:
return self.indented(f"print({self.format(stmt.expression)})\n") return self.indented(f"print({self.format(stmt.expression)})\n")
def visit_return_stmt(self, stmt: ReturnStmt) -> str:
res: str = self.indented("return")
if stmt.value is not None:
res += " "
res += self.format(stmt.value).rstrip("\n")
res += "\n"
return res
def visit_while_stmt(self, stmt: WhileStmt) -> str: def visit_while_stmt(self, stmt: WhileStmt) -> str:
return self.indented(f"while {self.format(stmt.condition)} " + self.format(stmt.body).rstrip("\n") + "\n") return self.indented(f"while {self.format(stmt.condition)} " + self.format(stmt.body).rstrip("\n") + "\n")

View File

@@ -0,0 +1,7 @@
from typing import Any
class ReturnException(RuntimeError):
def __init__(self, value: Any):
super().__init__()
self.value: Any = value

View File

@@ -3,11 +3,13 @@ from typing import Any, Optional
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr, AssignExpr, LogicalExpr, \ from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr, AssignExpr, LogicalExpr, \
CallExpr CallExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt, \
ReturnStmt
from src.core.callable import PebbleCallable from src.core.callable import PebbleCallable
from src.core.function import PebbleFunction from src.core.function import PebbleFunction
from src.interpreter.environment import Environment from src.interpreter.environment import Environment
from src.interpreter.error import PebbleRuntimeError from src.interpreter.error import PebbleRuntimeError
from src.interpreter.exceptions import ReturnException
from src.interpreter.globals import GlobalEnvironment from src.interpreter.globals import GlobalEnvironment
from src.pebble import Pebble from src.pebble import Pebble
from src.token import TokenType, Token from src.token import TokenType, Token
@@ -157,6 +159,12 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
value: Any = self.evaluate(stmt.expression) value: Any = self.evaluate(stmt.expression)
print(value) print(value)
def visit_return_stmt(self, stmt: ReturnStmt) -> None:
value: Any = None
if stmt.value is not None:
value = self.evaluate(stmt.value)
raise ReturnException(value)
def visit_while_stmt(self, stmt: WhileStmt) -> None: def visit_while_stmt(self, stmt: WhileStmt) -> None:
while self.is_truthy(self.evaluate(stmt.condition)): while self.is_truthy(self.evaluate(stmt.condition)):
self.execute(stmt.body) self.execute(stmt.body)

View File

@@ -17,4 +17,5 @@ KEYWORDS: dict[str, TokenType] = {
"true": TokenType.TRUE, "true": TokenType.TRUE,
"null": TokenType.NULL, "null": TokenType.NULL,
"print": TokenType.PRINT, "print": TokenType.PRINT,
"return": TokenType.RETURN,
} }

View File

@@ -2,7 +2,8 @@ from typing import Optional
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \ from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \
CallExpr CallExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt, \
ReturnStmt
from src.consts import MAX_FUNCTION_ARGS from src.consts import MAX_FUNCTION_ARGS
from src.parser.error import ParsingError from src.parser.error import ParsingError
from src.pebble import Pebble from src.pebble import Pebble
@@ -135,6 +136,8 @@ class Parser:
return self.if_stmt() return self.if_stmt()
if self.match(TokenType.PRINT): if self.match(TokenType.PRINT):
return self.print_stmt() return self.print_stmt()
if self.match(TokenType.RETURN):
return self.return_stmt()
if self.match(TokenType.WHILE): if self.match(TokenType.WHILE):
return self.while_stmt() return self.while_stmt()
if self.match(TokenType.LEFT_BRACE): if self.match(TokenType.LEFT_BRACE):
@@ -211,6 +214,14 @@ class Parser:
self.expect_eol("Expected end of line after statement") self.expect_eol("Expected end of line after statement")
return PrintStmt(value) return PrintStmt(value)
def return_stmt(self) -> Stmt:
keyword: Token = self.previous()
value: Optional[Expr] = None
if not self.check(TokenType.NEWLINE) and not self.is_at_end():
value = self.expression()
self.expect_eol("Expected end of line after return statement.")
return ReturnStmt(keyword, value)
def while_stmt(self) -> Stmt: def while_stmt(self) -> Stmt:
condition: Expr = self.expression() condition: Expr = self.expression()
body: Stmt = self.statement() body: Stmt = self.statement()

View File

@@ -55,6 +55,7 @@ class TokenType(Enum):
TO = auto() TO = auto()
UNTIL = auto() UNTIL = auto()
BY = auto() BY = auto()
RETURN = auto()
# Misc # Misc
PRINT = auto() PRINT = auto()