Merge pull request 'Type aliases vs. Derived types' (#24) from feat/subtypes-and-aliases into main
Reviewed-on: #24
This commit was merged in pull request #24.
This commit is contained in:
@@ -44,6 +44,11 @@ class TypeStmt:
|
|||||||
type: Type
|
type: Type
|
||||||
|
|
||||||
|
|
||||||
|
class AliasStmt:
|
||||||
|
name: Token
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
|
||||||
class MemberStmt:
|
class MemberStmt:
|
||||||
name: Token
|
name: Token
|
||||||
type: Type
|
type: Type
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ class Stmt(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_alias_stmt(self, stmt: AliasStmt) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
|
def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
|
||||||
|
|
||||||
@@ -71,6 +74,15 @@ class TypeStmt(Stmt):
|
|||||||
return visitor.visit_type_stmt(self)
|
return visitor.visit_type_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AliasStmt(Stmt):
|
||||||
|
name: Token
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_alias_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class MemberStmt(Stmt):
|
class MemberStmt(Stmt):
|
||||||
name: Token
|
name: Token
|
||||||
|
|||||||
@@ -105,6 +105,14 @@ class MidasAstPrinter(
|
|||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
stmt.type.accept(self)
|
stmt.type.accept(self)
|
||||||
|
|
||||||
|
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
|
||||||
|
self._write_line("AliasStmt")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||||
|
self._write_line("type", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
stmt.type.accept(self)
|
||||||
|
|
||||||
def _print_type_param(self, param: m.TypeParam) -> None:
|
def _print_type_param(self, param: m.TypeParam) -> None:
|
||||||
self._write_line("Param")
|
self._write_line("Param")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
@@ -371,6 +379,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||||
return self.indented(res)
|
return self.indented(res)
|
||||||
|
|
||||||
|
def visit_alias_stmt(self, stmt: m.AliasStmt) -> str:
|
||||||
|
return self.indented(f"alias {stmt.name.lexeme} = {stmt.type.accept(self)}")
|
||||||
|
|
||||||
def _print_type_param(self, param: m.TypeParam) -> str:
|
def _print_type_param(self, param: m.TypeParam) -> str:
|
||||||
res: str = param.name.lexeme
|
res: str = param.name.lexeme
|
||||||
if param.bound is not None:
|
if param.bound is not None:
|
||||||
|
|||||||
@@ -12,10 +12,10 @@ from midas.checker.preamble import Preamble
|
|||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.reporter import FileReporter, Reporter
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
|
||||||
AppliedType,
|
AppliedType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DerivedType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
@@ -152,11 +152,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
|
|||||||
if len(params) != 0:
|
if len(params) != 0:
|
||||||
type = GenericType(name=name, params=params, body=type)
|
type = GenericType(name=name, params=params, body=type)
|
||||||
else:
|
else:
|
||||||
type = AliasType(name=name, type=type)
|
type = DerivedType(name=name, type=type)
|
||||||
self.types.define_type(name, type)
|
self.types.define_type(name, type)
|
||||||
self._local_variables.clear()
|
self._local_variables.clear()
|
||||||
self._current_name = None
|
self._current_name = None
|
||||||
|
|
||||||
|
def visit_alias_stmt(self, stmt: m.AliasStmt) -> None:
|
||||||
|
name: str = stmt.name.lexeme
|
||||||
|
self._current_name = name
|
||||||
|
type: Type = stmt.type.accept(self)
|
||||||
|
self.types.define_type(name, type)
|
||||||
|
self._current_name = None
|
||||||
|
|
||||||
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
|
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
|
||||||
|
|
||||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||||
|
|||||||
@@ -18,10 +18,10 @@ from midas.checker.registry import TypesRegistry
|
|||||||
from midas.checker.reporter import FileReporter, Reporter
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
from midas.checker.resolver import Resolver
|
from midas.checker.resolver import Resolver
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DerivedType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
@@ -740,7 +740,7 @@ class PythonTyper(
|
|||||||
case UnknownType():
|
case UnknownType():
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
case AliasType(type=base):
|
case DerivedType(type=base):
|
||||||
return self._get_call_result(
|
return self._get_call_result(
|
||||||
location, base, positional, keywords, report_errors
|
location, base, positional, keywords, report_errors
|
||||||
)
|
)
|
||||||
@@ -1169,7 +1169,7 @@ class PythonTyper(
|
|||||||
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
|
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
|
||||||
) -> bool:
|
) -> bool:
|
||||||
match target_type:
|
match target_type:
|
||||||
case AliasType(type=base):
|
case DerivedType(type=base):
|
||||||
return self._evaluate_cast_statically(
|
return self._evaluate_cast_statically(
|
||||||
expr, subject_type, base, lit_value
|
expr, subject_type, base, lit_value
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,11 +5,11 @@ from typing import Optional
|
|||||||
from midas.ast.midas import MemberKind
|
from midas.ast.midas import MemberKind
|
||||||
from midas.checker.builtins import BUILTIN_SUBTYPES
|
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DerivedType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
@@ -143,7 +143,7 @@ class TypesRegistry:
|
|||||||
return True
|
return True
|
||||||
return self.is_subtype(type1, bound)
|
return self.is_subtype(type1, bound)
|
||||||
|
|
||||||
case (AliasType(type=base1), _):
|
case (DerivedType(type=base1), _):
|
||||||
return self.is_subtype(base1, type2)
|
return self.is_subtype(base1, type2)
|
||||||
|
|
||||||
case (BaseType(name=name1), BaseType(name=name2)):
|
case (BaseType(name=name1), BaseType(name=name2)):
|
||||||
@@ -294,8 +294,8 @@ class TypesRegistry:
|
|||||||
|
|
||||||
def apply_generic(self, type: Type, args: list[Type]) -> Type:
|
def apply_generic(self, type: Type, args: list[Type]) -> Type:
|
||||||
match type:
|
match type:
|
||||||
case AliasType(name=name, type=base):
|
case DerivedType(name=name, type=base):
|
||||||
return AliasType(name=name, type=self.apply_generic(base, args))
|
return DerivedType(name=name, type=self.apply_generic(base, args))
|
||||||
|
|
||||||
case GenericType(name=name, params=type_vars, body=body):
|
case GenericType(name=name, params=type_vars, body=body):
|
||||||
n_args: int = len(args)
|
n_args: int = len(args)
|
||||||
@@ -362,7 +362,7 @@ class TypesRegistry:
|
|||||||
return self._members[name][member_name].type
|
return self._members[name][member_name].type
|
||||||
return None
|
return None
|
||||||
|
|
||||||
case AliasType(name=name, type=base):
|
case DerivedType(name=name, type=base):
|
||||||
if name in self._members:
|
if name in self._members:
|
||||||
if member_name in self._members[name]:
|
if member_name in self._members[name]:
|
||||||
return self._members[name][member_name].type
|
return self._members[name][member_name].type
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class BaseType:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class AliasType:
|
class DerivedType:
|
||||||
name: str
|
name: str
|
||||||
type: Type
|
type: Type
|
||||||
|
|
||||||
@@ -175,8 +175,10 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
case BaseType():
|
case BaseType():
|
||||||
return type
|
return type
|
||||||
|
|
||||||
case AliasType(name=name, type=type2):
|
case DerivedType(name=name, type=type2):
|
||||||
return AliasType(name=name, type=substitute_typevars(type2, substitutions))
|
return DerivedType(
|
||||||
|
name=name, type=substitute_typevars(type2, substitutions)
|
||||||
|
)
|
||||||
|
|
||||||
case Function(
|
case Function(
|
||||||
pos_args=pos_args,
|
pos_args=pos_args,
|
||||||
@@ -263,7 +265,7 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
|
|
||||||
def unfold_type(type: Type) -> Type:
|
def unfold_type(type: Type) -> Type:
|
||||||
match type:
|
match type:
|
||||||
case AliasType(type=ref_type):
|
case DerivedType(type=ref_type):
|
||||||
return unfold_type(ref_type)
|
return unfold_type(ref_type)
|
||||||
case _:
|
case _:
|
||||||
return type
|
return type
|
||||||
@@ -286,7 +288,7 @@ def to_annotation(type: Type) -> str:
|
|||||||
case BaseType(name=name):
|
case BaseType(name=name):
|
||||||
return name
|
return name
|
||||||
|
|
||||||
case AliasType(name=name):
|
case DerivedType(name=name):
|
||||||
return name
|
return name
|
||||||
|
|
||||||
case UnknownType():
|
case UnknownType():
|
||||||
@@ -331,7 +333,7 @@ class Predicate:
|
|||||||
Type = (
|
Type = (
|
||||||
TopType
|
TopType
|
||||||
| BaseType
|
| BaseType
|
||||||
| AliasType
|
| DerivedType
|
||||||
| UnknownType
|
| UnknownType
|
||||||
| UnitType
|
| UnitType
|
||||||
| Function
|
| Function
|
||||||
|
|||||||
@@ -11,14 +11,14 @@ import click
|
|||||||
from midas.ast.printer import MidasPrinter
|
from midas.ast.printer import MidasPrinter
|
||||||
from midas.checker.checker import TypeChecker
|
from midas.checker.checker import TypeChecker
|
||||||
from midas.checker.registry import Member
|
from midas.checker.registry import Member
|
||||||
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
|
from midas.checker.types import AppliedType, BaseType, DerivedType, GenericType, Type
|
||||||
|
|
||||||
|
|
||||||
def base_type(type: Type) -> Type:
|
def base_type(type: Type) -> Type:
|
||||||
match type:
|
match type:
|
||||||
case BaseType():
|
case BaseType():
|
||||||
return type
|
return type
|
||||||
case AliasType(type=base):
|
case DerivedType(type=base):
|
||||||
return base
|
return base
|
||||||
case AppliedType(body=body):
|
case AppliedType(body=body):
|
||||||
return body
|
return body
|
||||||
|
|||||||
@@ -10,11 +10,11 @@ from midas.ast.location import Location
|
|||||||
from midas.ast.printer import MidasPrinter
|
from midas.ast.printer import MidasPrinter
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DerivedType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
@@ -305,7 +305,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
self._make_cast_assert_message(src_location, expr, type),
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
)
|
)
|
||||||
|
|
||||||
case AliasType(type=base):
|
case DerivedType(type=base):
|
||||||
self._make_cast_asserts(src_location, expr, base)
|
self._make_cast_asserts(src_location, expr, base)
|
||||||
|
|
||||||
case UnitType():
|
case UnitType():
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ from typing import Optional, assert_never
|
|||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
from midas.checker.registry import Member, TypesRegistry
|
from midas.checker.registry import Member, TypesRegistry
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
|
DerivedType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
@@ -96,7 +96,7 @@ class StubsGenerator:
|
|||||||
|
|
||||||
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
|
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
|
||||||
match type:
|
match type:
|
||||||
case AliasType(type=base):
|
case DerivedType(type=base):
|
||||||
return [self.dump_type(base)], {}
|
return [self.dump_type(base)], {}
|
||||||
|
|
||||||
case GenericType(params=params, body=body):
|
case GenericType(params=params, body=body):
|
||||||
@@ -161,7 +161,7 @@ class StubsGenerator:
|
|||||||
|
|
||||||
def dump_type(self, type: Type) -> ast.expr:
|
def dump_type(self, type: Type) -> ast.expr:
|
||||||
match type:
|
match type:
|
||||||
case AliasType(name=name) | GenericType(name=name) if (
|
case DerivedType(name=name) | GenericType(name=name) if (
|
||||||
name in self.substitutions
|
name in self.substitutions
|
||||||
):
|
):
|
||||||
type = substitute_typevars(type, self.substitutions[name])
|
type = substitute_typevars(type, self.substitutions[name])
|
||||||
@@ -174,7 +174,7 @@ class StubsGenerator:
|
|||||||
case BaseType(name=name):
|
case BaseType(name=name):
|
||||||
return ast.Name(id=name)
|
return ast.Name(id=name)
|
||||||
|
|
||||||
case AliasType(name=name):
|
case DerivedType(name=name):
|
||||||
return ast.Name(id=name)
|
return ast.Name(id=name)
|
||||||
|
|
||||||
case UnitType():
|
case UnitType():
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ class TokenType(Enum):
|
|||||||
|
|
||||||
# Keywords
|
# Keywords
|
||||||
TYPE = auto()
|
TYPE = auto()
|
||||||
|
ALIAS = auto()
|
||||||
PREDICATE = auto()
|
PREDICATE = auto()
|
||||||
EXTEND = auto()
|
EXTEND = auto()
|
||||||
WHERE = auto()
|
WHERE = auto()
|
||||||
@@ -63,6 +64,7 @@ class TokenType(Enum):
|
|||||||
|
|
||||||
KEYWORDS: dict[str, TokenType] = {
|
KEYWORDS: dict[str, TokenType] = {
|
||||||
"type": TokenType.TYPE,
|
"type": TokenType.TYPE,
|
||||||
|
"alias": TokenType.ALIAS,
|
||||||
"predicate": TokenType.PREDICATE,
|
"predicate": TokenType.PREDICATE,
|
||||||
"extend": TokenType.EXTEND,
|
"extend": TokenType.EXTEND,
|
||||||
"where": TokenType.WHERE,
|
"where": TokenType.WHERE,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.ast.midas import (
|
from midas.ast.midas import (
|
||||||
|
AliasStmt,
|
||||||
BinaryExpr,
|
BinaryExpr,
|
||||||
CallExpr,
|
CallExpr,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
@@ -79,6 +80,8 @@ class MidasParser(Parser):
|
|||||||
try:
|
try:
|
||||||
if self.match(TokenType.TYPE):
|
if self.match(TokenType.TYPE):
|
||||||
return self.type_declaration()
|
return self.type_declaration()
|
||||||
|
if self.match(TokenType.ALIAS):
|
||||||
|
return self.alias_declaration()
|
||||||
if self.match(TokenType.EXTEND):
|
if self.match(TokenType.EXTEND):
|
||||||
return self.extend_declaration()
|
return self.extend_declaration()
|
||||||
if self.match(TokenType.PREDICATE):
|
if self.match(TokenType.PREDICATE):
|
||||||
@@ -158,6 +161,25 @@ class MidasParser(Parser):
|
|||||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
|
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
def alias_declaration(self) -> AliasStmt:
|
||||||
|
"""Parse an alias declaration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AliasStmt: the parsed alias declaration statement
|
||||||
|
"""
|
||||||
|
keyword: Token = self.previous()
|
||||||
|
name: Token = self.consume_identifier("Expected type name")
|
||||||
|
|
||||||
|
self.consume(TokenType.EQUAL, "Expected '=' before alias definition")
|
||||||
|
|
||||||
|
type: Type = self.type_expr()
|
||||||
|
|
||||||
|
return AliasStmt(
|
||||||
|
location=keyword.location_to(self.previous()),
|
||||||
|
name=name,
|
||||||
|
type=type,
|
||||||
|
)
|
||||||
|
|
||||||
def type_expr(self) -> Type:
|
def type_expr(self) -> Type:
|
||||||
"""Parse a type expression
|
"""Parse a type expression
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user