Compare commits

...

3 Commits

Author SHA1 Message Date
822a74acce refactor(checker): rename methods
improve a couple methods names, namely evaluate → type_of and evaluate_block → process_block
2026-06-03 13:03:41 +02:00
9a934fabfd tests: remove union type 2026-06-02 17:22:19 +02:00
828ec9a3fa fix!: remove union type 2026-06-02 17:19:17 +02:00
14 changed files with 48 additions and 136 deletions

View File

@@ -111,10 +111,6 @@ class ConstraintType:
constraint: Expr constraint: Expr
class UnionType:
types: list[Type]
class ComplexType: class ComplexType:
properties: list[PropertyStmt] properties: list[PropertyStmt]

View File

@@ -228,9 +228,6 @@ class Type(ABC):
@abstractmethod @abstractmethod
def visit_constraint_type(self, type: ConstraintType) -> T: ... def visit_constraint_type(self, type: ConstraintType) -> T: ...
@abstractmethod
def visit_union_type(self, type: UnionType) -> T: ...
@abstractmethod @abstractmethod
def visit_complex_type(self, type: ComplexType) -> T: ... def visit_complex_type(self, type: ComplexType) -> T: ...
@@ -261,14 +258,6 @@ class ConstraintType(Type):
return visitor.visit_constraint_type(self) return visitor.visit_constraint_type(self)
@dataclass(frozen=True)
class UnionType(Type):
types: list[Type]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_union_type(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class ComplexType(Type): class ComplexType(Type):
properties: list[PropertyStmt] properties: list[PropertyStmt]

View File

@@ -252,17 +252,6 @@ class MidasAstPrinter(
with self._child_level(single=True): with self._child_level(single=True):
type.constraint.accept(self) type.constraint.accept(self)
def visit_union_type(self, type: m.UnionType) -> None:
self._write_line("UnionType")
with self._child_level():
self._write_line("types", last=True)
with self._child_level():
for i, type_ in enumerate(type.types):
self._idx = i
if i == len(type.types) - 1:
self._mark_last()
type_.accept(self)
def visit_complex_type(self, type: m.ComplexType) -> None: def visit_complex_type(self, type: m.ComplexType) -> None:
self._write_line("ComplexType") self._write_line("ComplexType")
with self._child_level(): with self._child_level():
@@ -379,10 +368,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
res += " where " + type.constraint.accept(self) res += " where " + type.constraint.accept(self)
return res return res
def visit_union_type(self, type: m.UnionType) -> str:
types: list[str] = [type_.accept(self) for type_ in type.types]
return " | ".join(types)
def visit_complex_type(self, type: m.ComplexType) -> str: def visit_complex_type(self, type: m.ComplexType) -> str:
res: str = "{\n" res: str = "{\n"
self.level += 1 self.level += 1

View File

@@ -74,7 +74,7 @@ class Checker(
message=message, message=message,
) )
def evaluate(self, expr: p.Expr) -> Type: def type_of(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression """Evaluate the type of an expression
Args: Args:
@@ -85,7 +85,7 @@ class Checker(
""" """
return expr.accept(self) return expr.accept(self)
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> bool: def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
"""Evaluate a sequence of statements """Evaluate a sequence of statements
Args: Args:
@@ -181,7 +181,7 @@ class Checker(
self.logger.debug(f"Midas operations: {self.ctx._operations}") self.logger.debug(f"Midas operations: {self.ctx._operations}")
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
self.evaluate(stmt.expr) self.type_of(stmt.expr)
def visit_function(self, stmt: p.Function) -> None: def visit_function(self, stmt: p.Function) -> None:
env: Environment = Environment(self.env) env: Environment = Environment(self.env)
@@ -237,7 +237,7 @@ class Checker(
) )
self.env.define(stmt.name, inside_function) self.env.define(stmt.name, inside_function)
returned: bool = self.evaluate_block(stmt.body, env) returned: bool = self.process_block(stmt.body, env)
inferred_return: Type = UnknownType() inferred_return: Type = UnknownType()
if not returned: if not returned:
env.return_types.append(UnitType()) env.return_types.append(UnitType())
@@ -278,7 +278,7 @@ class Checker(
self.env.define(stmt.name, type) self.env.define(stmt.name, type)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
value: Type = self.evaluate(stmt.value) value: Type = self.type_of(stmt.value)
for target in stmt.targets: for target in stmt.targets:
if not isinstance(target, p.VariableExpr): if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}") self.logger.warning(f"Unsupported assignment to {target}")
@@ -317,8 +317,8 @@ class Checker(
) )
env: Environment = Environment(self.env) env: Environment = Environment(self.env)
body_returned: bool = self.evaluate_block(stmt.body, env) body_returned: bool = self.process_block(stmt.body, env)
else_returned: bool = self.evaluate_block(stmt.orelse, env) else_returned: bool = self.process_block(stmt.orelse, env)
self.env.return_types.extend(env.return_types) self.env.return_types.extend(env.return_types)
if body_returned and else_returned: if body_returned and else_returned:
raise ReturnException() raise ReturnException()
@@ -329,8 +329,8 @@ class Checker(
self.logger.warning(f"Unsupported operator {expr.operator}") self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}") self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType() return UnknownType()
left: Type = self.evaluate(expr.left) left: Type = self.type_of(expr.left)
right: Type = self.evaluate(expr.right) right: Type = self.type_of(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right) result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None: if result is None:
@@ -347,8 +347,8 @@ class Checker(
self.logger.warning(f"Unsupported operator {expr.operator}") self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}") self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType() return UnknownType()
left: Type = self.evaluate(expr.left) left: Type = self.type_of(expr.left)
right: Type = self.evaluate(expr.right) right: Type = self.type_of(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right) result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None: if result is None:
@@ -365,7 +365,7 @@ class Checker(
if path := self.parse_midas_import(expr): if path := self.parse_midas_import(expr):
self.import_midas(path) self.import_midas(path)
return UnknownType() return UnknownType()
callee: Type = self.evaluate(expr.callee) callee: Type = self.type_of(expr.callee)
if not isinstance(callee, Function): if not isinstance(callee, Function):
self.error(expr.callee.location, "Callee is not a function") self.error(expr.callee.location, "Callee is not a function")
return UnknownType() return UnknownType()
@@ -460,10 +460,10 @@ class Checker(
list[MappedArgument]: the list of mapped arguments list[MappedArgument]: the list of mapped arguments
""" """
positional: list[tuple[p.Expr, Type]] = [ positional: list[tuple[p.Expr, Type]] = [
(arg, self.evaluate(arg)) for arg in call.arguments (arg, self.type_of(arg)) for arg in call.arguments
] ]
keywords: dict[str, tuple[p.Expr, Type]] = { keywords: dict[str, tuple[p.Expr, Type]] = {
name: (arg, self.evaluate(arg)) for name, arg in call.keywords.items() name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
} }
set_args: set[str] = set() set_args: set[str] = set()

View File

@@ -44,11 +44,4 @@ class ComplexType:
properties: dict[str, Type] properties: dict[str, Type]
@dataclass(frozen=True, kw_only=True) Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType
class UnionType:
alternatives: list[Type]
Type = (
BaseType | AliasType | UnknownType | UnitType | Function | ComplexType | UnionType
)

View File

@@ -294,11 +294,6 @@ class MidasHighlighter(
type.type.accept(self) type.type.accept(self)
type.constraint.accept(self) type.constraint.accept(self)
def visit_union_type(self, type: m.UnionType) -> None:
self.wrap(type, "union-type")
for type_ in type.types:
type_.accept(self)
def visit_complex_type(self, type: m.ComplexType) -> None: def visit_complex_type(self, type: m.ComplexType) -> None:
self.wrap(type, "complex-type") self.wrap(type, "complex-type")
for prop in type.properties: for prop in type.properties:

View File

@@ -8,7 +8,6 @@ span {
&.named-type, &.named-type,
&.generic-type, &.generic-type,
&.constraint-type, &.constraint-type,
&.union-type,
&.complex-type { &.complex-type {
--col: 150, 150, 150; --col: 150, 150, 150;
} }

View File

@@ -18,8 +18,6 @@ class MidasLexer(Lexer):
self.add_token(TokenType.LEFT_BRACE) self.add_token(TokenType.LEFT_BRACE)
case "}": case "}":
self.add_token(TokenType.RIGHT_BRACE) self.add_token(TokenType.RIGHT_BRACE)
case "|":
self.add_token(TokenType.PIPE)
case "<": case "<":
self.add_token( self.add_token(
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS

View File

@@ -23,7 +23,6 @@ class TokenType(Enum):
AND = auto() AND = auto()
QMARK = auto() QMARK = auto()
DOT = auto() DOT = auto()
PIPE = auto()
# Operators # Operators
# PLUS = auto() # PLUS = auto()

View File

@@ -20,7 +20,6 @@ from midas.ast.midas import (
Type, Type,
TypeStmt, TypeStmt,
UnaryExpr, UnaryExpr,
UnionType,
VariableExpr, VariableExpr,
WildcardExpr, WildcardExpr,
) )
@@ -161,18 +160,7 @@ class MidasParser(Parser):
Returns: Returns:
TypeExpr: the parsed type expression TypeExpr: the parsed type expression
""" """
return self.union_type() return self.constraint_type()
def union_type(self) -> Type:
types: list[Type] = [self.constraint_type()]
while self.match(TokenType.PIPE):
types.append(self.constraint_type())
if len(types) == 1:
return types[0]
return UnionType(
location=Location.span(types[0].location, types[-1].location),
types=types,
)
def constraint_type(self) -> Type: def constraint_type(self) -> Type:
type: Type = self.base_type() type: Type = self.base_type()

View File

@@ -4,7 +4,6 @@ import midas.ast.midas as m
from midas.checker.types import ( from midas.checker.types import (
AliasType, AliasType,
Type, Type,
UnionType,
UnknownType, UnknownType,
) )
from midas.resolver.builtin import define_builtins from midas.resolver.builtin import define_builtins
@@ -157,10 +156,6 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
# TODO # TODO
return UnknownType() return UnknownType()
def visit_union_type(self, type: m.UnionType) -> Type:
types: list[Type] = [type_.accept(self) for type_ in type.types]
return UnionType(alternatives=types)
def visit_complex_type(self, type: m.ComplexType) -> Type: def visit_complex_type(self, type: m.ComplexType) -> Type:
for prop in type.properties: for prop in type.properties:
prop.accept(self) prop.accept(self)

View File

@@ -48,7 +48,7 @@ type Person = {
name: str name: str
// Property with an inline constraint // Property with an inline constraint
age: None | (int where (0 <= _ < 150)) age: Optional[int where (0 <= _ < 150)]
// Property referencing a predicate // Property referencing a predicate
height: float where StrictlyPositive height: float where StrictlyPositive

View File

@@ -1982,123 +1982,99 @@
}, },
{ {
"type": "IDENTIFIER", "type": "IDENTIFIER",
"lexeme": "None", "lexeme": "Optional",
"line": 51, "line": 51,
"column": 10 "column": 10
}, },
{ {
"type": "WHITESPACE", "type": "LEFT_BRACKET",
"lexeme": " ", "lexeme": "[",
"line": 51, "line": 51,
"column": 14 "column": 18
},
{
"type": "PIPE",
"lexeme": "|",
"line": 51,
"column": 15
},
{
"type": "WHITESPACE",
"lexeme": " ",
"line": 51,
"column": 16
},
{
"type": "LEFT_PAREN",
"lexeme": "(",
"line": 51,
"column": 17
}, },
{ {
"type": "IDENTIFIER", "type": "IDENTIFIER",
"lexeme": "int", "lexeme": "int",
"line": 51, "line": 51,
"column": 18 "column": 19
}, },
{ {
"type": "WHITESPACE", "type": "WHITESPACE",
"lexeme": " ", "lexeme": " ",
"line": 51, "line": 51,
"column": 21 "column": 22
}, },
{ {
"type": "WHERE", "type": "WHERE",
"lexeme": "where", "lexeme": "where",
"line": 51, "line": 51,
"column": 22 "column": 23
}, },
{ {
"type": "WHITESPACE", "type": "WHITESPACE",
"lexeme": " ", "lexeme": " ",
"line": 51, "line": 51,
"column": 27 "column": 28
}, },
{ {
"type": "LEFT_PAREN", "type": "LEFT_PAREN",
"lexeme": "(", "lexeme": "(",
"line": 51, "line": 51,
"column": 28 "column": 29
}, },
{ {
"type": "NUMBER", "type": "NUMBER",
"lexeme": "0", "lexeme": "0",
"line": 51, "line": 51,
"column": 29 "column": 30
}, },
{ {
"type": "WHITESPACE", "type": "WHITESPACE",
"lexeme": " ", "lexeme": " ",
"line": 51, "line": 51,
"column": 30 "column": 31
}, },
{ {
"type": "LESS_EQUAL", "type": "LESS_EQUAL",
"lexeme": "<=", "lexeme": "<=",
"line": 51, "line": 51,
"column": 31 "column": 32
}, },
{ {
"type": "WHITESPACE", "type": "WHITESPACE",
"lexeme": " ", "lexeme": " ",
"line": 51, "line": 51,
"column": 33 "column": 34
}, },
{ {
"type": "UNDERSCORE", "type": "UNDERSCORE",
"lexeme": "_", "lexeme": "_",
"line": 51, "line": 51,
"column": 34 "column": 35
}, },
{ {
"type": "WHITESPACE", "type": "WHITESPACE",
"lexeme": " ", "lexeme": " ",
"line": 51, "line": 51,
"column": 35 "column": 36
}, },
{ {
"type": "LESS", "type": "LESS",
"lexeme": "<", "lexeme": "<",
"line": 51, "line": 51,
"column": 36 "column": 37
}, },
{ {
"type": "WHITESPACE", "type": "WHITESPACE",
"lexeme": " ", "lexeme": " ",
"line": 51, "line": 51,
"column": 37 "column": 38
}, },
{ {
"type": "NUMBER", "type": "NUMBER",
"lexeme": "150", "lexeme": "150",
"line": 51, "line": 51,
"column": 38 "column": 39
},
{
"type": "RIGHT_PAREN",
"lexeme": ")",
"line": 51,
"column": 41
}, },
{ {
"type": "RIGHT_PAREN", "type": "RIGHT_PAREN",
@@ -2106,11 +2082,17 @@
"line": 51, "line": 51,
"column": 42 "column": 42
}, },
{
"type": "RIGHT_BRACKET",
"lexeme": "]",
"line": 51,
"column": 43
},
{ {
"type": "NEWLINE", "type": "NEWLINE",
"lexeme": "\n", "lexeme": "\n",
"line": 51, "line": 51,
"column": 43 "column": 44
}, },
{ {
"type": "NEWLINE", "type": "NEWLINE",
@@ -2651,12 +2633,12 @@
"_type": "PropertyStmt", "_type": "PropertyStmt",
"name": "age", "name": "age",
"type": { "type": {
"_type": "UnionType", "_type": "GenericType",
"types": [ "type": {
{
"_type": "NamedType", "_type": "NamedType",
"name": "None" "name": "Optional"
}, },
"params": [
{ {
"_type": "ConstraintType", "_type": "ConstraintType",
"type": { "type": {

View File

@@ -19,7 +19,6 @@ from midas.ast.midas import (
Type, Type,
TypeStmt, TypeStmt,
UnaryExpr, UnaryExpr,
UnionType,
VariableExpr, VariableExpr,
WildcardExpr, WildcardExpr,
) )
@@ -161,12 +160,6 @@ class MidasAstJsonSerializer(
"constraint": type.constraint.accept(self), "constraint": type.constraint.accept(self),
} }
def visit_union_type(self, type: UnionType) -> dict:
return {
"_type": "UnionType",
"types": self._serialize_list(type.types),
}
def visit_complex_type(self, type: ComplexType) -> dict: def visit_complex_type(self, type: ComplexType) -> dict:
return { return {
"_type": "ComplexType", "_type": "ComplexType",