Files
pebble/src/parser/parser.py

388 lines
15 KiB
Python

from typing import Optional
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \
CallExpr, GetExpr, SetExpr, ThisExpr, SuperExpr
from src.ast.stmt import Stmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt, \
ReturnStmt, BreakStmt, ContinueStmt, ClassStmt
from src.consts import MAX_FUNCTION_ARGS
from src.parser.error import ParsingError
from src.pebble import Pebble
from src.token import Token, TokenType
class Parser:
IGNORE: set[TokenType] = {
TokenType.WHITESPACE, TokenType.COMMENT, TokenType.NEWLINE
}
STATEMENT_BOUNDARY: set[TokenType] = {
TokenType.FOR, TokenType.WHILE, TokenType.IF
}
def __init__(self, tokens: list[Token]):
self.tokens: list[Token] = list(filter(lambda t: t.type not in self.IGNORE, tokens))
self.current: int = 0
self.length: int = len(self.tokens)
@staticmethod
def error(token: Token, msg: str):
Pebble.token_error(token, msg)
return ParsingError()
def parse(self) -> list[Stmt]:
statements: list[Stmt] = []
while not self.is_at_end():
statements.append(self.declaration())
return statements
def is_at_end(self) -> bool:
return self.peek().type == TokenType.EOF
def peek(self) -> Token:
return self.tokens[self.current]
def previous(self) -> Token:
return self.tokens[self.current - 1]
def check(self, token_type: TokenType) -> bool:
if self.is_at_end():
return False
return self.peek().type == token_type
def advance(self):
token: Token = self.peek()
self.current += 1
return token
def match(self, *types: TokenType) -> bool:
for token_type in types:
if self.check(token_type):
self.advance()
return True
return False
def consume(self, token_type: TokenType, error_msg: str) -> Token:
if self.check(token_type):
return self.advance()
raise self.error(self.peek(), error_msg)
def expect_eol(self, error_msg: str):
if self.is_at_end():
return
if not self.match(TokenType.NEWLINE) and not self.match(TokenType.EOF):
raise self.error(self.peek(), error_msg)
# Parsing
def synchronize(self):
self.advance()
while not self.is_at_end():
if self.previous().type == TokenType.NEWLINE:
return
if self.peek().type in self.STATEMENT_BOUNDARY:
return
self.advance()
def declaration(self) -> Optional[Stmt]:
try:
if self.match(TokenType.CLASS):
return self.class_declaration()
if self.match(TokenType.FUN):
return self.function("function")
if self.match(TokenType.LET):
return self.var_declaration()
return self.statement()
except ParsingError:
self.synchronize()
return None
def class_declaration(self) -> Stmt:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected class name.")
superclass: Optional[VariableExpr] = None
if self.match(TokenType.LESS):
self.consume(TokenType.IDENTIFIER, "Expected superclass name.")
superclass = VariableExpr(self.previous())
self.consume(TokenType.LEFT_BRACE, "Expected '{' before class body.")
methods: list[FunctionStmt] = []
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
methods.append(self.function("method"))
self.consume(TokenType.RIGHT_BRACE, "Expected '}' after class body.")
return ClassStmt(name, superclass, methods)
def function(self, kind: str) -> FunctionStmt:
# TODO: allow anonymous/lambda functions
name: Token = self.consume(TokenType.IDENTIFIER, f"Expected {kind} name.")
self.consume(TokenType.LEFT_PAREN, f"Expected '(' after {kind} name.")
parameters: list[Token] = []
if not self.check(TokenType.RIGHT_PAREN):
while True:
if len(parameters) >= MAX_FUNCTION_ARGS:
self.error(self.peek(), f"Cannot have more than {MAX_FUNCTION_ARGS} parameters.")
parameters.append(self.consume(TokenType.IDENTIFIER, "Expected parameter name."))
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after parameters.")
# TODO: allow single statement functions (without braces)
self.consume(TokenType.LEFT_BRACE, f"Expected '{{' before {kind} body.")
body: list[Stmt] = self.block_stmt()
return FunctionStmt(name, parameters, body)
def var_declaration(self) -> Stmt:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected variable name.")
initializer: Optional[Expr] = None
if self.match(TokenType.EQUAL):
initializer = self.expression()
return LetStmt(name, initializer)
def statement(self) -> Stmt:
if self.match(TokenType.FOR):
return self.for_stmt()
if self.match(TokenType.IF):
return self.if_stmt()
if self.match(TokenType.RETURN):
return self.return_stmt()
if self.match(TokenType.BREAK):
return self.break_stmt()
if self.match(TokenType.CONTINUE):
return self.continue_stmt()
if self.match(TokenType.WHILE):
return self.while_stmt()
if self.match(TokenType.LEFT_BRACE):
return BlockStmt(statements=self.block_stmt())
return self.expression_stmt()
def for_stmt(self) -> Stmt:
var: Token = self.consume(TokenType.IDENTIFIER, "Missing loop variable.")
end: Optional[Expr] = None
from_clause: Optional[Expr] = None
to_clause: Optional[Expr] = None
until_clause: Optional[Expr] = None
by_clause: Optional[Expr] = None
from_token: Optional[Token] = None
end_token: Optional[Token] = None
by_token: Optional[Token] = None
while self.match(TokenType.FROM, TokenType.TO, TokenType.UNTIL, TokenType.BY):
previous: Token = self.previous()
match previous.type:
case TokenType.FROM:
if from_clause is not None:
raise self.error(previous, "From clause already defined.")
from_clause = self.expression()
from_token = previous
case TokenType.TO:
if to_clause is not None:
raise self.error(previous, "To clause already defined.")
if until_clause is not None:
raise self.error(previous, "Until clause already defined.")
to_clause = self.expression()
end = to_clause
end_token = previous
case TokenType.UNTIL:
if until_clause is not None:
raise self.error(previous, "Until clause already defined.")
if to_clause is not None:
raise self.error(previous, "To clause already defined.")
until_clause = self.expression()
end = until_clause
end_token = previous
case TokenType.BY:
if by_clause is not None:
raise self.error(previous, "By clause already defined.")
by_clause = self.expression()
by_token = previous
body: Stmt = self.statement()
loop: Stmt = ForStmt(
variable=var,
start_token=from_token,
start=from_clause,
end_token=end_token,
end=end,
step_token=by_token,
step=by_clause,
body=body
)
return loop
def if_stmt(self) -> Stmt:
condition: Expr = self.expression()
then_branch: Stmt = self.statement()
else_branch: Optional[Stmt] = None
if self.match(TokenType.ELSE):
else_branch = self.statement()
return IfStmt(condition, then_branch, else_branch)
def return_stmt(self) -> Stmt:
keyword: Token = self.previous()
value: Optional[Expr] = None
if not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
value = self.expression()
if not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
self.error(keyword, "Return must be the last statement in a function.")
return ReturnStmt(keyword, value)
def break_stmt(self) -> Stmt:
keyword: Token = self.previous()
if not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
self.error(keyword, "Break must be the last statement in a block.")
return BreakStmt(keyword)
def continue_stmt(self) -> Stmt:
keyword: Token = self.previous()
if not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
self.error(keyword, "Continue must be the last statement in a block.")
return ContinueStmt(keyword)
def while_stmt(self) -> Stmt:
condition: Expr = self.expression()
body: Stmt = self.statement()
return WhileStmt(condition, body)
def block_stmt(self) -> list[Stmt]:
statements: list[Stmt] = []
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
statements.append(self.declaration())
self.consume(TokenType.RIGHT_BRACE, "Expected '}' after block.")
return statements
def expression_stmt(self) -> Stmt:
value: Expr = self.expression()
return ExpressionStmt(value)
def expression(self) -> Expr:
return self.assignment()
def assignment(self) -> Expr:
expr: Expr = self.or_()
if self.match(TokenType.EQUAL, TokenType.PLUS_EQUAL, TokenType.MINUS_EQUAL, TokenType.STAR_EQUAL,
TokenType.SLASH_EQUAL):
operator: Token = self.previous()
value: Expr = self.assignment()
if operator.type != TokenType.EQUAL:
value = BinaryExpr(
expr,
operator,
value
)
if isinstance(expr, VariableExpr):
name: Token = expr.name
return AssignExpr(name, value)
elif isinstance(expr, GetExpr):
return SetExpr(expr.object, expr.name, value)
self.error(operator, "Invalid assignment target.")
return expr
def or_(self) -> Expr:
expr: Expr = self.and_()
while self.match(TokenType.OR):
operator: Token = self.previous()
right: Expr = self.and_()
expr = LogicalExpr(expr, operator, right)
return expr
def and_(self) -> Expr:
expr: Expr = self.equality()
while self.match(TokenType.AND):
operator: Token = self.previous()
right: Expr = self.equality()
expr = LogicalExpr(expr, operator, right)
return expr
def equality(self) -> Expr:
expr: Expr = self.comparison()
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
operator: Token = self.previous()
right: Expr = self.comparison()
expr = BinaryExpr(expr, operator, right)
return expr
def comparison(self) -> Expr:
expr: Expr = self.term()
while self.match(TokenType.LESS, TokenType.LESS_EQUAL, TokenType.GREATER, TokenType.GREATER_EQUAL):
operator: Token = self.previous()
right: Expr = self.term()
expr = BinaryExpr(expr, operator, right)
return expr
def term(self) -> Expr:
expr: Expr = self.factor()
while self.match(TokenType.PLUS, TokenType.MINUS):
operator: Token = self.previous()
right: Expr = self.factor()
expr = BinaryExpr(expr, operator, right)
return expr
def factor(self) -> Expr:
expr: Expr = self.unary()
while self.match(TokenType.STAR, TokenType.SLASH):
operator: Token = self.previous()
right: Expr = self.unary()
expr = BinaryExpr(expr, operator, right)
return expr
def unary(self) -> Expr:
if self.match(TokenType.BANG, TokenType.MINUS):
operator: Token = self.previous()
right: Expr = self.unary()
return UnaryExpr(operator, right)
return self.call()
def call(self) -> Expr:
expr: Expr = self.primary()
while True:
if self.match(TokenType.LEFT_PAREN):
expr = self.finish_call(expr)
elif self.match(TokenType.DOT):
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name after '.'.")
expr = GetExpr(expr, name)
else:
break
return expr
def finish_call(self, callee: Expr) -> Expr:
arguments: list[Expr] = []
if not self.check(TokenType.RIGHT_PAREN):
while True:
if len(arguments) >= MAX_FUNCTION_ARGS:
self.error(self.peek(), f"Cannot have more than {MAX_FUNCTION_ARGS} arguments.")
arguments.append(self.expression())
if not self.match(TokenType.COMMA):
break
paren: Token = self.consume(TokenType.RIGHT_PAREN, "Expected ')' after arguments.")
return CallExpr(callee, paren, arguments)
def primary(self) -> Expr:
if self.match(TokenType.FALSE):
return LiteralExpr(False)
if self.match(TokenType.TRUE):
return LiteralExpr(True)
if self.match(TokenType.NULL):
return LiteralExpr(None)
if self.match(TokenType.NUMBER, TokenType.STRING):
return LiteralExpr(self.previous().value)
if self.match(TokenType.SUPER):
keyword: Token = self.previous()
self.consume(TokenType.DOT, "Expected '.' after 'super'.")
method: Token = self.consume(TokenType.IDENTIFIER, "Expected superclass method name.")
return SuperExpr(keyword, method)
if self.match(TokenType.THIS):
return ThisExpr(self.previous())
if self.match(TokenType.IDENTIFIER):
return VariableExpr(self.previous())
if self.match(TokenType.LEFT_PAREN):
expr: Expr = self.expression()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
return GroupingExpr(expr)
raise self.error(self.peek(), "Expected expression")