feat: add statements
This commit is contained in:
@@ -1 +1,2 @@
|
|||||||
3 - (4 / 9 + 1) == 12
|
print(3 - (4 / 9 + 1) == 12)
|
||||||
|
print(40 + (8 - -1) / 3 - 0.5 * 2)
|
||||||
29
main.py
29
main.py
@@ -1,11 +1,8 @@
|
|||||||
from typing import Any
|
from src.ast.stmt import Stmt
|
||||||
|
|
||||||
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr
|
|
||||||
from src.ast.printer import AstPrinter
|
|
||||||
from src.interpreter.interpreter import Interpreter
|
from src.interpreter.interpreter import Interpreter
|
||||||
from src.lexer import Lexer
|
from src.lexer import Lexer
|
||||||
from src.parser.parser import Parser
|
from src.parser.parser import Parser
|
||||||
from src.token import Token, TokenType
|
from src.token import Token
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -21,29 +18,13 @@ def main():
|
|||||||
source = f.read()
|
source = f.read()
|
||||||
lexer: Lexer = Lexer()
|
lexer: Lexer = Lexer()
|
||||||
tokens: list[Token] = lexer.process(source, path)
|
tokens: list[Token] = lexer.process(source, path)
|
||||||
print(tokens)
|
print(list(filter(lambda t: t.type not in Parser.IGNORE, tokens)))
|
||||||
|
|
||||||
printer: AstPrinter = AstPrinter()
|
|
||||||
ast: Expr = BinaryExpr(
|
|
||||||
UnaryExpr(
|
|
||||||
Token(TokenType.MINUS, "-", None, None),
|
|
||||||
LiteralExpr(123)
|
|
||||||
),
|
|
||||||
Token(TokenType.STAR, "*", None, None),
|
|
||||||
GroupingExpr(LiteralExpr(45.67))
|
|
||||||
)
|
|
||||||
print(printer.print(ast))
|
|
||||||
|
|
||||||
parser: Parser = Parser()
|
parser: Parser = Parser()
|
||||||
ast = parser.parse(tokens)
|
program: list[Stmt] = parser.parse(tokens)
|
||||||
|
|
||||||
if ast is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
print(printer.print(ast))
|
|
||||||
interpreter: Interpreter = Interpreter()
|
interpreter: Interpreter = Interpreter()
|
||||||
result: Any = interpreter.interpret(ast)
|
interpreter.interpret(program)
|
||||||
print(f"Result: {result}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
41
src/ast/stmt.py
Normal file
41
src/ast/stmt.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TypeVar, Generic
|
||||||
|
|
||||||
|
from src.ast.expr import Expr
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Stmt(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def accept(self, visitor: Visitor[T]) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
class Visitor(ABC, Generic[T]):
|
||||||
|
@abstractmethod
|
||||||
|
def visit_expression_stmt(self, stmt: ExpressionStmt) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_print_stmt(self, stmt: PrintStmt) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ExpressionStmt(Stmt):
|
||||||
|
expression: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_expression_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PrintStmt(Stmt):
|
||||||
|
expression: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_print_stmt(self)
|
||||||
@@ -1,21 +1,26 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr
|
from src.ast.expr import LiteralExpr, GroupingExpr, UnaryExpr, BinaryExpr, Expr
|
||||||
|
from src.ast.stmt import Stmt, PrintStmt, T, ExpressionStmt
|
||||||
from src.interpreter.error import PebbleRuntimeError
|
from src.interpreter.error import PebbleRuntimeError
|
||||||
from src.pebble import Pebble
|
from src.pebble import Pebble
|
||||||
from src.token import TokenType, Token
|
from src.token import TokenType, Token
|
||||||
|
|
||||||
|
|
||||||
class Interpreter(Expr.Visitor[Any]):
|
class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
|
||||||
def interpret(self, expr: Expr) -> Any:
|
def interpret(self, statements: list[Stmt]) -> None:
|
||||||
try:
|
try:
|
||||||
return self.evaluate(expr)
|
for stmt in statements:
|
||||||
|
self.execute(stmt)
|
||||||
except PebbleRuntimeError as e:
|
except PebbleRuntimeError as e:
|
||||||
Pebble.runtime_error(e)
|
Pebble.runtime_error(e)
|
||||||
|
|
||||||
def evaluate(self, expr: Expr) -> Any:
|
def evaluate(self, expr: Expr) -> Any:
|
||||||
return expr.accept(self)
|
return expr.accept(self)
|
||||||
|
|
||||||
|
def execute(self, stmt: Stmt) -> None:
|
||||||
|
stmt.accept(self)
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: BinaryExpr) -> Any:
|
def visit_binary_expr(self, expr: BinaryExpr) -> Any:
|
||||||
left: Any = self.evaluate(expr.left)
|
left: Any = self.evaluate(expr.left)
|
||||||
right: Any = self.evaluate(expr.right)
|
right: Any = self.evaluate(expr.right)
|
||||||
@@ -75,6 +80,13 @@ class Interpreter(Expr.Visitor[Any]):
|
|||||||
def visit_literal_expr(self, expr: LiteralExpr) -> Any:
|
def visit_literal_expr(self, expr: LiteralExpr) -> Any:
|
||||||
return expr.value
|
return expr.value
|
||||||
|
|
||||||
|
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:
|
||||||
|
self.evaluate(stmt.expression)
|
||||||
|
|
||||||
|
def visit_print_stmt(self, stmt: PrintStmt) -> None:
|
||||||
|
value: Any = self.evaluate(stmt.expression)
|
||||||
|
print(value)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_truthy(value: Any) -> bool:
|
def is_truthy(value: Any) -> bool:
|
||||||
if value is None or value is False:
|
if value is None or value is False:
|
||||||
|
|||||||
@@ -14,4 +14,5 @@ KEYWORDS: dict[str, TokenType] = {
|
|||||||
"false": TokenType.FALSE,
|
"false": TokenType.FALSE,
|
||||||
"true": TokenType.TRUE,
|
"true": TokenType.TRUE,
|
||||||
"null": TokenType.NULL,
|
"null": TokenType.NULL,
|
||||||
|
"print": TokenType.PRINT,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr
|
from src.ast.expr import Expr, BinaryExpr, UnaryExpr, LiteralExpr, GroupingExpr
|
||||||
|
from src.ast.stmt import Stmt, PrintStmt, ExpressionStmt
|
||||||
from src.parser.error import ParsingError
|
from src.parser.error import ParsingError
|
||||||
from src.pebble import Pebble
|
from src.pebble import Pebble
|
||||||
from src.token import Token, TokenType
|
from src.token import Token, TokenType
|
||||||
@@ -25,15 +24,15 @@ class Parser:
|
|||||||
Pebble.token_error(token, msg)
|
Pebble.token_error(token, msg)
|
||||||
return ParsingError()
|
return ParsingError()
|
||||||
|
|
||||||
def parse(self, tokens: list[Token]) -> Optional[Expr]:
|
def parse(self, tokens: list[Token]) -> list[Stmt]:
|
||||||
self.tokens = list(filter(lambda t: t.type not in self.IGNORE, tokens))
|
self.tokens = list(filter(lambda t: t.type not in self.IGNORE, tokens))
|
||||||
self.current = 0
|
self.current = 0
|
||||||
self.length = len(self.tokens)
|
self.length = len(self.tokens)
|
||||||
|
|
||||||
try:
|
statements: list[Stmt] = []
|
||||||
return self.expression()
|
while not self.is_at_end():
|
||||||
except ParsingError:
|
statements.append(self.statement())
|
||||||
return None
|
return statements
|
||||||
|
|
||||||
def is_at_end(self) -> bool:
|
def is_at_end(self) -> bool:
|
||||||
return self.current >= self.length
|
return self.current >= self.length
|
||||||
@@ -65,6 +64,12 @@ class Parser:
|
|||||||
if not self.match(token_type):
|
if not self.match(token_type):
|
||||||
raise self.error(self.peek(), error_msg)
|
raise self.error(self.peek(), error_msg)
|
||||||
|
|
||||||
|
def expect_eol(self, error_msg: str):
|
||||||
|
if self.is_at_end():
|
||||||
|
return
|
||||||
|
if not self.match(TokenType.NEWLINE) and not self.match(TokenType.EOF):
|
||||||
|
raise self.error(self.peek(), error_msg)
|
||||||
|
|
||||||
# Parsing
|
# Parsing
|
||||||
def synchronize(self):
|
def synchronize(self):
|
||||||
self.advance()
|
self.advance()
|
||||||
@@ -75,6 +80,23 @@ class Parser:
|
|||||||
return
|
return
|
||||||
self.advance()
|
self.advance()
|
||||||
|
|
||||||
|
def statement(self) -> Stmt:
|
||||||
|
if self.match(TokenType.PRINT):
|
||||||
|
return self.print_stmt()
|
||||||
|
return self.expression_stmt()
|
||||||
|
|
||||||
|
def print_stmt(self) -> Stmt:
|
||||||
|
self.consume(TokenType.LEFT_PAREN, "Missing parentheses")
|
||||||
|
value: Expr = self.expression()
|
||||||
|
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
||||||
|
self.expect_eol("Expected end of line after statement")
|
||||||
|
return PrintStmt(value)
|
||||||
|
|
||||||
|
def expression_stmt(self) -> Stmt:
|
||||||
|
value: Expr = self.expression()
|
||||||
|
self.expect_eol("Expected end of line after expression")
|
||||||
|
return ExpressionStmt(value)
|
||||||
|
|
||||||
def expression(self) -> Expr:
|
def expression(self) -> Expr:
|
||||||
return self.equality()
|
return self.equality()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user