feat: add function definition
This commit is contained in:
8
examples/12_function_def.peb
Normal file
8
examples/12_function_def.peb
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
fun add(a, b) {
|
||||||
|
let c = a + b
|
||||||
|
print(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
add(1, 3)
|
||||||
|
add(4, 6)
|
||||||
|
add(-3, 9)
|
||||||
10
main.py
10
main.py
@@ -7,14 +7,8 @@ from src.token import Token
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
source: str = """() {} +- += / /= // sefs + {, )
|
path: str = "examples/12_function_def.peb"
|
||||||
}:: *
|
source: str = ""
|
||||||
"This is a string"
|
|
||||||
3.1415
|
|
||||||
123
|
|
||||||
"This is
|
|
||||||
another string" """
|
|
||||||
path: str = "examples/05_loop.peb"
|
|
||||||
with open(path, "r") as f:
|
with open(path, "r") as f:
|
||||||
source = f.read()
|
source = f.read()
|
||||||
lexer: Lexer = Lexer()
|
lexer: Lexer = Lexer()
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ class Stmt(ABC):
|
|||||||
def visit_expression_stmt(self, stmt: ExpressionStmt) -> T:
|
def visit_expression_stmt(self, stmt: ExpressionStmt) -> T:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_function_stmt(self, stmt: FunctionStmt) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_if_stmt(self, stmt: IfStmt) -> T:
|
def visit_if_stmt(self, stmt: IfStmt) -> T:
|
||||||
...
|
...
|
||||||
@@ -62,6 +66,16 @@ class ExpressionStmt(Stmt):
|
|||||||
return visitor.visit_expression_stmt(self)
|
return visitor.visit_expression_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FunctionStmt(Stmt):
|
||||||
|
name: Token
|
||||||
|
params: list[Token]
|
||||||
|
body: list[Stmt]
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_function_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class IfStmt(Stmt):
|
class IfStmt(Stmt):
|
||||||
condition: Expr
|
condition: Expr
|
||||||
|
|||||||
27
src/core/function.py
Normal file
27
src/core/function.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
from src.ast.stmt import FunctionStmt
|
||||||
|
from src.core.callable import PebbleCallable
|
||||||
|
from src.interpreter.environment import Environment
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.interpreter.interpreter import Interpreter
|
||||||
|
|
||||||
|
|
||||||
|
class PebbleFunction(PebbleCallable):
|
||||||
|
def __init__(self, declaration: FunctionStmt):
|
||||||
|
self.declaration: FunctionStmt = declaration
|
||||||
|
|
||||||
|
def arity(self) -> int:
|
||||||
|
return len(self.declaration.params)
|
||||||
|
|
||||||
|
def call(self, interpreter: Interpreter, arguments: list[Any]) -> None:
|
||||||
|
env: Environment = Environment(interpreter.global_env)
|
||||||
|
for i, param in enumerate(self.declaration.params):
|
||||||
|
env.define(param.lexeme, arguments[i])
|
||||||
|
interpreter.execute_block(self.declaration.body, env)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"<function {self.declaration.name.lexeme}>"
|
||||||
@@ -2,7 +2,7 @@ from typing import Any
|
|||||||
|
|
||||||
from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \
|
from src.ast.expr import Expr, VariableExpr, LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, AssignExpr, LogicalExpr, \
|
||||||
CallExpr
|
CallExpr
|
||||||
from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt
|
from src.ast.stmt import Stmt, LetStmt, PrintStmt, IfStmt, ExpressionStmt, BlockStmt, WhileStmt, ForStmt, FunctionStmt
|
||||||
|
|
||||||
|
|
||||||
class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
|
class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
|
||||||
@@ -72,6 +72,17 @@ class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
|
|||||||
def visit_expression_stmt(self, stmt: ExpressionStmt) -> str:
|
def visit_expression_stmt(self, stmt: ExpressionStmt) -> str:
|
||||||
return self.indented(self.format(stmt.expression) + "\n")
|
return self.indented(self.format(stmt.expression) + "\n")
|
||||||
|
|
||||||
|
def visit_function_stmt(self, stmt: FunctionStmt) -> str:
|
||||||
|
res: str = self.indented(f"fun {stmt.name.lexeme}")
|
||||||
|
res += f"({', '.join(param.lexeme for param in stmt.params)}) "
|
||||||
|
res += "{\n"
|
||||||
|
self.level += 1
|
||||||
|
for sub_stmt in stmt.body:
|
||||||
|
res += self.format(sub_stmt)
|
||||||
|
self.level -= 1
|
||||||
|
res += self.indented("}\n")
|
||||||
|
return res
|
||||||
|
|
||||||
def visit_if_stmt(self, stmt: IfStmt) -> str:
|
def visit_if_stmt(self, stmt: IfStmt) -> str:
|
||||||
res: str = self.indented(f"if {self.format(stmt.condition)} {self.format(stmt.then_branch)}")
|
res: str = self.indented(f"if {self.format(stmt.condition)} {self.format(stmt.then_branch)}")
|
||||||
res = res.rstrip("\n")
|
res = res.rstrip("\n")
|
||||||
|
|||||||
@@ -3,8 +3,9 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr, AssignExpr, LogicalExpr, \
|
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr, VariableExpr, AssignExpr, LogicalExpr, \
|
||||||
CallExpr
|
CallExpr
|
||||||
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt
|
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt
|
||||||
from src.core.callable import PebbleCallable
|
from src.core.callable import PebbleCallable
|
||||||
|
from src.core.function import PebbleFunction
|
||||||
from src.interpreter.environment import Environment
|
from src.interpreter.environment import Environment
|
||||||
from src.interpreter.error import PebbleRuntimeError
|
from src.interpreter.error import PebbleRuntimeError
|
||||||
from src.interpreter.globals import GlobalEnvironment
|
from src.interpreter.globals import GlobalEnvironment
|
||||||
@@ -142,6 +143,10 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
|
|||||||
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:
|
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:
|
||||||
self.evaluate(stmt.expression)
|
self.evaluate(stmt.expression)
|
||||||
|
|
||||||
|
def visit_function_stmt(self, stmt: FunctionStmt) -> None:
|
||||||
|
function: PebbleFunction = PebbleFunction(stmt)
|
||||||
|
self.env.define(stmt.name.lexeme, function)
|
||||||
|
|
||||||
def visit_if_stmt(self, stmt: IfStmt) -> None:
|
def visit_if_stmt(self, stmt: IfStmt) -> None:
|
||||||
if self.is_truthy(self.evaluate(stmt.condition)):
|
if self.is_truthy(self.evaluate(stmt.condition)):
|
||||||
self.execute(stmt.then_branch)
|
self.execute(stmt.then_branch)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from src.token import TokenType
|
|||||||
|
|
||||||
KEYWORDS: dict[str, TokenType] = {
|
KEYWORDS: dict[str, TokenType] = {
|
||||||
"let": TokenType.LET,
|
"let": TokenType.LET,
|
||||||
|
"fun": TokenType.FUN,
|
||||||
"and": TokenType.AND,
|
"and": TokenType.AND,
|
||||||
"or": TokenType.OR,
|
"or": TokenType.OR,
|
||||||
"if": TokenType.IF,
|
"if": TokenType.IF,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \
|
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \
|
||||||
CallExpr
|
CallExpr
|
||||||
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt
|
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt
|
||||||
from src.consts import MAX_FUNCTION_ARGS
|
from src.consts import MAX_FUNCTION_ARGS
|
||||||
from src.parser.error import ParsingError
|
from src.parser.error import ParsingError
|
||||||
from src.pebble import Pebble
|
from src.pebble import Pebble
|
||||||
@@ -93,6 +93,8 @@ class Parser:
|
|||||||
|
|
||||||
def declaration(self) -> Optional[Stmt]:
|
def declaration(self) -> Optional[Stmt]:
|
||||||
try:
|
try:
|
||||||
|
if self.match(TokenType.FUN):
|
||||||
|
return self.function("function")
|
||||||
if self.match(TokenType.LET):
|
if self.match(TokenType.LET):
|
||||||
return self.var_declaration()
|
return self.var_declaration()
|
||||||
return self.statement()
|
return self.statement()
|
||||||
@@ -100,6 +102,24 @@ class Parser:
|
|||||||
self.synchronize()
|
self.synchronize()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def function(self, kind: str) -> Stmt:
|
||||||
|
# TODO: allow anonymous/lambda functions
|
||||||
|
name: Token = self.consume(TokenType.IDENTIFIER, f"Expected {kind} name.")
|
||||||
|
self.consume(TokenType.LEFT_PAREN, f"Expected '(' after {kind} name.")
|
||||||
|
parameters: list[Token] = []
|
||||||
|
if not self.check(TokenType.RIGHT_PAREN):
|
||||||
|
while True:
|
||||||
|
if len(parameters) >= MAX_FUNCTION_ARGS:
|
||||||
|
self.error(self.peek(), f"Cannot have more than {MAX_FUNCTION_ARGS} parameters.")
|
||||||
|
parameters.append(self.consume(TokenType.IDENTIFIER, "Expected parameter name."))
|
||||||
|
if not self.match(TokenType.COMMA):
|
||||||
|
break
|
||||||
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after parameters.")
|
||||||
|
# TODO: allow single statement functions (without braces)
|
||||||
|
self.consume(TokenType.LEFT_BRACE, f"Expected '{{' before {kind} body.")
|
||||||
|
body: list[Stmt] = self.block_stmt()
|
||||||
|
return FunctionStmt(name, parameters, body)
|
||||||
|
|
||||||
def var_declaration(self) -> Stmt:
|
def var_declaration(self) -> Stmt:
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected variable name.")
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected variable name.")
|
||||||
initializer: Optional[Expr] = None
|
initializer: Optional[Expr] = None
|
||||||
@@ -118,7 +138,7 @@ class Parser:
|
|||||||
if self.match(TokenType.WHILE):
|
if self.match(TokenType.WHILE):
|
||||||
return self.while_stmt()
|
return self.while_stmt()
|
||||||
if self.match(TokenType.LEFT_BRACE):
|
if self.match(TokenType.LEFT_BRACE):
|
||||||
return self.block_stmt()
|
return BlockStmt(statements=self.block_stmt())
|
||||||
return self.expression_stmt()
|
return self.expression_stmt()
|
||||||
|
|
||||||
def for_stmt(self) -> Stmt:
|
def for_stmt(self) -> Stmt:
|
||||||
@@ -196,7 +216,7 @@ class Parser:
|
|||||||
body: Stmt = self.statement()
|
body: Stmt = self.statement()
|
||||||
return WhileStmt(condition, body)
|
return WhileStmt(condition, body)
|
||||||
|
|
||||||
def block_stmt(self) -> Stmt:
|
def block_stmt(self) -> list[Stmt]:
|
||||||
statements: list[Stmt] = []
|
statements: list[Stmt] = []
|
||||||
self.skip_newlines()
|
self.skip_newlines()
|
||||||
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
||||||
@@ -204,7 +224,7 @@ class Parser:
|
|||||||
statements.append(self.declaration())
|
statements.append(self.declaration())
|
||||||
|
|
||||||
self.consume(TokenType.RIGHT_BRACE, "Expected '}' after block.")
|
self.consume(TokenType.RIGHT_BRACE, "Expected '}' after block.")
|
||||||
return BlockStmt(statements)
|
return statements
|
||||||
|
|
||||||
def expression_stmt(self) -> Stmt:
|
def expression_stmt(self) -> Stmt:
|
||||||
value: Expr = self.expression()
|
value: Expr = self.expression()
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class TokenType(Enum):
|
|||||||
|
|
||||||
# Keywords
|
# Keywords
|
||||||
LET = auto()
|
LET = auto()
|
||||||
|
FUN = auto()
|
||||||
AND = auto()
|
AND = auto()
|
||||||
OR = auto()
|
OR = auto()
|
||||||
IF = auto()
|
IF = auto()
|
||||||
|
|||||||
Reference in New Issue
Block a user