From 671b91426bbe2c42c7ffde520d712b2803fa32ae Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Fri, 6 Feb 2026 13:48:59 +0100 Subject: [PATCH] feat: add function definition --- examples/12_function_def.peb | 8 ++++++++ main.py | 10 ++-------- src/ast/stmt.py | 14 ++++++++++++++ src/core/function.py | 27 +++++++++++++++++++++++++++ src/formatter.py | 13 ++++++++++++- src/interpreter/interpreter.py | 7 ++++++- src/keyword.py | 1 + src/parser/parser.py | 28 ++++++++++++++++++++++++---- src/token.py | 1 + 9 files changed, 95 insertions(+), 14 deletions(-) create mode 100644 examples/12_function_def.peb create mode 100644 src/core/function.py diff --git a/examples/12_function_def.peb b/examples/12_function_def.peb new file mode 100644 index 0000000..7c3278f --- /dev/null +++ b/examples/12_function_def.peb @@ -0,0 +1,8 @@ +fun add(a, b) { + let c = a + b + print(c) +} + +add(1, 3) +add(4, 6) +add(-3, 9) \ No newline at end of file diff --git a/main.py b/main.py index ab744f7..affb7fa 100644 --- a/main.py +++ b/main.py @@ -7,14 +7,8 @@ from src.token import Token def main(): - source: str = """() {} +- += / /= // sefs + {, ) - }:: * - "This is a string" - 3.1415 - 123 - "This is - another string" """ - path: str = "examples/05_loop.peb" + path: str = "examples/12_function_def.peb" + source: str = "" with open(path, "r") as f: source = f.read() lexer: Lexer = Lexer() diff --git a/src/ast/stmt.py b/src/ast/stmt.py index a3e9551..fe196c6 100644 --- a/src/ast/stmt.py +++ b/src/ast/stmt.py @@ -25,6 +25,10 @@ class Stmt(ABC): def visit_expression_stmt(self, stmt: ExpressionStmt) -> T: ... + @abstractmethod + def visit_function_stmt(self, stmt: FunctionStmt) -> T: + ... + @abstractmethod def visit_if_stmt(self, stmt: IfStmt) -> T: ... @@ -62,6 +66,16 @@ class ExpressionStmt(Stmt): return visitor.visit_expression_stmt(self) +@dataclass(frozen=True) +class FunctionStmt(Stmt): + name: Token + params: list[Token] + body: list[Stmt] + + def accept(self, visitor: Stmt.Visitor[T]) -> T: + return visitor.visit_function_stmt(self) + + @dataclass(frozen=True) class IfStmt(Stmt): condition: Expr diff --git a/src/core/function.py b/src/core/function.py new file mode 100644 index 0000000..9f95fe0 --- /dev/null +++ b/src/core/function.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from src.ast.stmt import FunctionStmt +from src.core.callable import PebbleCallable +from src.interpreter.environment import Environment + +if TYPE_CHECKING: + from src.interpreter.interpreter import Interpreter + + +class PebbleFunction(PebbleCallable): + def __init__(self, declaration: FunctionStmt): + self.declaration: FunctionStmt = declaration + + def arity(self) -> int: + return len(self.declaration.params) + + def call(self, interpreter: Interpreter, arguments: list[Any]) -> None: + env: Environment = Environment(interpreter.global_env) + for i, param in enumerate(self.declaration.params): + env.define(param.lexeme, arguments[i]) + interpreter.execute_block(self.declaration.body, env) + + def __str__(self): + return f"" diff --git a/src/formatter.py b/src/formatter.py index ee9e742..c9f4f8f 100644 --- a/src/formatter.py +++ b/src/formatter.py @@ -2,7 +2,7 @@ from typing import Any from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \ CallExpr -from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt +from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt, FunctionStmt class Formatter(Expr.Visitor[str], Stmt.Visitor[str]): @@ -72,6 +72,17 @@ class Formatter(Expr.Visitor[str], Stmt.Visitor[str]): def visit_expression_stmt(self, stmt: ExpressionStmt) -> str: return self.indented(self.format(stmt.expression) + "\n") + def visit_function_stmt(self, stmt: FunctionStmt) -> str: + res: str = self.indented(f"fun {stmt.name.lexeme}") + res += f"({', '.join(param.lexeme for param in stmt.params)}) " + res += "{\n" + self.level += 1 + for sub_stmt in stmt.body: + res += self.format(sub_stmt) + self.level -= 1 + res += self.indented("}\n") + return res + def visit_if_stmt(self, stmt: IfStmt) -> str: res: str = self.indented(f"if {self.format(stmt.condition)} {self.format(stmt.then_branch)}") res = res.rstrip("\n") diff --git a/src/interpreter/interpreter.py b/src/interpreter/interpreter.py index 808e191..4000210 100644 --- a/src/interpreter/interpreter.py +++ b/src/interpreter/interpreter.py @@ -3,8 +3,9 @@ from typing import Any, Optional from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr, AssignExpr, LogicalExpr, \ CallExpr -from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt +from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt from src.core.callable import PebbleCallable +from src.core.function import PebbleFunction from src.interpreter.environment import Environment from src.interpreter.error import PebbleRuntimeError from src.interpreter.globals import GlobalEnvironment @@ -142,6 +143,10 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]): def visit_expression_stmt(self, stmt: ExpressionStmt) -> None: self.evaluate(stmt.expression) + def visit_function_stmt(self, stmt: FunctionStmt) -> None: + function: PebbleFunction = PebbleFunction(stmt) + self.env.define(stmt.name.lexeme, function) + def visit_if_stmt(self, stmt: IfStmt) -> None: if self.is_truthy(self.evaluate(stmt.condition)): self.execute(stmt.then_branch) diff --git a/src/keyword.py b/src/keyword.py index 2b39d4e..f7e4b23 100644 --- a/src/keyword.py +++ b/src/keyword.py @@ -2,6 +2,7 @@ from src.token import TokenType KEYWORDS: dict[str, TokenType] = { "let": TokenType.LET, + "fun": TokenType.FUN, "and": TokenType.AND, "or": TokenType.OR, "if": TokenType.IF, diff --git a/src/parser/parser.py b/src/parser/parser.py index 1bc4e70..f94770e 100644 --- a/src/parser/parser.py +++ b/src/parser/parser.py @@ -2,7 +2,7 @@ from typing import Optional from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \ CallExpr -from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt +from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt from src.consts import MAX_FUNCTION_ARGS from src.parser.error import ParsingError from src.pebble import Pebble @@ -93,6 +93,8 @@ class Parser: def declaration(self) -> Optional[Stmt]: try: + if self.match(TokenType.FUN): + return self.function("function") if self.match(TokenType.LET): return self.var_declaration() return self.statement() @@ -100,6 +102,24 @@ class Parser: self.synchronize() return None + def function(self, kind: str) -> Stmt: + # TODO: allow anonymous/lambda functions + name: Token = self.consume(TokenType.IDENTIFIER, f"Expected {kind} name.") + self.consume(TokenType.LEFT_PAREN, f"Expected '(' after {kind} name.") + parameters: list[Token] = [] + if not self.check(TokenType.RIGHT_PAREN): + while True: + if len(parameters) >= MAX_FUNCTION_ARGS: + self.error(self.peek(), f"Cannot have more than {MAX_FUNCTION_ARGS} parameters.") + parameters.append(self.consume(TokenType.IDENTIFIER, "Expected parameter name.")) + if not self.match(TokenType.COMMA): + break + self.consume(TokenType.RIGHT_PAREN, "Expected ')' after parameters.") + # TODO: allow single statement functions (without braces) + self.consume(TokenType.LEFT_BRACE, f"Expected '{{' before {kind} body.") + body: list[Stmt] = self.block_stmt() + return FunctionStmt(name, parameters, body) + def var_declaration(self) -> Stmt: name: Token = self.consume(TokenType.IDENTIFIER, "Expected variable name.") initializer: Optional[Expr] = None @@ -118,7 +138,7 @@ class Parser: if self.match(TokenType.WHILE): return self.while_stmt() if self.match(TokenType.LEFT_BRACE): - return self.block_stmt() + return BlockStmt(statements=self.block_stmt()) return self.expression_stmt() def for_stmt(self) -> Stmt: @@ -196,7 +216,7 @@ class Parser: body: Stmt = self.statement() return WhileStmt(condition, body) - def block_stmt(self) -> Stmt: + def block_stmt(self) -> list[Stmt]: statements: list[Stmt] = [] self.skip_newlines() while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end(): @@ -204,7 +224,7 @@ class Parser: statements.append(self.declaration()) self.consume(TokenType.RIGHT_BRACE, "Expected '}' after block.") - return BlockStmt(statements) + return statements def expression_stmt(self) -> Stmt: value: Expr = self.expression() diff --git a/src/token.py b/src/token.py index 99c04f5..8dba9e9 100644 --- a/src/token.py +++ b/src/token.py @@ -44,6 +44,7 @@ class TokenType(Enum): # Keywords LET = auto() + FUN = auto() AND = auto() OR = auto() IF = auto()