diff --git a/examples/13_return.peb b/examples/13_return.peb new file mode 100644 index 0000000..be8d556 --- /dev/null +++ b/examples/13_return.peb @@ -0,0 +1,10 @@ +fun add(a, b) { + return a + b +} + +fun double(a) { + return add(a, a) +} + +print(add(1, 3)) +print(double(16)) \ No newline at end of file diff --git a/main.py b/main.py index affb7fa..e68794b 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,7 @@ from src.token import Token def main(): - path: str = "examples/12_function_def.peb" + path: str = "examples/13_return.peb" source: str = "" with open(path, "r") as f: source = f.read() diff --git a/src/ast/stmt.py b/src/ast/stmt.py index fe196c6..0356a7e 100644 --- a/src/ast/stmt.py +++ b/src/ast/stmt.py @@ -37,6 +37,10 @@ class Stmt(ABC): def visit_print_stmt(self, stmt: PrintStmt) -> T: ... + @abstractmethod + def visit_return_stmt(self, stmt: ReturnStmt) -> T: + ... + @abstractmethod def visit_let_stmt(self, stmt: LetStmt) -> T: ... @@ -94,6 +98,15 @@ class PrintStmt(Stmt): return visitor.visit_print_stmt(self) +@dataclass(frozen=True) +class ReturnStmt(Stmt): + keyword: Token + value: Optional[Expr] + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_return_stmt(self) + + @dataclass(frozen=True) class LetStmt(Stmt): name: Token diff --git a/src/core/function.py b/src/core/function.py index 9f95fe0..594cd08 100644 --- a/src/core/function.py +++ b/src/core/function.py @@ -5,6 +5,7 @@ from typing import Any, TYPE_CHECKING from src.ast.stmt import FunctionStmt from src.core.callable import PebbleCallable from src.interpreter.environment import Environment +from src.interpreter.exceptions import ReturnException if TYPE_CHECKING: from src.interpreter.interpreter import Interpreter @@ -17,11 +18,16 @@ class PebbleFunction(PebbleCallable): def arity(self) -> int: return len(self.declaration.params) - def call(self, interpreter: Interpreter, arguments: list[Any]) -> None: + def call(self, interpreter: Interpreter, arguments: list[Any]) -> Any: env: Environment = Environment(interpreter.global_env) for i, param in enumerate(self.declaration.params): env.define(param.lexeme, arguments[i]) - interpreter.execute_block(self.declaration.body, env) + + try: + interpreter.execute_block(self.declaration.body, env) + except ReturnException as ret: + return ret.value + return None def __str__(self): return f"" diff --git a/src/formatter.py b/src/formatter.py index c9f4f8f..35ad1dd 100644 --- a/src/formatter.py +++ b/src/formatter.py @@ -2,7 +2,8 @@ from typing import Any from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \ CallExpr -from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt, FunctionStmt +from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt, FunctionStmt, \ + ReturnStmt class Formatter(Expr.Visitor[str], Stmt.Visitor[str]): @@ -95,6 +96,14 @@ class Formatter(Expr.Visitor[str], Stmt.Visitor[str]): def visit_print_stmt(self, stmt: PrintStmt) -> str: return self.indented(f"print({self.format(stmt.expression)})\n") + def visit_return_stmt(self, stmt: ReturnStmt) -> str: + res: str = self.indented("return") + if stmt.value is not None: + res += " " + res += self.format(stmt.value).rstrip("\n") + res += "\n" + return res + def visit_while_stmt(self, stmt: WhileStmt) -> str: return self.indented(f"while {self.format(stmt.condition)} " + self.format(stmt.body).rstrip("\n") + "\n") diff --git a/src/interpreter/exceptions.py b/src/interpreter/exceptions.py new file mode 100644 index 0000000..e0b65e3 --- /dev/null +++ b/src/interpreter/exceptions.py @@ -0,0 +1,7 @@ +from typing import Any + + +class ReturnException(RuntimeError): + def __init__(self, value: Any): + super().__init__() + self.value: Any = value diff --git a/src/interpreter/interpreter.py b/src/interpreter/interpreter.py index 4000210..771750b 100644 --- a/src/interpreter/interpreter.py +++ b/src/interpreter/interpreter.py @@ -3,11 +3,13 @@ from typing import Any, Optional from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr, AssignExpr, LogicalExpr, \ CallExpr -from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt +from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt, \ + ReturnStmt from src.core.callable import PebbleCallable from src.core.function import PebbleFunction from src.interpreter.environment import Environment from src.interpreter.error import PebbleRuntimeError +from src.interpreter.exceptions import ReturnException from src.interpreter.globals import GlobalEnvironment from src.pebble import Pebble from src.token import TokenType, Token @@ -157,6 +159,12 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]): value: Any = self.evaluate(stmt.expression) print(value) + def visit_return_stmt(self, stmt: ReturnStmt) -> None: + value: Any = None + if stmt.value is not None: + value = self.evaluate(stmt.value) + raise ReturnException(value) + def visit_while_stmt(self, stmt: WhileStmt) -> None: while self.is_truthy(self.evaluate(stmt.condition)): self.execute(stmt.body) diff --git a/src/keyword.py b/src/keyword.py index f7e4b23..c6f46e9 100644 --- a/src/keyword.py +++ b/src/keyword.py @@ -17,4 +17,5 @@ KEYWORDS: dict[str, TokenType] = { "true": TokenType.TRUE, "null": TokenType.NULL, "print": TokenType.PRINT, + "return": TokenType.RETURN, } diff --git a/src/parser/parser.py b/src/parser/parser.py index f94770e..e2f6595 100644 --- a/src/parser/parser.py +++ b/src/parser/parser.py @@ -2,7 +2,8 @@ from typing import Optional from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \ CallExpr -from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt +from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt, \ + ReturnStmt from src.consts import MAX_FUNCTION_ARGS from src.parser.error import ParsingError from src.pebble import Pebble @@ -135,6 +136,8 @@ class Parser: return self.if_stmt() if self.match(TokenType.PRINT): return self.print_stmt() + if self.match(TokenType.RETURN): + return self.return_stmt() if self.match(TokenType.WHILE): return self.while_stmt() if self.match(TokenType.LEFT_BRACE): @@ -211,6 +214,14 @@ class Parser: self.expect_eol("Expected end of line after statement") return PrintStmt(value) + def return_stmt(self) -> Stmt: + keyword: Token = self.previous() + value: Optional[Expr] = None + if not self.check(TokenType.NEWLINE) and not self.is_at_end(): + value = self.expression() + self.expect_eol("Expected end of line after return statement.") + return ReturnStmt(keyword, value) + def while_stmt(self) -> Stmt: condition: Expr = self.expression() body: Stmt = self.statement() diff --git a/src/token.py b/src/token.py index 8dba9e9..b8f446f 100644 --- a/src/token.py +++ b/src/token.py @@ -55,6 +55,7 @@ class TokenType(Enum): TO = auto() UNTIL = auto() BY = auto() + RETURN = auto() # Misc PRINT = auto()