feat: add basic inheritance

This commit is contained in:
2026-02-06 23:51:23 +01:00
parent 91437e30bc
commit 9e4855598f
9 changed files with 72 additions and 8 deletions

View File

@@ -61,7 +61,7 @@ root ::= declaration* <<eof>> ;
declaration ::= classDecl | funDecl | varDecl | statement ; declaration ::= classDecl | funDecl | varDecl | statement ;
classDecl ::= KW_CLASS IDENTIFIER PUNC_LBRACE function* PUNC_RBRACE ; classDecl ::= KW_CLASS IDENTIFIER ( OP_LESS IDENTIFIER )? PUNC_LBRACE function* PUNC_RBRACE ;
funDecl ::= KW_FUN function ; funDecl ::= KW_FUN function ;
varDecl ::= KW_LET IDENTIFIER ( OP_EQUAL expression )? ; varDecl ::= KW_LET IDENTIFIER ( OP_EQUAL expression )? ;

View File

@@ -0,0 +1,37 @@
class Employee {
get_salary() {
return 100
}
has_responsibilities() {
return false
}
}
class Manager < Employee {
get_salary() {
return 300
}
has_responsibilities() {
return true
}
}
class Boss < Manager {
get_salary() {
return 500
}
}
let employee = Employee()
let manager = Manager()
let boss = Boss()
print(employee, employee.get_salary())
print(manager, manager.get_salary())
print(boss, boss.get_salary())
print(employee, employee.has_responsibilities())
print(manager, manager.has_responsibilities())
print(boss, boss.has_responsibilities())

View File

@@ -4,7 +4,7 @@ from src.pebble import Pebble
def main(): def main():
path: Path = Path("examples/basic/20_init.peb") path: Path = Path("examples/basic/21_super.peb")
Pebble.run_file(path) Pebble.run_file(path)

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeVar, Generic, Optional from typing import TypeVar, Generic, Optional
from src.ast.expr import Expr from src.ast.expr import Expr, VariableExpr
from src.token import Token from src.token import Token
T = TypeVar("T") T = TypeVar("T")
@@ -73,6 +73,7 @@ class BlockStmt(Stmt):
@dataclass(frozen=True) @dataclass(frozen=True)
class ClassStmt(Stmt): class ClassStmt(Stmt):
name: Token name: Token
superclass: Optional[VariableExpr]
methods: list[FunctionStmt] methods: list[FunctionStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:

View File

@@ -12,8 +12,9 @@ if TYPE_CHECKING:
class PebbleClass(PebbleCallable): class PebbleClass(PebbleCallable):
def __init__(self, name: str, methods: dict[str, PebbleFunction]): def __init__(self, name: str, superclass: Optional[PebbleClass], methods: dict[str, PebbleFunction]):
self.name: str = name self.name: str = name
self.superclass: Optional[PebbleClass] = superclass
self.methods: dict[str, PebbleFunction] = methods self.methods: dict[str, PebbleFunction] = methods
def __str__(self): def __str__(self):
@@ -33,4 +34,10 @@ class PebbleClass(PebbleCallable):
return instance return instance
def find_method(self, name: str) -> Optional[PebbleFunction]: def find_method(self, name: str) -> Optional[PebbleFunction]:
return self.methods.get(name) if name in self.methods:
return self.methods[name]
if self.superclass is not None:
return self.superclass.find_method(name)
return None

View File

@@ -87,7 +87,11 @@ class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
return res return res
def visit_class_stmt(self, stmt: ClassStmt) -> str: def visit_class_stmt(self, stmt: ClassStmt) -> str:
res: str = self.indented(f"class {stmt.name.lexeme} {{\n") res: str = self.indented("")
res += f"class {stmt.name.lexeme} "
if stmt.superclass is not None:
res += f"< {stmt.superclass.name.lexeme} "
res += "{\n"
self.level += 1 self.level += 1
enclosing_class: ClassType = self.current_class enclosing_class: ClassType = self.current_class
self.current_class = ClassType.CLASS self.current_class = ClassType.CLASS

View File

@@ -177,12 +177,18 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
self.execute_block(stmt.statements, Environment(self.env)) self.execute_block(stmt.statements, Environment(self.env))
def visit_class_stmt(self, stmt: ClassStmt) -> None: def visit_class_stmt(self, stmt: ClassStmt) -> None:
superclass: Any = None
if stmt.superclass is not None:
superclass = self.evaluate(stmt.superclass)
if not isinstance(superclass, PebbleClass):
raise PebbleRuntimeError(stmt.superclass.name, "Superclass must be a class.")
self.env.define(stmt.name.lexeme, None) self.env.define(stmt.name.lexeme, None)
methods: dict[str, PebbleFunction] = {} methods: dict[str, PebbleFunction] = {}
for method in stmt.methods: for method in stmt.methods:
func: PebbleFunction = PebbleFunction(method, self.env, method.name.lexeme == CONSTRUCTOR_NAME) func: PebbleFunction = PebbleFunction(method, self.env, method.name.lexeme == CONSTRUCTOR_NAME)
methods[method.name.lexeme] = func methods[method.name.lexeme] = func
klass: PebbleClass = PebbleClass(stmt.name.lexeme, methods) klass: PebbleClass = PebbleClass(stmt.name.lexeme, superclass, methods)
self.env.assign(stmt.name, klass) self.env.assign(stmt.name, klass)
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None: def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:

View File

@@ -134,6 +134,11 @@ class Resolver(Expr.Visitor[None], Stmt.Visitor[None]):
self.declare(stmt.name) self.declare(stmt.name)
self.define(stmt.name) self.define(stmt.name)
if stmt.superclass is not None:
if stmt.name.lexeme == stmt.superclass.name.lexeme:
Pebble.token_error(stmt.superclass.name, "A class cannot inherit from itself.")
self.resolve(stmt.superclass)
self.begin_scope() self.begin_scope()
self.scopes[-1]["this"] = True self.scopes[-1]["this"] = True

View File

@@ -97,13 +97,17 @@ class Parser:
def class_declaration(self) -> Stmt: def class_declaration(self) -> Stmt:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected class name.") name: Token = self.consume(TokenType.IDENTIFIER, "Expected class name.")
superclass: Optional[VariableExpr] = None
if self.match(TokenType.LESS):
self.consume(TokenType.IDENTIFIER, "Expected superclass name.")
superclass = VariableExpr(self.previous())
self.consume(TokenType.LEFT_BRACE, "Expected '{' before class body.") self.consume(TokenType.LEFT_BRACE, "Expected '{' before class body.")
methods: list[FunctionStmt] = [] methods: list[FunctionStmt] = []
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end(): while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
methods.append(self.function("method")) methods.append(self.function("method"))
self.consume(TokenType.RIGHT_BRACE, "Expected '}' after class body.") self.consume(TokenType.RIGHT_BRACE, "Expected '}' after class body.")
return ClassStmt(name, methods) return ClassStmt(name, superclass, methods)
def function(self, kind: str) -> FunctionStmt: def function(self, kind: str) -> FunctionStmt:
# TODO: allow anonymous/lambda functions # TODO: allow anonymous/lambda functions