feat: add function calls

This commit is contained in:
2026-02-06 13:02:21 +01:00
parent bf750748e3
commit 9d5fbc8c45
6 changed files with 74 additions and 3 deletions

View File

@@ -29,6 +29,10 @@ class Expr(ABC):
def visit_unary_expr(self, expr: UnaryExpr) -> T: def visit_unary_expr(self, expr: UnaryExpr) -> T:
... ...
@abstractmethod
def visit_call_expr(self, expr: CallExpr) -> T:
...
@abstractmethod @abstractmethod
def visit_grouping_expr(self, expr: GroupingExpr) -> T: def visit_grouping_expr(self, expr: GroupingExpr) -> T:
... ...
@@ -74,6 +78,16 @@ class UnaryExpr(Expr):
return visitor.visit_unary_expr(self) return visitor.visit_unary_expr(self)
@dataclass(frozen=True)
class CallExpr(Expr):
callee: Expr
paren: Token
arguments: list[Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_call_expr(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class GroupingExpr(Expr): class GroupingExpr(Expr):
expression: Expr expression: Expr

1
src/consts.py Normal file
View File

@@ -0,0 +1 @@
MAX_FUNCTION_ARGS = 255

0
src/core/__init__.py Normal file
View File

17
src/core/callable.py Normal file
View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from src.interpreter.interpreter import Interpreter
class PebbleCallable(ABC):
@abstractmethod
def arity(self) -> int:
...
@abstractmethod
def call(self, interpreter: Interpreter, arguments: list[Any]) -> Any:
...

View File

@@ -1,7 +1,9 @@
from typing import Any, Optional from typing import Any, Optional
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr, AssignExpr, LogicalExpr from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr, AssignExpr, LogicalExpr, \
CallExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt
from src.core.callable import PebbleCallable
from src.interpreter.environment import Environment from src.interpreter.environment import Environment
from src.interpreter.error import PebbleRuntimeError from src.interpreter.error import PebbleRuntimeError
from src.pebble import Pebble from src.pebble import Pebble
@@ -108,6 +110,19 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
# Unreachable # Unreachable
return None return None
def visit_call_expr(self, expr: CallExpr) -> Any:
callee: Any = self.evaluate(expr.callee)
arguments: list[Any] = [
self.evaluate(arg)
for arg in expr.arguments
]
if not isinstance(callee, PebbleCallable):
raise PebbleRuntimeError(expr.paren, "Can only call functions and classes.")
function: PebbleCallable = callee
if len(arguments) != function.arity():
raise PebbleRuntimeError(expr.paren, f"Expected {function.arity()} arguments but got {len(arguments)}.")
return function.call(self, arguments)
def visit_grouping_expr(self, expr: GroupingExpr) -> Any: def visit_grouping_expr(self, expr: GroupingExpr) -> Any:
return self.evaluate(expr.expression) return self.evaluate(expr.expression)

View File

@@ -1,7 +1,9 @@
from typing import Optional from typing import Optional
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \
CallExpr
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt
from src.consts import MAX_FUNCTION_ARGS
from src.parser.error import ParsingError from src.parser.error import ParsingError
from src.pebble import Pebble from src.pebble import Pebble
from src.token import Token, TokenType from src.token import Token, TokenType
@@ -286,7 +288,29 @@ class Parser:
operator: Token = self.previous() operator: Token = self.previous()
right: Expr = self.unary() right: Expr = self.unary()
return UnaryExpr(operator, right) return UnaryExpr(operator, right)
return self.primary() return self.call()
def call(self) -> Expr:
expr: Expr = self.primary()
while True:
if self.match(TokenType.LEFT_PAREN):
expr = self.finish_call(expr)
else:
break
return expr
def finish_call(self, callee: Expr) -> Expr:
arguments: list[Expr] = []
if not self.check(TokenType.RIGHT_PAREN):
while True:
if len(arguments) >= MAX_FUNCTION_ARGS:
self.error(self.peek(), f"Cannot have more than {MAX_FUNCTION_ARGS} arguments.")
arguments.append(self.expression())
if not self.match(TokenType.COMMA):
break
paren: Token = self.consume(TokenType.RIGHT_PAREN, "Expected ')' after arguments.")
return CallExpr(callee, paren, arguments)
def primary(self) -> Expr: def primary(self) -> Expr:
if self.match(TokenType.FALSE): if self.match(TokenType.FALSE):