Merge pull request 'Type aliases vs. Derived types' (#24) from feat/subtypes-and-aliases into main

Reviewed-on: #24
This commit was merged in pull request #24.
This commit is contained in:
2026-07-01 08:09:13 +00:00
12 changed files with 85 additions and 24 deletions

View File

@@ -44,6 +44,11 @@ class TypeStmt:
type: Type
class AliasStmt:
name: Token
type: Type
class MemberStmt:
name: Token
type: Type

View File

@@ -51,6 +51,9 @@ class Stmt(ABC):
@abstractmethod
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
@abstractmethod
def visit_alias_stmt(self, stmt: AliasStmt) -> T: ...
@abstractmethod
def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
@@ -71,6 +74,15 @@ class TypeStmt(Stmt):
return visitor.visit_type_stmt(self)
@dataclass(frozen=True)
class AliasStmt(Stmt):
name: Token
type: Type
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_alias_stmt(self)
@dataclass(frozen=True)
class MemberStmt(Stmt):
name: Token

View File

@@ -105,6 +105,14 @@ class MidasAstPrinter(
with self._child_level(single=True):
stmt.type.accept(self)
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
self._write_line("AliasStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def _print_type_param(self, param: m.TypeParam) -> None:
self._write_line("Param")
with self._child_level():
@@ -371,6 +379,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
return self.indented(res)
def visit_alias_stmt(self, stmt: m.AliasStmt) -> str:
return self.indented(f"alias {stmt.name.lexeme} = {stmt.type.accept(self)}")
def _print_type_param(self, param: m.TypeParam) -> str:
res: str = param.name.lexeme
if param.bound is not None:

View File

@@ -12,10 +12,10 @@ from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AliasType,
AppliedType,
ComplexType,
ConstraintType,
DerivedType,
ExtensionType,
Function,
GenericType,
@@ -152,11 +152,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
if len(params) != 0:
type = GenericType(name=name, params=params, body=type)
else:
type = AliasType(name=name, type=type)
type = DerivedType(name=name, type=type)
self.types.define_type(name, type)
self._local_variables.clear()
self._current_name = None
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
name: str = stmt.name.lexeme
self._current_name = name
type: Type = stmt.type.accept(self)
self.types.define_type(name, type)
self._current_name = None
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:

View File

@@ -18,10 +18,10 @@ from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ConstraintType,
DerivedType,
Function,
GenericType,
OverloadedFunction,
@@ -740,7 +740,7 @@ class PythonTyper(
case UnknownType():
return UnknownType()
case AliasType(type=base):
case DerivedType(type=base):
return self._get_call_result(
location, base, positional, keywords, report_errors
)
@@ -1169,7 +1169,7 @@ class PythonTyper(
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
) -> bool:
match target_type:
case AliasType(type=base):
case DerivedType(type=base):
return self._evaluate_cast_statically(
expr, subject_type, base, lit_value
)

View File

@@ -5,11 +5,11 @@ from typing import Optional
from midas.ast.midas import MemberKind
from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ConstraintType,
DerivedType,
ExtensionType,
Function,
GenericType,
@@ -143,7 +143,7 @@ class TypesRegistry:
return True
return self.is_subtype(type1, bound)
case (AliasType(type=base1), _):
case (DerivedType(type=base1), _):
return self.is_subtype(base1, type2)
case (BaseType(name=name1), BaseType(name=name2)):
@@ -294,8 +294,8 @@ class TypesRegistry:
def apply_generic(self, type: Type, args: list[Type]) -> Type:
match type:
case AliasType(name=name, type=base):
return AliasType(name=name, type=self.apply_generic(base, args))
case DerivedType(name=name, type=base):
return DerivedType(name=name, type=self.apply_generic(base, args))
case GenericType(name=name, params=type_vars, body=body):
n_args: int = len(args)
@@ -362,7 +362,7 @@ class TypesRegistry:
return self._members[name][member_name].type
return None
case AliasType(name=name, type=base):
case DerivedType(name=name, type=base):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name].type

View File

@@ -23,7 +23,7 @@ class BaseType:
@dataclass(frozen=True, kw_only=True)
class AliasType:
class DerivedType:
name: str
type: Type
@@ -175,8 +175,10 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
case BaseType():
return type
case AliasType(name=name, type=type2):
return AliasType(name=name, type=substitute_typevars(type2, substitutions))
case DerivedType(name=name, type=type2):
return DerivedType(
name=name, type=substitute_typevars(type2, substitutions)
)
case Function(
pos_args=pos_args,
@@ -263,7 +265,7 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def unfold_type(type: Type) -> Type:
match type:
case AliasType(type=ref_type):
case DerivedType(type=ref_type):
return unfold_type(ref_type)
case _:
return type
@@ -286,7 +288,7 @@ def to_annotation(type: Type) -> str:
case BaseType(name=name):
return name
case AliasType(name=name):
case DerivedType(name=name):
return name
case UnknownType():
@@ -331,7 +333,7 @@ class Predicate:
Type = (
TopType
| BaseType
| AliasType
| DerivedType
| UnknownType
| UnitType
| Function

View File

@@ -11,14 +11,14 @@ import click
from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker
from midas.checker.registry import Member
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
from midas.checker.types import AppliedType, BaseType, DerivedType, GenericType, Type
def base_type(type: Type) -> Type:
match type:
case BaseType():
return type
case AliasType(type=base):
case DerivedType(type=base):
return base
case AppliedType(body=body):
return body

View File

@@ -10,11 +10,11 @@ from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ConstraintType,
DerivedType,
ExtensionType,
Function,
GenericType,
@@ -305,7 +305,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._make_cast_assert_message(src_location, expr, type),
)
case AliasType(type=base):
case DerivedType(type=base):
self._make_cast_asserts(src_location, expr, base)
case UnitType():

View File

@@ -4,11 +4,11 @@ from typing import Optional, assert_never
import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ConstraintType,
DerivedType,
ExtensionType,
Function,
GenericType,
@@ -96,7 +96,7 @@ class StubsGenerator:
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
match type:
case AliasType(type=base):
case DerivedType(type=base):
return [self.dump_type(base)], {}
case GenericType(params=params, body=body):
@@ -161,7 +161,7 @@ class StubsGenerator:
def dump_type(self, type: Type) -> ast.expr:
match type:
case AliasType(name=name) | GenericType(name=name) if (
case DerivedType(name=name) | GenericType(name=name) if (
name in self.substitutions
):
type = substitute_typevars(type, self.substitutions[name])
@@ -174,7 +174,7 @@ class StubsGenerator:
case BaseType(name=name):
return ast.Name(id=name)
case AliasType(name=name):
case DerivedType(name=name):
return ast.Name(id=name)
case UnitType():

View File

@@ -47,6 +47,7 @@ class TokenType(Enum):
# Keywords
TYPE = auto()
ALIAS = auto()
PREDICATE = auto()
EXTEND = auto()
WHERE = auto()
@@ -63,6 +64,7 @@ class TokenType(Enum):
KEYWORDS: dict[str, TokenType] = {
"type": TokenType.TYPE,
"alias": TokenType.ALIAS,
"predicate": TokenType.PREDICATE,
"extend": TokenType.EXTEND,
"where": TokenType.WHERE,

View File

@@ -2,6 +2,7 @@ from typing import Optional
from midas.ast.location import Location
from midas.ast.midas import (
AliasStmt,
BinaryExpr,
CallExpr,
ComplexType,
@@ -79,6 +80,8 @@ class MidasParser(Parser):
try:
if self.match(TokenType.TYPE):
return self.type_declaration()
if self.match(TokenType.ALIAS):
return self.alias_declaration()
if self.match(TokenType.EXTEND):
return self.extend_declaration()
if self.match(TokenType.PREDICATE):
@@ -158,6 +161,25 @@ class MidasParser(Parser):
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
return params
def alias_declaration(self) -> AliasStmt:
"""Parse an alias declaration
Returns:
AliasStmt: the parsed alias declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected type name")
self.consume(TokenType.EQUAL, "Expected '=' before alias definition")
type: Type = self.type_expr()
return AliasStmt(
location=keyword.location_to(self.previous()),
name=name,
type=type,
)
def type_expr(self) -> Type:
"""Parse a type expression