feat: add function definition

This commit is contained in:
2026-02-06 13:48:59 +01:00
parent 6552eaff90
commit 671b91426b
9 changed files with 95 additions and 14 deletions

View File

@@ -0,0 +1,8 @@
fun add(a, b) {
let c = a + b
print(c)
}
add(1, 3)
add(4, 6)
add(-3, 9)

10
main.py
View File

@@ -7,14 +7,8 @@ from src.token import Token
def main(): def main():
source: str = """() {} +- += / /= // sefs + {, ) path: str = "examples/12_function_def.peb"
}:: * source: str = ""
"This is a string"
3.1415
123
"This is
another string" """
path: str = "examples/05_loop.peb"
with open(path, "r") as f: with open(path, "r") as f:
source = f.read() source = f.read()
lexer: Lexer = Lexer() lexer: Lexer = Lexer()

View File

@@ -25,6 +25,10 @@ class Stmt(ABC):
def visit_expression_stmt(self, stmt: ExpressionStmt) -> T: def visit_expression_stmt(self, stmt: ExpressionStmt) -> T:
... ...
@abstractmethod
def visit_function_stmt(self, stmt: FunctionStmt) -> T:
...
@abstractmethod @abstractmethod
def visit_if_stmt(self, stmt: IfStmt) -> T: def visit_if_stmt(self, stmt: IfStmt) -> T:
... ...
@@ -62,6 +66,16 @@ class ExpressionStmt(Stmt):
return visitor.visit_expression_stmt(self) return visitor.visit_expression_stmt(self)
@dataclass(frozen=True)
class FunctionStmt(Stmt):
name: Token
params: list[Token]
body: list[Stmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_function_stmt(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class IfStmt(Stmt): class IfStmt(Stmt):
condition: Expr condition: Expr

27
src/core/function.py Normal file
View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING
from src.ast.stmt import FunctionStmt
from src.core.callable import PebbleCallable
from src.interpreter.environment import Environment
if TYPE_CHECKING:
from src.interpreter.interpreter import Interpreter
class PebbleFunction(PebbleCallable):
def __init__(self, declaration: FunctionStmt):
self.declaration: FunctionStmt = declaration
def arity(self) -> int:
return len(self.declaration.params)
def call(self, interpreter: Interpreter, arguments: list[Any]) -> None:
env: Environment = Environment(interpreter.global_env)
for i, param in enumerate(self.declaration.params):
env.define(param.lexeme, arguments[i])
interpreter.execute_block(self.declaration.body, env)
def __str__(self):
return f"<function {self.declaration.name.lexeme}>"

View File

@@ -2,7 +2,7 @@ from typing import Any
from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \ from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \
CallExpr CallExpr
from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt, FunctionStmt
class Formatter(Expr.Visitor[str], Stmt.Visitor[str]): class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
@@ -72,6 +72,17 @@ class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
def visit_expression_stmt(self, stmt: ExpressionStmt) -> str: def visit_expression_stmt(self, stmt: ExpressionStmt) -> str:
return self.indented(self.format(stmt.expression) + "\n") return self.indented(self.format(stmt.expression) + "\n")
def visit_function_stmt(self, stmt: FunctionStmt) -> str:
res: str = self.indented(f"fun {stmt.name.lexeme}")
res += f"({', '.join(param.lexeme for param in stmt.params)}) "
res += "{\n"
self.level += 1
for sub_stmt in stmt.body:
res += self.format(sub_stmt)
self.level -= 1
res += self.indented("}\n")
return res
def visit_if_stmt(self, stmt: IfStmt) -> str: def visit_if_stmt(self, stmt: IfStmt) -> str:
res: str = self.indented(f"if {self.format(stmt.condition)} {self.format(stmt.then_branch)}") res: str = self.indented(f"if {self.format(stmt.condition)} {self.format(stmt.then_branch)}")
res = res.rstrip("\n") res = res.rstrip("\n")

View File

@@ -3,8 +3,9 @@ from typing import Any, Optional
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr, AssignExpr, LogicalExpr, \ from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr, AssignExpr, LogicalExpr, \
CallExpr CallExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt
from src.core.callable import PebbleCallable from src.core.callable import PebbleCallable
from src.core.function import PebbleFunction
from src.interpreter.environment import Environment from src.interpreter.environment import Environment
from src.interpreter.error import PebbleRuntimeError from src.interpreter.error import PebbleRuntimeError
from src.interpreter.globals import GlobalEnvironment from src.interpreter.globals import GlobalEnvironment
@@ -142,6 +143,10 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None: def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:
self.evaluate(stmt.expression) self.evaluate(stmt.expression)
def visit_function_stmt(self, stmt: FunctionStmt) -> None:
function: PebbleFunction = PebbleFunction(stmt)
self.env.define(stmt.name.lexeme, function)
def visit_if_stmt(self, stmt: IfStmt) -> None: def visit_if_stmt(self, stmt: IfStmt) -> None:
if self.is_truthy(self.evaluate(stmt.condition)): if self.is_truthy(self.evaluate(stmt.condition)):
self.execute(stmt.then_branch) self.execute(stmt.then_branch)

View File

@@ -2,6 +2,7 @@ from src.token import TokenType
KEYWORDS: dict[str, TokenType] = { KEYWORDS: dict[str, TokenType] = {
"let": TokenType.LET, "let": TokenType.LET,
"fun": TokenType.FUN,
"and": TokenType.AND, "and": TokenType.AND,
"or": TokenType.OR, "or": TokenType.OR,
"if": TokenType.IF, "if": TokenType.IF,

View File

@@ -2,7 +2,7 @@ from typing import Optional
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \ from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \
CallExpr CallExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt
from src.consts import MAX_FUNCTION_ARGS from src.consts import MAX_FUNCTION_ARGS
from src.parser.error import ParsingError from src.parser.error import ParsingError
from src.pebble import Pebble from src.pebble import Pebble
@@ -93,6 +93,8 @@ class Parser:
def declaration(self) -> Optional[Stmt]: def declaration(self) -> Optional[Stmt]:
try: try:
if self.match(TokenType.FUN):
return self.function("function")
if self.match(TokenType.LET): if self.match(TokenType.LET):
return self.var_declaration() return self.var_declaration()
return self.statement() return self.statement()
@@ -100,6 +102,24 @@ class Parser:
self.synchronize() self.synchronize()
return None return None
def function(self, kind: str) -> Stmt:
# TODO: allow anonymous/lambda functions
name: Token = self.consume(TokenType.IDENTIFIER, f"Expected {kind} name.")
self.consume(TokenType.LEFT_PAREN, f"Expected '(' after {kind} name.")
parameters: list[Token] = []
if not self.check(TokenType.RIGHT_PAREN):
while True:
if len(parameters) >= MAX_FUNCTION_ARGS:
self.error(self.peek(), f"Cannot have more than {MAX_FUNCTION_ARGS} parameters.")
parameters.append(self.consume(TokenType.IDENTIFIER, "Expected parameter name."))
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after parameters.")
# TODO: allow single statement functions (without braces)
self.consume(TokenType.LEFT_BRACE, f"Expected '{{' before {kind} body.")
body: list[Stmt] = self.block_stmt()
return FunctionStmt(name, parameters, body)
def var_declaration(self) -> Stmt: def var_declaration(self) -> Stmt:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected variable name.") name: Token = self.consume(TokenType.IDENTIFIER, "Expected variable name.")
initializer: Optional[Expr] = None initializer: Optional[Expr] = None
@@ -118,7 +138,7 @@ class Parser:
if self.match(TokenType.WHILE): if self.match(TokenType.WHILE):
return self.while_stmt() return self.while_stmt()
if self.match(TokenType.LEFT_BRACE): if self.match(TokenType.LEFT_BRACE):
return self.block_stmt() return BlockStmt(statements=self.block_stmt())
return self.expression_stmt() return self.expression_stmt()
def for_stmt(self) -> Stmt: def for_stmt(self) -> Stmt:
@@ -196,7 +216,7 @@ class Parser:
body: Stmt = self.statement() body: Stmt = self.statement()
return WhileStmt(condition, body) return WhileStmt(condition, body)
def block_stmt(self) -> Stmt: def block_stmt(self) -> list[Stmt]:
statements: list[Stmt] = [] statements: list[Stmt] = []
self.skip_newlines() self.skip_newlines()
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end(): while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
@@ -204,7 +224,7 @@ class Parser:
statements.append(self.declaration()) statements.append(self.declaration())
self.consume(TokenType.RIGHT_BRACE, "Expected '}' after block.") self.consume(TokenType.RIGHT_BRACE, "Expected '}' after block.")
return BlockStmt(statements) return statements
def expression_stmt(self) -> Stmt: def expression_stmt(self) -> Stmt:
value: Expr = self.expression() value: Expr = self.expression()

View File

@@ -44,6 +44,7 @@ class TokenType(Enum):
# Keywords # Keywords
LET = auto() LET = auto()
FUN = auto()
AND = auto() AND = auto()
OR = auto() OR = auto()
IF = auto() IF = auto()