feat: add function definition

This commit is contained in:
2026-02-06 13:48:59 +01:00
parent 6552eaff90
commit 671b91426b
9 changed files with 95 additions and 14 deletions

View File

@@ -0,0 +1,8 @@
fun add(a, b) {
let c = a + b
print(c)
}
add(1, 3)
add(4, 6)
add(-3, 9)

10
main.py
View File

@@ -7,14 +7,8 @@ from src.token import Token
def main():
source: str = """() {} +- += / /= // sefs + {, )
}:: *
"This is a string"
3.1415
123
"This is
another string" """
path: str = "examples/05_loop.peb"
path: str = "examples/12_function_def.peb"
source: str = ""
with open(path, "r") as f:
source = f.read()
lexer: Lexer = Lexer()

View File

@@ -25,6 +25,10 @@ class Stmt(ABC):
def visit_expression_stmt(self, stmt: ExpressionStmt) -> T:
...
@abstractmethod
def visit_function_stmt(self, stmt: FunctionStmt) -> T:
...
@abstractmethod
def visit_if_stmt(self, stmt: IfStmt) -> T:
...
@@ -62,6 +66,16 @@ class ExpressionStmt(Stmt):
return visitor.visit_expression_stmt(self)
@dataclass(frozen=True)
class FunctionStmt(Stmt):
name: Token
params: list[Token]
body: list[Stmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_function_stmt(self)
@dataclass(frozen=True)
class IfStmt(Stmt):
condition: Expr

27
src/core/function.py Normal file
View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING
from src.ast.stmt import FunctionStmt
from src.core.callable import PebbleCallable
from src.interpreter.environment import Environment
if TYPE_CHECKING:
from src.interpreter.interpreter import Interpreter
class PebbleFunction(PebbleCallable):
def __init__(self, declaration: FunctionStmt):
self.declaration: FunctionStmt = declaration
def arity(self) -> int:
return len(self.declaration.params)
def call(self, interpreter: Interpreter, arguments: list[Any]) -> None:
env: Environment = Environment(interpreter.global_env)
for i, param in enumerate(self.declaration.params):
env.define(param.lexeme, arguments[i])
interpreter.execute_block(self.declaration.body, env)
def __str__(self):
return f"<function {self.declaration.name.lexeme}>"

View File

@@ -2,7 +2,7 @@ from typing import Any
from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \
CallExpr
from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt
from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt, FunctionStmt
class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
@@ -72,6 +72,17 @@ class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
def visit_expression_stmt(self, stmt: ExpressionStmt) -> str:
return self.indented(self.format(stmt.expression) + "\n")
def visit_function_stmt(self, stmt: FunctionStmt) -> str:
res: str = self.indented(f"fun {stmt.name.lexeme}")
res += f"({', '.join(param.lexeme for param in stmt.params)}) "
res += "{\n"
self.level += 1
for sub_stmt in stmt.body:
res += self.format(sub_stmt)
self.level -= 1
res += self.indented("}\n")
return res
def visit_if_stmt(self, stmt: IfStmt) -> str:
res: str = self.indented(f"if {self.format(stmt.condition)} {self.format(stmt.then_branch)}")
res = res.rstrip("\n")

View File

@@ -3,8 +3,9 @@ from typing import Any, Optional
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr, AssignExpr, LogicalExpr, \
CallExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt
from src.core.callable import PebbleCallable
from src.core.function import PebbleFunction
from src.interpreter.environment import Environment
from src.interpreter.error import PebbleRuntimeError
from src.interpreter.globals import GlobalEnvironment
@@ -142,6 +143,10 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:
self.evaluate(stmt.expression)
def visit_function_stmt(self, stmt: FunctionStmt) -> None:
function: PebbleFunction = PebbleFunction(stmt)
self.env.define(stmt.name.lexeme, function)
def visit_if_stmt(self, stmt: IfStmt) -> None:
if self.is_truthy(self.evaluate(stmt.condition)):
self.execute(stmt.then_branch)

View File

@@ -2,6 +2,7 @@ from src.token import TokenType
KEYWORDS: dict[str, TokenType] = {
"let": TokenType.LET,
"fun": TokenType.FUN,
"and": TokenType.AND,
"or": TokenType.OR,
"if": TokenType.IF,

View File

@@ -2,7 +2,7 @@ from typing import Optional
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \
CallExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt
from src.consts import MAX_FUNCTION_ARGS
from src.parser.error import ParsingError
from src.pebble import Pebble
@@ -93,6 +93,8 @@ class Parser:
def declaration(self) -> Optional[Stmt]:
try:
if self.match(TokenType.FUN):
return self.function("function")
if self.match(TokenType.LET):
return self.var_declaration()
return self.statement()
@@ -100,6 +102,24 @@ class Parser:
self.synchronize()
return None
def function(self, kind: str) -> Stmt:
# TODO: allow anonymous/lambda functions
name: Token = self.consume(TokenType.IDENTIFIER, f"Expected {kind} name.")
self.consume(TokenType.LEFT_PAREN, f"Expected '(' after {kind} name.")
parameters: list[Token] = []
if not self.check(TokenType.RIGHT_PAREN):
while True:
if len(parameters) >= MAX_FUNCTION_ARGS:
self.error(self.peek(), f"Cannot have more than {MAX_FUNCTION_ARGS} parameters.")
parameters.append(self.consume(TokenType.IDENTIFIER, "Expected parameter name."))
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after parameters.")
# TODO: allow single statement functions (without braces)
self.consume(TokenType.LEFT_BRACE, f"Expected '{{' before {kind} body.")
body: list[Stmt] = self.block_stmt()
return FunctionStmt(name, parameters, body)
def var_declaration(self) -> Stmt:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected variable name.")
initializer: Optional[Expr] = None
@@ -118,7 +138,7 @@ class Parser:
if self.match(TokenType.WHILE):
return self.while_stmt()
if self.match(TokenType.LEFT_BRACE):
return self.block_stmt()
return BlockStmt(statements=self.block_stmt())
return self.expression_stmt()
def for_stmt(self) -> Stmt:
@@ -196,7 +216,7 @@ class Parser:
body: Stmt = self.statement()
return WhileStmt(condition, body)
def block_stmt(self) -> Stmt:
def block_stmt(self) -> list[Stmt]:
statements: list[Stmt] = []
self.skip_newlines()
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
@@ -204,7 +224,7 @@ class Parser:
statements.append(self.declaration())
self.consume(TokenType.RIGHT_BRACE, "Expected '}' after block.")
return BlockStmt(statements)
return statements
def expression_stmt(self) -> Stmt:
value: Expr = self.expression()

View File

@@ -44,6 +44,7 @@ class TokenType(Enum):
# Keywords
LET = auto()
FUN = auto()
AND = auto()
OR = auto()
IF = auto()