feat: add basic inheritance

This commit is contained in:
2026-02-06 23:51:23 +01:00
parent 91437e30bc
commit 9e4855598f
9 changed files with 72 additions and 8 deletions

View File

@@ -61,7 +61,7 @@ root ::= declaration* <<eof>> ;
declaration ::= classDecl | funDecl | varDecl | statement ;
classDecl ::= KW_CLASS IDENTIFIER PUNC_LBRACE function* PUNC_RBRACE ;
classDecl ::= KW_CLASS IDENTIFIER ( OP_LESS IDENTIFIER )? PUNC_LBRACE function* PUNC_RBRACE ;
funDecl ::= KW_FUN function ;
varDecl ::= KW_LET IDENTIFIER ( OP_EQUAL expression )? ;

View File

@@ -0,0 +1,37 @@
class Employee {
get_salary() {
return 100
}
has_responsibilities() {
return false
}
}
class Manager < Employee {
get_salary() {
return 300
}
has_responsibilities() {
return true
}
}
class Boss < Manager {
get_salary() {
return 500
}
}
let employee = Employee()
let manager = Manager()
let boss = Boss()
print(employee, employee.get_salary())
print(manager, manager.get_salary())
print(boss, boss.get_salary())
print(employee, employee.has_responsibilities())
print(manager, manager.has_responsibilities())
print(boss, boss.has_responsibilities())

View File

@@ -4,7 +4,7 @@ from src.pebble import Pebble
def main():
path: Path = Path("examples/basic/20_init.peb")
path: Path = Path("examples/basic/21_super.peb")
Pebble.run_file(path)

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TypeVar, Generic, Optional
from src.ast.expr import Expr
from src.ast.expr import Expr, VariableExpr
from src.token import Token
T = TypeVar("T")
@@ -73,6 +73,7 @@ class BlockStmt(Stmt):
@dataclass(frozen=True)
class ClassStmt(Stmt):
name: Token
superclass: Optional[VariableExpr]
methods: list[FunctionStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:

View File

@@ -12,8 +12,9 @@ if TYPE_CHECKING:
class PebbleClass(PebbleCallable):
def __init__(self, name: str, methods: dict[str, PebbleFunction]):
def __init__(self, name: str, superclass: Optional[PebbleClass], methods: dict[str, PebbleFunction]):
self.name: str = name
self.superclass: Optional[PebbleClass] = superclass
self.methods: dict[str, PebbleFunction] = methods
def __str__(self):
@@ -33,4 +34,10 @@ class PebbleClass(PebbleCallable):
return instance
def find_method(self, name: str) -> Optional[PebbleFunction]:
return self.methods.get(name)
if name in self.methods:
return self.methods[name]
if self.superclass is not None:
return self.superclass.find_method(name)
return None

View File

@@ -87,7 +87,11 @@ class Formatter(Expr.Visitor[str], Stmt.Visitor[str]):
return res
def visit_class_stmt(self, stmt: ClassStmt) -> str:
res: str = self.indented(f"class {stmt.name.lexeme} {{\n")
res: str = self.indented("")
res += f"class {stmt.name.lexeme} "
if stmt.superclass is not None:
res += f"< {stmt.superclass.name.lexeme} "
res += "{\n"
self.level += 1
enclosing_class: ClassType = self.current_class
self.current_class = ClassType.CLASS

View File

@@ -177,12 +177,18 @@ class Interpreter(Expr.Visitor[Any], Stmt.Visitor[None]):
self.execute_block(stmt.statements, Environment(self.env))
def visit_class_stmt(self, stmt: ClassStmt) -> None:
superclass: Any = None
if stmt.superclass is not None:
superclass = self.evaluate(stmt.superclass)
if not isinstance(superclass, PebbleClass):
raise PebbleRuntimeError(stmt.superclass.name, "Superclass must be a class.")
self.env.define(stmt.name.lexeme, None)
methods: dict[str, PebbleFunction] = {}
for method in stmt.methods:
func: PebbleFunction = PebbleFunction(method, self.env, method.name.lexeme == CONSTRUCTOR_NAME)
methods[method.name.lexeme] = func
klass: PebbleClass = PebbleClass(stmt.name.lexeme, methods)
klass: PebbleClass = PebbleClass(stmt.name.lexeme, superclass, methods)
self.env.assign(stmt.name, klass)
def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:

View File

@@ -134,6 +134,11 @@ class Resolver(Expr.Visitor[None], Stmt.Visitor[None]):
self.declare(stmt.name)
self.define(stmt.name)
if stmt.superclass is not None:
if stmt.name.lexeme == stmt.superclass.name.lexeme:
Pebble.token_error(stmt.superclass.name, "A class cannot inherit from itself.")
self.resolve(stmt.superclass)
self.begin_scope()
self.scopes[-1]["this"] = True

View File

@@ -97,13 +97,17 @@ class Parser:
def class_declaration(self) -> Stmt:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected class name.")
superclass: Optional[VariableExpr] = None
if self.match(TokenType.LESS):
self.consume(TokenType.IDENTIFIER, "Expected superclass name.")
superclass = VariableExpr(self.previous())
self.consume(TokenType.LEFT_BRACE, "Expected '{' before class body.")
methods: list[FunctionStmt] = []
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
methods.append(self.function("method"))
self.consume(TokenType.RIGHT_BRACE, "Expected '}' after class body.")
return ClassStmt(name, methods)
return ClassStmt(name, superclass, methods)
def function(self, kind: str) -> FunctionStmt:
# TODO: allow anonymous/lambda functions