feat: add return statements

This commit is contained in:
2026-02-06 14:06:55 +01:00
parent 671b91426b
commit fa13a88478
10 changed files with 72 additions and 6 deletions

10
examples/13_return.peb Normal file
View File

@@ -0,0 +1,10 @@
fun add(a, b) {
return a + b
}
fun double(a) {
return add(a, a)
}
print(add(1, 3))
print(double(16))

View File

@@ -7,7 +7,7 @@ from src.token import Token
def main():
path: str = "examples/12_function_def.peb"
path: str = "examples/13_return.peb"
source: str = ""
with open(path, "r") as f:
source = f.read()

View File

@@ -37,6 +37,10 @@ class Stmt(ABC):
def visit_print_stmt(self, stmt: PrintStmt) -> T:
...
@abstractmethod
def visit_return_stmt(self, stmt: ReturnStmt) -> T:
...
@abstractmethod
def visit_let_stmt(self, stmt: LetStmt) -> T:
...
@@ -94,6 +98,15 @@ class PrintStmt(Stmt):
return visitor.visit_print_stmt(self)
@dataclass(frozen=True)
class ReturnStmt(Stmt):
keyword: Token
value: Optional[Expr]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_return_stmt(self)
@dataclass(frozen=True)
class LetStmt(Stmt):
name: Token

View File

@@ -5,6 +5,7 @@ from typing import Any, TYPE_CHECKING
from src.ast.stmt import FunctionStmt
from src.core.callable import PebbleCallable
from src.interpreter.environment import Environment
from src.interpreter.exceptions import ReturnException
if TYPE_CHECKING:
from src.interpreter.interpreter import Interpreter
@@ -17,11 +18,16 @@ class PebbleFunction(PebbleCallable):
def arity(self) -> int:
return len(self.declaration.params)
def call(self, interpreter: Interpreter, arguments: list[Any]) -> None:
def call(self, interpreter: Interpreter, arguments: list[Any]) -> Any:
env: Environment = Environment(interpreter.global_env)
for i, param in enumerate(self.declaration.params):
env.define(param.lexeme, arguments[i])
try:
interpreter.execute_block(self.declaration.body, env)
except ReturnException as ret:
return ret.value
return None
def __str__(self):
return f"<function {self.declaration.name.lexeme}>"

View File

@@ -2,7 +2,8 @@ from typing import Any
from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \
CallExpr
from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt, FunctionStmt
from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt, FunctionStmt, \
ReturnStmt
class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
@@ -95,6 +96,14 @@ class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
def visit_print_stmt(self, stmt: PrintStmt) -> str:
return self.indented(f"print({self.format(stmt.expression)})\n")
def visit_return_stmt(self, stmt: ReturnStmt) -> str:
res: str = self.indented("return")
if stmt.value is not None:
res += " "
res += self.format(stmt.value).rstrip("\n")
res += "\n"
return res
def visit_while_stmt(self, stmt: WhileStmt) -> str:
return self.indented(f"while {self.format(stmt.condition)} " + self.format(stmt.body).rstrip("\n") + "\n")

View File

@@ -0,0 +1,7 @@
from typing import Any
class ReturnException(RuntimeError):
def __init__(self, value: Any):
super().__init__()
self.value: Any = value

View File

@@ -3,11 +3,13 @@ from typing import Any, Optional
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr, AssignExpr, LogicalExpr, \
CallExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt, \
ReturnStmt
from src.core.callable import PebbleCallable
from src.core.function import PebbleFunction
from src.interpreter.environment import Environment
from src.interpreter.error import PebbleRuntimeError
from src.interpreter.exceptions import ReturnException
from src.interpreter.globals import GlobalEnvironment
from src.pebble import Pebble
from src.token import TokenType, Token
@@ -157,6 +159,12 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
value: Any = self.evaluate(stmt.expression)
print(value)
def visit_return_stmt(self, stmt: ReturnStmt) -> None:
value: Any = None
if stmt.value is not None:
value = self.evaluate(stmt.value)
raise ReturnException(value)
def visit_while_stmt(self, stmt: WhileStmt) -> None:
while self.is_truthy(self.evaluate(stmt.condition)):
self.execute(stmt.body)

View File

@@ -17,4 +17,5 @@ KEYWORDS: dict[str, TokenType] = {
"true": TokenType.TRUE,
"null": TokenType.NULL,
"print": TokenType.PRINT,
"return": TokenType.RETURN,
}

View File

@@ -2,7 +2,8 @@ from typing import Optional
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \
CallExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt, \
ReturnStmt
from src.consts import MAX_FUNCTION_ARGS
from src.parser.error import ParsingError
from src.pebble import Pebble
@@ -135,6 +136,8 @@ class Parser:
return self.if_stmt()
if self.match(TokenType.PRINT):
return self.print_stmt()
if self.match(TokenType.RETURN):
return self.return_stmt()
if self.match(TokenType.WHILE):
return self.while_stmt()
if self.match(TokenType.LEFT_BRACE):
@@ -211,6 +214,14 @@ class Parser:
self.expect_eol("Expected end of line after statement")
return PrintStmt(value)
def return_stmt(self) -> Stmt:
keyword: Token = self.previous()
value: Optional[Expr] = None
if not self.check(TokenType.NEWLINE) and not self.is_at_end():
value = self.expression()
self.expect_eol("Expected end of line after return statement.")
return ReturnStmt(keyword, value)
def while_stmt(self) -> Stmt:
condition: Expr = self.expression()
body: Stmt = self.statement()

View File

@@ -55,6 +55,7 @@ class TokenType(Enum):
TO = auto()
UNTIL = auto()
BY = auto()
RETURN = auto()
# Misc
PRINT = auto()