Merge pull request 'Type aliases vs. Derived types' (#24) from feat/subtypes-and-aliases into main

Reviewed-on: #24
This commit was merged in pull request #24.
This commit is contained in:
2026-07-01 08:09:13 +00:00
12 changed files with 85 additions and 24 deletions

View File

@@ -44,6 +44,11 @@ class TypeStmt:
type: Type type: Type
class AliasStmt:
name: Token
type: Type
class MemberStmt: class MemberStmt:
name: Token name: Token
type: Type type: Type

View File

@@ -51,6 +51,9 @@ class Stmt(ABC):
@abstractmethod @abstractmethod
def visit_type_stmt(self, stmt: TypeStmt) -> T: ... def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
@abstractmethod
def visit_alias_stmt(self, stmt: AliasStmt) -> T: ...
@abstractmethod @abstractmethod
def visit_member_stmt(self, stmt: MemberStmt) -> T: ... def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
@@ -71,6 +74,15 @@ class TypeStmt(Stmt):
return visitor.visit_type_stmt(self) return visitor.visit_type_stmt(self)
@dataclass(frozen=True)
class AliasStmt(Stmt):
name: Token
type: Type
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_alias_stmt(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class MemberStmt(Stmt): class MemberStmt(Stmt):
name: Token name: Token

View File

@@ -105,6 +105,14 @@ class MidasAstPrinter(
with self._child_level(single=True): with self._child_level(single=True):
stmt.type.accept(self) stmt.type.accept(self)
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
self._write_line("AliasStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def _print_type_param(self, param: m.TypeParam) -> None: def _print_type_param(self, param: m.TypeParam) -> None:
self._write_line("Param") self._write_line("Param")
with self._child_level(): with self._child_level():
@@ -371,6 +379,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}" res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
return self.indented(res) return self.indented(res)
def visit_alias_stmt(self, stmt: m.AliasStmt) -> str:
return self.indented(f"alias {stmt.name.lexeme} = {stmt.type.accept(self)}")
def _print_type_param(self, param: m.TypeParam) -> str: def _print_type_param(self, param: m.TypeParam) -> str:
res: str = param.name.lexeme res: str = param.name.lexeme
if param.bound is not None: if param.bound is not None:

View File

@@ -12,10 +12,10 @@ from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import ( from midas.checker.types import (
AliasType,
AppliedType, AppliedType,
ComplexType, ComplexType,
ConstraintType, ConstraintType,
DerivedType,
ExtensionType, ExtensionType,
Function, Function,
GenericType, GenericType,
@@ -152,11 +152,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
if len(params) != 0: if len(params) != 0:
type = GenericType(name=name, params=params, body=type) type = GenericType(name=name, params=params, body=type)
else: else:
type = AliasType(name=name, type=type) type = DerivedType(name=name, type=type)
self.types.define_type(name, type) self.types.define_type(name, type)
self._local_variables.clear() self._local_variables.clear()
self._current_name = None self._current_name = None
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
name: str = stmt.name.lexeme
self._current_name = name
type: Type = stmt.type.accept(self)
self.types.define_type(name, type)
self._current_name = None
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ... def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:

View File

@@ -18,10 +18,10 @@ from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver from midas.checker.resolver import Resolver
from midas.checker.types import ( from midas.checker.types import (
AliasType,
AppliedType, AppliedType,
BaseType, BaseType,
ConstraintType, ConstraintType,
DerivedType,
Function, Function,
GenericType, GenericType,
OverloadedFunction, OverloadedFunction,
@@ -740,7 +740,7 @@ class PythonTyper(
case UnknownType(): case UnknownType():
return UnknownType() return UnknownType()
case AliasType(type=base): case DerivedType(type=base):
return self._get_call_result( return self._get_call_result(
location, base, positional, keywords, report_errors location, base, positional, keywords, report_errors
) )
@@ -1169,7 +1169,7 @@ class PythonTyper(
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
) -> bool: ) -> bool:
match target_type: match target_type:
case AliasType(type=base): case DerivedType(type=base):
return self._evaluate_cast_statically( return self._evaluate_cast_statically(
expr, subject_type, base, lit_value expr, subject_type, base, lit_value
) )

View File

@@ -5,11 +5,11 @@ from typing import Optional
from midas.ast.midas import MemberKind from midas.ast.midas import MemberKind
from midas.checker.builtins import BUILTIN_SUBTYPES from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import ( from midas.checker.types import (
AliasType,
AppliedType, AppliedType,
BaseType, BaseType,
ComplexType, ComplexType,
ConstraintType, ConstraintType,
DerivedType,
ExtensionType, ExtensionType,
Function, Function,
GenericType, GenericType,
@@ -143,7 +143,7 @@ class TypesRegistry:
return True return True
return self.is_subtype(type1, bound) return self.is_subtype(type1, bound)
case (AliasType(type=base1), _): case (DerivedType(type=base1), _):
return self.is_subtype(base1, type2) return self.is_subtype(base1, type2)
case (BaseType(name=name1), BaseType(name=name2)): case (BaseType(name=name1), BaseType(name=name2)):
@@ -294,8 +294,8 @@ class TypesRegistry:
def apply_generic(self, type: Type, args: list[Type]) -> Type: def apply_generic(self, type: Type, args: list[Type]) -> Type:
match type: match type:
case AliasType(name=name, type=base): case DerivedType(name=name, type=base):
return AliasType(name=name, type=self.apply_generic(base, args)) return DerivedType(name=name, type=self.apply_generic(base, args))
case GenericType(name=name, params=type_vars, body=body): case GenericType(name=name, params=type_vars, body=body):
n_args: int = len(args) n_args: int = len(args)
@@ -362,7 +362,7 @@ class TypesRegistry:
return self._members[name][member_name].type return self._members[name][member_name].type
return None return None
case AliasType(name=name, type=base): case DerivedType(name=name, type=base):
if name in self._members: if name in self._members:
if member_name in self._members[name]: if member_name in self._members[name]:
return self._members[name][member_name].type return self._members[name][member_name].type

View File

@@ -23,7 +23,7 @@ class BaseType:
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class AliasType: class DerivedType:
name: str name: str
type: Type type: Type
@@ -175,8 +175,10 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
case BaseType(): case BaseType():
return type return type
case AliasType(name=name, type=type2): case DerivedType(name=name, type=type2):
return AliasType(name=name, type=substitute_typevars(type2, substitutions)) return DerivedType(
name=name, type=substitute_typevars(type2, substitutions)
)
case Function( case Function(
pos_args=pos_args, pos_args=pos_args,
@@ -263,7 +265,7 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def unfold_type(type: Type) -> Type: def unfold_type(type: Type) -> Type:
match type: match type:
case AliasType(type=ref_type): case DerivedType(type=ref_type):
return unfold_type(ref_type) return unfold_type(ref_type)
case _: case _:
return type return type
@@ -286,7 +288,7 @@ def to_annotation(type: Type) -> str:
case BaseType(name=name): case BaseType(name=name):
return name return name
case AliasType(name=name): case DerivedType(name=name):
return name return name
case UnknownType(): case UnknownType():
@@ -331,7 +333,7 @@ class Predicate:
Type = ( Type = (
TopType TopType
| BaseType | BaseType
| AliasType | DerivedType
| UnknownType | UnknownType
| UnitType | UnitType
| Function | Function

View File

@@ -11,14 +11,14 @@ import click
from midas.ast.printer import MidasPrinter from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker from midas.checker.checker import TypeChecker
from midas.checker.registry import Member from midas.checker.registry import Member
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type from midas.checker.types import AppliedType, BaseType, DerivedType, GenericType, Type
def base_type(type: Type) -> Type: def base_type(type: Type) -> Type:
match type: match type:
case BaseType(): case BaseType():
return type return type
case AliasType(type=base): case DerivedType(type=base):
return base return base
case AppliedType(body=body): case AppliedType(body=body):
return body return body

View File

@@ -10,11 +10,11 @@ from midas.ast.location import Location
from midas.ast.printer import MidasPrinter from midas.ast.printer import MidasPrinter
from midas.checker.registry import TypesRegistry from midas.checker.registry import TypesRegistry
from midas.checker.types import ( from midas.checker.types import (
AliasType,
AppliedType, AppliedType,
BaseType, BaseType,
ComplexType, ComplexType,
ConstraintType, ConstraintType,
DerivedType,
ExtensionType, ExtensionType,
Function, Function,
GenericType, GenericType,
@@ -305,7 +305,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._make_cast_assert_message(src_location, expr, type), self._make_cast_assert_message(src_location, expr, type),
) )
case AliasType(type=base): case DerivedType(type=base):
self._make_cast_asserts(src_location, expr, base) self._make_cast_asserts(src_location, expr, base)
case UnitType(): case UnitType():

View File

@@ -4,11 +4,11 @@ from typing import Optional, assert_never
import midas.ast.midas as m import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import ( from midas.checker.types import (
AliasType,
AppliedType, AppliedType,
BaseType, BaseType,
ComplexType, ComplexType,
ConstraintType, ConstraintType,
DerivedType,
ExtensionType, ExtensionType,
Function, Function,
GenericType, GenericType,
@@ -96,7 +96,7 @@ class StubsGenerator:
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]: def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
match type: match type:
case AliasType(type=base): case DerivedType(type=base):
return [self.dump_type(base)], {} return [self.dump_type(base)], {}
case GenericType(params=params, body=body): case GenericType(params=params, body=body):
@@ -161,7 +161,7 @@ class StubsGenerator:
def dump_type(self, type: Type) -> ast.expr: def dump_type(self, type: Type) -> ast.expr:
match type: match type:
case AliasType(name=name) | GenericType(name=name) if ( case DerivedType(name=name) | GenericType(name=name) if (
name in self.substitutions name in self.substitutions
): ):
type = substitute_typevars(type, self.substitutions[name]) type = substitute_typevars(type, self.substitutions[name])
@@ -174,7 +174,7 @@ class StubsGenerator:
case BaseType(name=name): case BaseType(name=name):
return ast.Name(id=name) return ast.Name(id=name)
case AliasType(name=name): case DerivedType(name=name):
return ast.Name(id=name) return ast.Name(id=name)
case UnitType(): case UnitType():

View File

@@ -47,6 +47,7 @@ class TokenType(Enum):
# Keywords # Keywords
TYPE = auto() TYPE = auto()
ALIAS = auto()
PREDICATE = auto() PREDICATE = auto()
EXTEND = auto() EXTEND = auto()
WHERE = auto() WHERE = auto()
@@ -63,6 +64,7 @@ class TokenType(Enum):
KEYWORDS: dict[str, TokenType] = { KEYWORDS: dict[str, TokenType] = {
"type": TokenType.TYPE, "type": TokenType.TYPE,
"alias": TokenType.ALIAS,
"predicate": TokenType.PREDICATE, "predicate": TokenType.PREDICATE,
"extend": TokenType.EXTEND, "extend": TokenType.EXTEND,
"where": TokenType.WHERE, "where": TokenType.WHERE,

View File

@@ -2,6 +2,7 @@ from typing import Optional
from midas.ast.location import Location from midas.ast.location import Location
from midas.ast.midas import ( from midas.ast.midas import (
AliasStmt,
BinaryExpr, BinaryExpr,
CallExpr, CallExpr,
ComplexType, ComplexType,
@@ -79,6 +80,8 @@ class MidasParser(Parser):
try: try:
if self.match(TokenType.TYPE): if self.match(TokenType.TYPE):
return self.type_declaration() return self.type_declaration()
if self.match(TokenType.ALIAS):
return self.alias_declaration()
if self.match(TokenType.EXTEND): if self.match(TokenType.EXTEND):
return self.extend_declaration() return self.extend_declaration()
if self.match(TokenType.PREDICATE): if self.match(TokenType.PREDICATE):
@@ -158,6 +161,25 @@ class MidasParser(Parser):
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters") self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
return params return params
def alias_declaration(self) -> AliasStmt:
"""Parse an alias declaration
Returns:
AliasStmt: the parsed alias declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected type name")
self.consume(TokenType.EQUAL, "Expected '=' before alias definition")
type: Type = self.type_expr()
return AliasStmt(
location=keyword.location_to(self.previous()),
name=name,
type=type,
)
def type_expr(self) -> Type: def type_expr(self) -> Type:
"""Parse a type expression """Parse a type expression