369 lines
13 KiB
Python
369 lines
13 KiB
Python
from typing import Optional
|
|
|
|
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \
|
|
CallExpr
|
|
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt, \
|
|
ReturnStmt, BreakStmt
|
|
from src.consts import MAX_FUNCTION_ARGS
|
|
from src.parser.error import ParsingError
|
|
from src.pebble import Pebble
|
|
from src.token import Token, TokenType
|
|
|
|
|
|
class Parser:
|
|
IGNORE: set[TokenType] = {
|
|
TokenType.WHITESPACE, TokenType.COMMENT
|
|
}
|
|
|
|
STATEMENT_BOUNDARY: set[TokenType] = {
|
|
TokenType.FOR, TokenType.WHILE, TokenType.IF, TokenType.PRINT
|
|
}
|
|
|
|
def __init__(self, tokens: list[Token]):
|
|
self.tokens: list[Token] = list(filter(lambda t: t.type not in self.IGNORE, tokens))
|
|
self.current: int = 0
|
|
self.length: int = len(self.tokens)
|
|
|
|
@staticmethod
|
|
def error(token: Token, msg: str):
|
|
Pebble.token_error(token, msg)
|
|
return ParsingError()
|
|
|
|
def parse(self) -> list[Stmt]:
|
|
statements: list[Stmt] = []
|
|
self.skip_newlines()
|
|
while not self.is_at_end():
|
|
statements.append(self.declaration())
|
|
self.skip_newlines()
|
|
return statements
|
|
|
|
def skip_newlines(self):
|
|
while self.check(TokenType.NEWLINE):
|
|
self.advance()
|
|
|
|
def is_at_end(self) -> bool:
|
|
return self.peek().type == TokenType.EOF
|
|
|
|
def peek(self) -> Token:
|
|
return self.tokens[self.current]
|
|
|
|
def previous(self) -> Token:
|
|
return self.tokens[self.current - 1]
|
|
|
|
def check(self, token_type: TokenType) -> bool:
|
|
if self.is_at_end():
|
|
return False
|
|
return self.peek().type == token_type
|
|
|
|
def advance(self):
|
|
token: Token = self.peek()
|
|
self.current += 1
|
|
return token
|
|
|
|
def match(self, *types: TokenType) -> bool:
|
|
for token_type in types:
|
|
if self.check(token_type):
|
|
self.advance()
|
|
return True
|
|
return False
|
|
|
|
def consume(self, token_type: TokenType, error_msg: str) -> Token:
|
|
if self.check(token_type):
|
|
return self.advance()
|
|
raise self.error(self.peek(), error_msg)
|
|
|
|
def expect_eol(self, error_msg: str):
|
|
if self.is_at_end():
|
|
return
|
|
if not self.match(TokenType.NEWLINE) and not self.match(TokenType.EOF):
|
|
raise self.error(self.peek(), error_msg)
|
|
|
|
# Parsing
|
|
def synchronize(self):
|
|
self.advance()
|
|
while not self.is_at_end():
|
|
if self.previous().type == TokenType.NEWLINE:
|
|
return
|
|
if self.peek().type in self.STATEMENT_BOUNDARY:
|
|
return
|
|
self.advance()
|
|
|
|
def declaration(self) -> Optional[Stmt]:
|
|
try:
|
|
if self.match(TokenType.FUN):
|
|
return self.function("function")
|
|
if self.match(TokenType.LET):
|
|
return self.var_declaration()
|
|
return self.statement()
|
|
except ParsingError:
|
|
self.synchronize()
|
|
return None
|
|
|
|
def function(self, kind: str) -> Stmt:
|
|
# TODO: allow anonymous/lambda functions
|
|
name: Token = self.consume(TokenType.IDENTIFIER, f"Expected {kind} name.")
|
|
self.consume(TokenType.LEFT_PAREN, f"Expected '(' after {kind} name.")
|
|
parameters: list[Token] = []
|
|
if not self.check(TokenType.RIGHT_PAREN):
|
|
while True:
|
|
if len(parameters) >= MAX_FUNCTION_ARGS:
|
|
self.error(self.peek(), f"Cannot have more than {MAX_FUNCTION_ARGS} parameters.")
|
|
parameters.append(self.consume(TokenType.IDENTIFIER, "Expected parameter name."))
|
|
if not self.match(TokenType.COMMA):
|
|
break
|
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after parameters.")
|
|
# TODO: allow single statement functions (without braces)
|
|
self.consume(TokenType.LEFT_BRACE, f"Expected '{{' before {kind} body.")
|
|
body: list[Stmt] = self.block_stmt()
|
|
return FunctionStmt(name, parameters, body)
|
|
|
|
def var_declaration(self) -> Stmt:
|
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected variable name.")
|
|
initializer: Optional[Expr] = None
|
|
if self.match(TokenType.EQUAL):
|
|
initializer = self.expression()
|
|
self.expect_eol("Expected end of line after variable initialization")
|
|
return LetStmt(name, initializer)
|
|
|
|
def statement(self) -> Stmt:
|
|
if self.match(TokenType.FOR):
|
|
return self.for_stmt()
|
|
if self.match(TokenType.IF):
|
|
return self.if_stmt()
|
|
if self.match(TokenType.PRINT):
|
|
return self.print_stmt()
|
|
if self.match(TokenType.RETURN):
|
|
return self.return_stmt()
|
|
if self.match(TokenType.BREAK):
|
|
return self.break_stmt()
|
|
if self.match(TokenType.WHILE):
|
|
return self.while_stmt()
|
|
if self.match(TokenType.LEFT_BRACE):
|
|
return BlockStmt(statements=self.block_stmt())
|
|
return self.expression_stmt()
|
|
|
|
def for_stmt(self) -> Stmt:
|
|
var: Token = self.consume(TokenType.IDENTIFIER, "Missing loop variable.")
|
|
end: Optional[Expr] = None
|
|
from_clause: Optional[Expr] = None
|
|
to_clause: Optional[Expr] = None
|
|
until_clause: Optional[Expr] = None
|
|
by_clause: Optional[Expr] = None
|
|
from_token: Optional[Token] = None
|
|
end_token: Optional[Token] = None
|
|
by_token: Optional[Token] = None
|
|
|
|
while self.match(TokenType.FROM, TokenType.TO, TokenType.UNTIL, TokenType.BY):
|
|
previous: Token = self.previous()
|
|
match previous.type:
|
|
case TokenType.FROM:
|
|
if from_clause is not None:
|
|
raise self.error(previous, "From clause already defined.")
|
|
from_clause = self.expression()
|
|
from_token = previous
|
|
case TokenType.TO:
|
|
if to_clause is not None:
|
|
raise self.error(previous, "To clause already defined.")
|
|
if until_clause is not None:
|
|
raise self.error(previous, "Until clause already defined.")
|
|
to_clause = self.expression()
|
|
end = to_clause
|
|
end_token = previous
|
|
case TokenType.UNTIL:
|
|
if until_clause is not None:
|
|
raise self.error(previous, "Until clause already defined.")
|
|
if to_clause is not None:
|
|
raise self.error(previous, "To clause already defined.")
|
|
until_clause = self.expression()
|
|
end = until_clause
|
|
end_token = previous
|
|
case TokenType.BY:
|
|
if by_clause is not None:
|
|
raise self.error(previous, "By clause already defined.")
|
|
by_clause = self.expression()
|
|
by_token = previous
|
|
|
|
body: Stmt = self.statement()
|
|
|
|
loop: Stmt = ForStmt(
|
|
variable=var,
|
|
start_token=from_token,
|
|
start=from_clause,
|
|
end_token=end_token,
|
|
end=end,
|
|
step_token=by_token,
|
|
step=by_clause,
|
|
body=body
|
|
)
|
|
return loop
|
|
|
|
def if_stmt(self) -> Stmt:
|
|
condition: Expr = self.expression()
|
|
then_branch: Stmt = self.statement()
|
|
else_branch: Optional[Stmt] = None
|
|
if self.match(TokenType.ELSE):
|
|
else_branch = self.statement()
|
|
return IfStmt(condition, then_branch, else_branch)
|
|
|
|
def print_stmt(self) -> Stmt:
|
|
self.consume(TokenType.LEFT_PAREN, "Missing parentheses")
|
|
value: Expr = self.expression()
|
|
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
|
self.expect_eol("Expected end of line after statement")
|
|
return PrintStmt(value)
|
|
|
|
def return_stmt(self) -> Stmt:
|
|
keyword: Token = self.previous()
|
|
value: Optional[Expr] = None
|
|
if not self.check(TokenType.NEWLINE) and not self.is_at_end():
|
|
value = self.expression()
|
|
self.expect_eol("Expected end of line after return statement.")
|
|
return ReturnStmt(keyword, value)
|
|
|
|
def break_stmt(self) -> Stmt:
|
|
keyword: Token = self.previous()
|
|
self.expect_eol("Expected end of line after break statement.")
|
|
return BreakStmt(keyword)
|
|
|
|
def while_stmt(self) -> Stmt:
|
|
condition: Expr = self.expression()
|
|
body: Stmt = self.statement()
|
|
return WhileStmt(condition, body)
|
|
|
|
def block_stmt(self) -> list[Stmt]:
|
|
statements: list[Stmt] = []
|
|
self.skip_newlines()
|
|
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
|
self.skip_newlines()
|
|
statements.append(self.declaration())
|
|
|
|
self.consume(TokenType.RIGHT_BRACE, "Expected '}' after block.")
|
|
return statements
|
|
|
|
def expression_stmt(self) -> Stmt:
|
|
value: Expr = self.expression()
|
|
self.expect_eol("Expected end of line after expression")
|
|
return ExpressionStmt(value)
|
|
|
|
def expression(self) -> Expr:
|
|
return self.assignment()
|
|
|
|
def assignment(self) -> Expr:
|
|
expr: Expr = self.or_()
|
|
if self.match(TokenType.EQUAL, TokenType.PLUS_EQUAL, TokenType.MINUS_EQUAL, TokenType.STAR_EQUAL, TokenType.SLASH_EQUAL):
|
|
operator: Token = self.previous()
|
|
value: Expr = self.assignment()
|
|
if isinstance(expr, VariableExpr):
|
|
name: Token = expr.name
|
|
if operator.type == TokenType.EQUAL:
|
|
return AssignExpr(name, value)
|
|
else:
|
|
return AssignExpr(
|
|
name,
|
|
BinaryExpr(
|
|
VariableExpr(name),
|
|
operator,
|
|
value
|
|
)
|
|
)
|
|
self.error(operator, "Invalid assignment target.")
|
|
return expr
|
|
|
|
def or_(self) -> Expr:
|
|
expr: Expr = self.and_()
|
|
while self.match(TokenType.OR):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.and_()
|
|
expr = LogicalExpr(expr, operator, right)
|
|
return expr
|
|
|
|
def and_(self) -> Expr:
|
|
expr: Expr = self.equality()
|
|
while self.match(TokenType.AND):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.equality()
|
|
expr = LogicalExpr(expr, operator, right)
|
|
return expr
|
|
|
|
def equality(self) -> Expr:
|
|
expr: Expr = self.comparison()
|
|
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.comparison()
|
|
expr = BinaryExpr(expr, operator, right)
|
|
return expr
|
|
|
|
def comparison(self) -> Expr:
|
|
expr: Expr = self.term()
|
|
while self.match(TokenType.LESS, TokenType.LESS_EQUAL, TokenType.GREATER, TokenType.GREATER_EQUAL):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.term()
|
|
expr = BinaryExpr(expr, operator, right)
|
|
return expr
|
|
|
|
def term(self) -> Expr:
|
|
expr: Expr = self.factor()
|
|
while self.match(TokenType.PLUS, TokenType.MINUS):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.factor()
|
|
expr = BinaryExpr(expr, operator, right)
|
|
return expr
|
|
|
|
def factor(self) -> Expr:
|
|
expr: Expr = self.unary()
|
|
while self.match(TokenType.STAR, TokenType.SLASH):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.unary()
|
|
expr = BinaryExpr(expr, operator, right)
|
|
return expr
|
|
|
|
def unary(self) -> Expr:
|
|
if self.match(TokenType.BANG, TokenType.MINUS):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.unary()
|
|
return UnaryExpr(operator, right)
|
|
return self.call()
|
|
|
|
def call(self) -> Expr:
|
|
expr: Expr = self.primary()
|
|
while True:
|
|
if self.match(TokenType.LEFT_PAREN):
|
|
expr = self.finish_call(expr)
|
|
else:
|
|
break
|
|
return expr
|
|
|
|
def finish_call(self, callee: Expr) -> Expr:
|
|
arguments: list[Expr] = []
|
|
if not self.check(TokenType.RIGHT_PAREN):
|
|
while True:
|
|
if len(arguments) >= MAX_FUNCTION_ARGS:
|
|
self.error(self.peek(), f"Cannot have more than {MAX_FUNCTION_ARGS} arguments.")
|
|
arguments.append(self.expression())
|
|
if not self.match(TokenType.COMMA):
|
|
break
|
|
|
|
paren: Token = self.consume(TokenType.RIGHT_PAREN, "Expected ')' after arguments.")
|
|
return CallExpr(callee, paren, arguments)
|
|
|
|
def primary(self) -> Expr:
|
|
if self.match(TokenType.FALSE):
|
|
return LiteralExpr(False)
|
|
if self.match(TokenType.TRUE):
|
|
return LiteralExpr(True)
|
|
if self.match(TokenType.NULL):
|
|
return LiteralExpr(None)
|
|
|
|
if self.match(TokenType.NUMBER, TokenType.STRING):
|
|
return LiteralExpr(self.previous().value)
|
|
|
|
if self.match(TokenType.IDENTIFIER):
|
|
return VariableExpr(self.previous())
|
|
|
|
if self.match(TokenType.LEFT_PAREN):
|
|
expr: Expr = self.expression()
|
|
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
|
return GroupingExpr(expr)
|
|
|
|
raise self.error(self.peek(), "Expected expression")
|