434 lines
17 KiB
Python
434 lines
17 KiB
Python
from typing import Optional
|
|
|
|
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr, VariableExpr, AssignExpr, LogicalExpr, \
|
|
CallExpr, GetExpr, SetExpr, ThisExpr, SuperExpr, FStringExpr, FStringEmbedExpr, ListExpr, SubscriptGetExpr
|
|
from src.ast.stmt import Stmt, ExpressionStmt, LetStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, FunctionStmt, \
|
|
ReturnStmt, BreakStmt, ContinueStmt, ClassStmt
|
|
from src.consts import MAX_FUNCTION_ARGS
|
|
from src.core.format_spec.parser import FormatSpecParser
|
|
from src.core.format_spec.spec import FormatSpec
|
|
from src.parser.error import ParsingError
|
|
from src.pebble import Pebble
|
|
from src.token.token import Token, TokenType
|
|
|
|
|
|
class Parser:
|
|
IGNORE: set[TokenType] = {
|
|
TokenType.WHITESPACE, TokenType.COMMENT, TokenType.NEWLINE
|
|
}
|
|
|
|
STATEMENT_BOUNDARY: set[TokenType] = {
|
|
TokenType.FOR, TokenType.WHILE, TokenType.IF, TokenType.CLASS, TokenType.FUN, TokenType.LET, TokenType.RETURN
|
|
}
|
|
|
|
def __init__(self, tokens: list[Token]):
|
|
self.tokens: list[Token] = list(filter(lambda t: t.type not in self.IGNORE, tokens))
|
|
self.current: int = 0
|
|
self.length: int = len(self.tokens)
|
|
|
|
@staticmethod
|
|
def error(token: Token, msg: str):
|
|
Pebble.token_error(token, msg)
|
|
return ParsingError()
|
|
|
|
def parse(self) -> list[Stmt]:
|
|
statements: list[Stmt] = []
|
|
while not self.is_at_end():
|
|
statements.append(self.declaration())
|
|
return statements
|
|
|
|
def is_at_end(self) -> bool:
|
|
return self.peek().type == TokenType.EOF
|
|
|
|
def peek(self) -> Token:
|
|
return self.tokens[self.current]
|
|
|
|
def previous(self) -> Token:
|
|
return self.tokens[self.current - 1]
|
|
|
|
def check(self, token_type: TokenType) -> bool:
|
|
if self.is_at_end():
|
|
return False
|
|
return self.peek().type == token_type
|
|
|
|
def advance(self) -> Token:
|
|
if not self.is_at_end():
|
|
self.current += 1
|
|
return self.previous()
|
|
|
|
def match(self, *types: TokenType) -> bool:
|
|
for token_type in types:
|
|
if self.check(token_type):
|
|
self.advance()
|
|
return True
|
|
return False
|
|
|
|
def consume(self, token_type: TokenType, error_msg: str) -> Token:
|
|
if self.check(token_type):
|
|
return self.advance()
|
|
raise self.error(self.peek(), error_msg)
|
|
|
|
def expect_eol(self, error_msg: str):
|
|
if self.is_at_end():
|
|
return
|
|
if not self.match(TokenType.NEWLINE) and not self.match(TokenType.EOF):
|
|
raise self.error(self.peek(), error_msg)
|
|
|
|
# Parsing
|
|
def synchronize(self):
|
|
self.advance()
|
|
while not self.is_at_end():
|
|
if self.previous().type == TokenType.NEWLINE:
|
|
return
|
|
if self.peek().type in self.STATEMENT_BOUNDARY:
|
|
return
|
|
self.advance()
|
|
|
|
def declaration(self) -> Optional[Stmt]:
|
|
try:
|
|
if self.match(TokenType.CLASS):
|
|
return self.class_declaration()
|
|
if self.match(TokenType.FUN):
|
|
return self.function("function")
|
|
if self.match(TokenType.LET):
|
|
return self.var_declaration()
|
|
return self.statement()
|
|
except ParsingError:
|
|
self.synchronize()
|
|
return None
|
|
|
|
def class_declaration(self) -> Stmt:
|
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected class name.")
|
|
superclass: Optional[VariableExpr] = None
|
|
if self.match(TokenType.LESS):
|
|
self.consume(TokenType.IDENTIFIER, "Expected superclass name.")
|
|
superclass = VariableExpr(self.previous())
|
|
self.consume(TokenType.LEFT_BRACE, "Expected '{' before class body.")
|
|
methods: list[FunctionStmt] = []
|
|
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
|
methods.append(self.function("method"))
|
|
|
|
self.consume(TokenType.RIGHT_BRACE, "Expected '}' after class body.")
|
|
return ClassStmt(name, superclass, methods)
|
|
|
|
def function(self, kind: str) -> FunctionStmt:
|
|
# TODO: allow anonymous/lambda functions
|
|
name: Token = self.consume(TokenType.IDENTIFIER, f"Expected {kind} name.")
|
|
self.consume(TokenType.LEFT_PAREN, f"Expected '(' after {kind} name.")
|
|
parameters: list[Token] = []
|
|
if not self.check(TokenType.RIGHT_PAREN):
|
|
while True:
|
|
if len(parameters) >= MAX_FUNCTION_ARGS:
|
|
self.error(self.peek(), f"Cannot have more than {MAX_FUNCTION_ARGS} parameters.")
|
|
parameters.append(self.consume(TokenType.IDENTIFIER, "Expected parameter name."))
|
|
if not self.match(TokenType.COMMA):
|
|
break
|
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after parameters.")
|
|
# TODO: allow single statement functions (without braces)
|
|
self.consume(TokenType.LEFT_BRACE, f"Expected '{{' before {kind} body.")
|
|
body: list[Stmt] = self.block_stmt()
|
|
return FunctionStmt(name, parameters, body)
|
|
|
|
def var_declaration(self) -> Stmt:
|
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected variable name.")
|
|
initializer: Optional[Expr] = None
|
|
if self.match(TokenType.EQUAL):
|
|
initializer = self.expression()
|
|
return LetStmt(name, initializer)
|
|
|
|
def statement(self) -> Stmt:
|
|
if self.match(TokenType.FOR):
|
|
return self.for_stmt()
|
|
if self.match(TokenType.IF):
|
|
return self.if_stmt()
|
|
if self.match(TokenType.RETURN):
|
|
return self.return_stmt()
|
|
if self.match(TokenType.BREAK):
|
|
return self.break_stmt()
|
|
if self.match(TokenType.CONTINUE):
|
|
return self.continue_stmt()
|
|
if self.match(TokenType.WHILE):
|
|
return self.while_stmt()
|
|
if self.match(TokenType.LEFT_BRACE):
|
|
return BlockStmt(statements=self.block_stmt())
|
|
return self.expression_stmt()
|
|
|
|
def for_stmt(self) -> Stmt:
|
|
var: Token = self.consume(TokenType.IDENTIFIER, "Missing loop variable.")
|
|
end: Optional[Expr] = None
|
|
from_clause: Optional[Expr] = None
|
|
to_clause: Optional[Expr] = None
|
|
until_clause: Optional[Expr] = None
|
|
by_clause: Optional[Expr] = None
|
|
from_token: Optional[Token] = None
|
|
end_token: Optional[Token] = None
|
|
by_token: Optional[Token] = None
|
|
|
|
while self.match(TokenType.FROM, TokenType.TO, TokenType.UNTIL, TokenType.BY):
|
|
previous: Token = self.previous()
|
|
match previous.type:
|
|
case TokenType.FROM:
|
|
if from_clause is not None:
|
|
raise self.error(previous, "From clause already defined.")
|
|
from_clause = self.expression()
|
|
from_token = previous
|
|
case TokenType.TO:
|
|
if to_clause is not None:
|
|
raise self.error(previous, "To clause already defined.")
|
|
if until_clause is not None:
|
|
raise self.error(previous, "Until clause already defined.")
|
|
to_clause = self.expression()
|
|
end = to_clause
|
|
end_token = previous
|
|
case TokenType.UNTIL:
|
|
if until_clause is not None:
|
|
raise self.error(previous, "Until clause already defined.")
|
|
if to_clause is not None:
|
|
raise self.error(previous, "To clause already defined.")
|
|
until_clause = self.expression()
|
|
end = until_clause
|
|
end_token = previous
|
|
case TokenType.BY:
|
|
if by_clause is not None:
|
|
raise self.error(previous, "By clause already defined.")
|
|
by_clause = self.expression()
|
|
by_token = previous
|
|
|
|
body: Stmt = self.statement()
|
|
|
|
loop: Stmt = ForStmt(
|
|
variable=var,
|
|
start_token=from_token,
|
|
start=from_clause,
|
|
end_token=end_token,
|
|
end=end,
|
|
step_token=by_token,
|
|
step=by_clause,
|
|
body=body
|
|
)
|
|
return loop
|
|
|
|
def if_stmt(self) -> Stmt:
|
|
condition: Expr = self.expression()
|
|
then_branch: Stmt = self.statement()
|
|
else_branch: Optional[Stmt] = None
|
|
if self.match(TokenType.ELSE):
|
|
else_branch = self.statement()
|
|
return IfStmt(condition, then_branch, else_branch)
|
|
|
|
def return_stmt(self) -> Stmt:
|
|
keyword: Token = self.previous()
|
|
value: Optional[Expr] = None
|
|
if not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
|
value = self.expression()
|
|
if not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
|
|
self.error(keyword, "Return must be the last statement in a function.")
|
|
return ReturnStmt(keyword, value)
|
|
|
|
def break_stmt(self) -> Stmt:
|
|
keyword: Token = self.previous()
|
|
if not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
|
|
self.error(keyword, "Break must be the last statement in a block.")
|
|
return BreakStmt(keyword)
|
|
|
|
def continue_stmt(self) -> Stmt:
|
|
keyword: Token = self.previous()
|
|
if not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
|
|
self.error(keyword, "Continue must be the last statement in a block.")
|
|
return ContinueStmt(keyword)
|
|
|
|
def while_stmt(self) -> Stmt:
|
|
condition: Expr = self.expression()
|
|
body: Stmt = self.statement()
|
|
return WhileStmt(condition, body)
|
|
|
|
def block_stmt(self) -> list[Stmt]:
|
|
statements: list[Stmt] = []
|
|
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
|
statements.append(self.declaration())
|
|
|
|
self.consume(TokenType.RIGHT_BRACE, "Expected '}' after block.")
|
|
return statements
|
|
|
|
def expression_stmt(self) -> Stmt:
|
|
value: Expr = self.expression()
|
|
return ExpressionStmt(value)
|
|
|
|
def expression(self) -> Expr:
|
|
return self.assignment()
|
|
|
|
def assignment(self) -> Expr:
|
|
expr: Expr = self.or_()
|
|
if self.match(TokenType.EQUAL, TokenType.PLUS_EQUAL, TokenType.MINUS_EQUAL, TokenType.STAR_EQUAL,
|
|
TokenType.SLASH_EQUAL):
|
|
operator: Token = self.previous()
|
|
value: Expr = self.assignment()
|
|
if operator.type != TokenType.EQUAL:
|
|
value = BinaryExpr(
|
|
expr,
|
|
operator,
|
|
value
|
|
)
|
|
if isinstance(expr, VariableExpr):
|
|
name: Token = expr.name
|
|
return AssignExpr(name, value)
|
|
|
|
elif isinstance(expr, GetExpr):
|
|
return SetExpr(expr.object, expr.name, value)
|
|
self.error(operator, "Invalid assignment target.")
|
|
return expr
|
|
|
|
def or_(self) -> Expr:
|
|
expr: Expr = self.and_()
|
|
while self.match(TokenType.OR):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.and_()
|
|
expr = LogicalExpr(expr, operator, right)
|
|
return expr
|
|
|
|
def and_(self) -> Expr:
|
|
expr: Expr = self.equality()
|
|
while self.match(TokenType.AND):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.equality()
|
|
expr = LogicalExpr(expr, operator, right)
|
|
return expr
|
|
|
|
def equality(self) -> Expr:
|
|
expr: Expr = self.comparison()
|
|
while self.match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.comparison()
|
|
expr = BinaryExpr(expr, operator, right)
|
|
return expr
|
|
|
|
def comparison(self) -> Expr:
|
|
expr: Expr = self.term()
|
|
while self.match(TokenType.LESS, TokenType.LESS_EQUAL, TokenType.GREATER, TokenType.GREATER_EQUAL):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.term()
|
|
expr = BinaryExpr(expr, operator, right)
|
|
return expr
|
|
|
|
def term(self) -> Expr:
|
|
expr: Expr = self.factor()
|
|
while self.match(TokenType.PLUS, TokenType.MINUS):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.factor()
|
|
expr = BinaryExpr(expr, operator, right)
|
|
return expr
|
|
|
|
def factor(self) -> Expr:
|
|
expr: Expr = self.unary()
|
|
while self.match(TokenType.STAR, TokenType.SLASH):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.unary()
|
|
expr = BinaryExpr(expr, operator, right)
|
|
return expr
|
|
|
|
def unary(self) -> Expr:
|
|
if self.match(TokenType.BANG, TokenType.MINUS):
|
|
operator: Token = self.previous()
|
|
right: Expr = self.unary()
|
|
return UnaryExpr(operator, right)
|
|
return self.call()
|
|
|
|
def call(self) -> Expr:
|
|
expr: Expr = self.subscript()
|
|
while True:
|
|
if self.match(TokenType.LEFT_PAREN):
|
|
expr = self.finish_call(expr)
|
|
elif self.match(TokenType.DOT):
|
|
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name after '.'.")
|
|
expr = GetExpr(expr, name)
|
|
else:
|
|
break
|
|
return expr
|
|
|
|
def finish_call(self, callee: Expr) -> Expr:
|
|
arguments: list[Expr] = []
|
|
if not self.check(TokenType.RIGHT_PAREN):
|
|
while True:
|
|
if len(arguments) >= MAX_FUNCTION_ARGS:
|
|
self.error(self.peek(), f"Cannot have more than {MAX_FUNCTION_ARGS} arguments.")
|
|
arguments.append(self.expression())
|
|
if not self.match(TokenType.COMMA):
|
|
break
|
|
|
|
paren: Token = self.consume(TokenType.RIGHT_PAREN, "Expected ')' after arguments.")
|
|
return CallExpr(callee, paren, arguments)
|
|
|
|
def subscript(self) -> Expr:
|
|
expr: Expr = self.primary()
|
|
while self.match(TokenType.LEFT_BRACKET):
|
|
idx: Expr = self.expression()
|
|
bracket: Token = self.consume(TokenType.RIGHT_BRACKET, "Unclosed list index")
|
|
expr = SubscriptGetExpr(expr, bracket, idx)
|
|
return expr
|
|
|
|
def primary(self) -> Expr:
|
|
if self.match(TokenType.FALSE):
|
|
return LiteralExpr(False)
|
|
if self.match(TokenType.TRUE):
|
|
return LiteralExpr(True)
|
|
if self.match(TokenType.NULL):
|
|
return LiteralExpr(None)
|
|
|
|
if self.match(TokenType.FSTRING_START):
|
|
return self.fstring()
|
|
|
|
if self.match(TokenType.LEFT_BRACKET):
|
|
return self.list()
|
|
|
|
if self.match(TokenType.NUMBER, TokenType.STRING):
|
|
return LiteralExpr(self.previous().value)
|
|
|
|
if self.match(TokenType.SUPER):
|
|
keyword: Token = self.previous()
|
|
self.consume(TokenType.DOT, "Expected '.' after 'super'.")
|
|
method: Token = self.consume(TokenType.IDENTIFIER, "Expected superclass method name.")
|
|
return SuperExpr(keyword, method)
|
|
|
|
if self.match(TokenType.THIS):
|
|
return ThisExpr(self.previous())
|
|
|
|
if self.match(TokenType.IDENTIFIER):
|
|
return VariableExpr(self.previous())
|
|
|
|
if self.match(TokenType.LEFT_PAREN):
|
|
expr: Expr = self.expression()
|
|
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
|
return GroupingExpr(expr)
|
|
|
|
raise self.error(self.peek(), "Expected expression")
|
|
|
|
def fstring(self) -> Expr:
|
|
start: Token = self.previous()
|
|
parts: list[LiteralExpr | FStringEmbedExpr] = []
|
|
|
|
while not self.check(TokenType.FSTRING_END) and not self.is_at_end():
|
|
if self.match(TokenType.LEFT_BRACE):
|
|
brace: Token = self.previous()
|
|
expr: Expr = self.expression()
|
|
spec: Optional[FormatSpec] = None
|
|
if self.match(TokenType.FORMAT_SPEC):
|
|
spec = FormatSpecParser(self.previous().value).parse()
|
|
self.consume(TokenType.RIGHT_BRACE, "Expected '}' after f-string embed")
|
|
parts.append(FStringEmbedExpr(brace, expr, spec, self.previous()))
|
|
else:
|
|
self.consume(TokenType.FSTRING_TEXT, "Unexpected token")
|
|
parts.append(LiteralExpr(self.previous().value))
|
|
|
|
self.consume(TokenType.FSTRING_END, "Unclosed f-string")
|
|
return FStringExpr(start, parts, self.previous())
|
|
|
|
def list(self) -> Expr:
|
|
items: list[Expr] = []
|
|
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
|
|
items.append(self.expression())
|
|
if not self.check(TokenType.RIGHT_BRACKET):
|
|
self.consume(TokenType.COMMA, "Expected ',' between list items")
|
|
|
|
bracket: Token = self.consume(TokenType.RIGHT_BRACKET, "Unclosed list")
|
|
return ListExpr(bracket, items)
|