Compare commits
4 Commits
bea3f399ad
...
7f3d74ee49
| Author | SHA1 | Date | |
|---|---|---|---|
|
7f3d74ee49
|
|||
|
b9f378de6f
|
|||
|
ccb17c7290
|
|||
|
505779310a
|
33
examples/00_syntax_prototype/05_custom_types_v3.midas
Normal file
33
examples/00_syntax_prototype/05_custom_types_v3.midas
Normal file
@@ -0,0 +1,33 @@
|
||||
type Foo1 = float
|
||||
type Foo2 = float where (_ > 3)
|
||||
type Foo3 = int | float
|
||||
type Foo4 = int where (_ > 3) | float where (_ > 3)
|
||||
type Foo5 = (int | float) where (_ > 3)
|
||||
type Foo6 = {
|
||||
foo: float
|
||||
bar: float where (_ > 3)
|
||||
}
|
||||
|
||||
type Foo7[T] = T where (_ > 3)
|
||||
type Foo8[A, B<:int] = {
|
||||
a: A
|
||||
b: B
|
||||
}
|
||||
|
||||
type Complex = {
|
||||
a: int
|
||||
b: int
|
||||
}
|
||||
type Complex2 = Complex where (_.a > 3 & _.b < 5)
|
||||
|
||||
predicate Positive(n: int) = n >= 0
|
||||
|
||||
extend Foo1 {
|
||||
op __add__(Foo1) -> Foo1
|
||||
}
|
||||
|
||||
extend Foo7[T] {
|
||||
op __add__(Foo7[T]) -> Foo7[T]
|
||||
}
|
||||
|
||||
type Optional[T] = None | T
|
||||
58
gen/midas.py
58
gen/midas.py
@@ -13,40 +13,38 @@ from midas.lexer.token import Token
|
||||
|
||||
|
||||
###> Stmt | Statements
|
||||
class SimpleTypeStmt:
|
||||
class TypeStmt:
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
base: TypeExpr
|
||||
constraint: Optional[Expr]
|
||||
params: list[Param]
|
||||
type: Type
|
||||
|
||||
|
||||
class ComplexTypeStmt:
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
properties: list[PropertyStmt]
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Param:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
|
||||
class PropertyStmt:
|
||||
name: Token
|
||||
type: TypeExpr
|
||||
constraint: Optional[Expr]
|
||||
type: Type
|
||||
|
||||
|
||||
class ExtendStmt:
|
||||
type: TypeExpr
|
||||
type: Type
|
||||
operations: list[OpStmt]
|
||||
|
||||
|
||||
class OpStmt:
|
||||
name: Token
|
||||
operand: TypeExpr
|
||||
result: TypeExpr
|
||||
operand: Type
|
||||
result: Type
|
||||
|
||||
|
||||
class PredicateStmt:
|
||||
name: Token
|
||||
subject: Token
|
||||
type: TypeExpr
|
||||
type: Type
|
||||
condition: Expr
|
||||
|
||||
|
||||
@@ -54,9 +52,6 @@ class PredicateStmt:
|
||||
|
||||
|
||||
###> Expr | Expressions
|
||||
class SimpleTypeExpr:
|
||||
name: Token
|
||||
optional: bool
|
||||
|
||||
|
||||
class LogicalExpr:
|
||||
@@ -97,14 +92,31 @@ class WildcardExpr:
|
||||
token: Token
|
||||
|
||||
|
||||
class TemplateExpr:
|
||||
type: TypeExpr
|
||||
###<
|
||||
|
||||
###> Type | Types
|
||||
|
||||
|
||||
class TypeExpr:
|
||||
class NamedType:
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
optional: bool
|
||||
|
||||
|
||||
class GenericType:
|
||||
type: Type
|
||||
params: list[Type]
|
||||
|
||||
|
||||
class ConstraintType:
|
||||
type: Type
|
||||
constraint: Expr
|
||||
|
||||
|
||||
class UnionType:
|
||||
types: list[Type]
|
||||
|
||||
|
||||
class ComplexType:
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
@@ -28,10 +28,7 @@ class Stmt(ABC):
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> T: ...
|
||||
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
|
||||
@@ -47,31 +44,25 @@ class Stmt(ABC):
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SimpleTypeStmt(Stmt):
|
||||
class TypeStmt(Stmt):
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
base: TypeExpr
|
||||
constraint: Optional[Expr]
|
||||
params: list[Param]
|
||||
type: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Param:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_simple_type_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ComplexTypeStmt(Stmt):
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_complex_type_stmt(self)
|
||||
return visitor.visit_type_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PropertyStmt(Stmt):
|
||||
name: Token
|
||||
type: TypeExpr
|
||||
constraint: Optional[Expr]
|
||||
type: Type
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_property_stmt(self)
|
||||
@@ -79,7 +70,7 @@ class PropertyStmt(Stmt):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtendStmt(Stmt):
|
||||
type: TypeExpr
|
||||
type: Type
|
||||
operations: list[OpStmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
@@ -89,8 +80,8 @@ class ExtendStmt(Stmt):
|
||||
@dataclass(frozen=True)
|
||||
class OpStmt(Stmt):
|
||||
name: Token
|
||||
operand: TypeExpr
|
||||
result: TypeExpr
|
||||
operand: Type
|
||||
result: Type
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_op_stmt(self)
|
||||
@@ -100,7 +91,7 @@ class OpStmt(Stmt):
|
||||
class PredicateStmt(Stmt):
|
||||
name: Token
|
||||
subject: Token
|
||||
type: TypeExpr
|
||||
type: Type
|
||||
condition: Expr
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
@@ -120,9 +111,6 @@ class Expr(ABC):
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||
|
||||
@@ -147,21 +135,6 @@ class Expr(ABC):
|
||||
@abstractmethod
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_template_expr(self, expr: TemplateExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_type_expr(self, expr: TypeExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SimpleTypeExpr(Expr):
|
||||
name: Token
|
||||
optional: bool
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_simple_type_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LogicalExpr(Expr):
|
||||
@@ -233,19 +206,72 @@ class WildcardExpr(Expr):
|
||||
return visitor.visit_wildcard_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TemplateExpr(Expr):
|
||||
type: TypeExpr
|
||||
#########
|
||||
# Types #
|
||||
#########
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_template_expr(self)
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Type(ABC):
|
||||
location: Location
|
||||
|
||||
@abstractmethod
|
||||
def accept(self, visitor: Visitor[T]) -> T: ...
|
||||
|
||||
class Visitor(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def visit_named_type(self, type: NamedType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_generic_type(self, type: GenericType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_constraint_type(self, type: ConstraintType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_union_type(self, type: UnionType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_complex_type(self, type: ComplexType) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeExpr(Expr):
|
||||
class NamedType(Type):
|
||||
name: Token
|
||||
template: Optional[TemplateExpr]
|
||||
optional: bool
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_type_expr(self)
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_named_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenericType(Type):
|
||||
type: Type
|
||||
params: list[Type]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_generic_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstraintType(Type):
|
||||
type: Type
|
||||
constraint: Expr
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_constraint_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UnionType(Type):
|
||||
types: list[Type]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_union_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ComplexType(Type):
|
||||
properties: list[PropertyStmt]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_complex_type(self)
|
||||
|
||||
@@ -85,40 +85,39 @@ class AstPrinter(Generic[T]):
|
||||
child.accept(self)
|
||||
|
||||
|
||||
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
||||
class MidasAstPrinter(
|
||||
AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None], m.Type.Visitor[None]
|
||||
):
|
||||
# Statements
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
|
||||
self._write_line("SimpleTypeStmt")
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
self._write_line("TypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_optional_child("template", stmt.template)
|
||||
self._write_line("base")
|
||||
with self._child_level(single=True):
|
||||
stmt.base.accept(self)
|
||||
self._write_optional_child("constraint", stmt.constraint, last=True)
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
|
||||
self._write_line("ComplexTypeStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_optional_child("template", stmt.template)
|
||||
self._write_line("properties", last=True)
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, prop in enumerate(stmt.properties):
|
||||
for i, param in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.properties) - 1:
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
prop.accept(self)
|
||||
self._print_type_stmt_param(param)
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
|
||||
self._write_line("Param")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{param.name.lexeme}"')
|
||||
self._write_optional_child("bound", param.bound, last=True)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
self._write_line("PropertyStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type")
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
self._write_optional_child("constraint", stmt.constraint, last=True)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._write_line("ExtendStmt")
|
||||
@@ -161,12 +160,6 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
||||
|
||||
# Expressions
|
||||
|
||||
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
|
||||
self._write_line("SimpleTypeExpr")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||
self._write_line(f"optional: {expr.optional}", last=True)
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
self._write_line("LogicalExpr")
|
||||
with self._child_level():
|
||||
@@ -230,22 +223,59 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self._write_line("WildcardExpr")
|
||||
|
||||
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
|
||||
self._write_line("TemplateExpr")
|
||||
with self._child_level(single=True):
|
||||
def visit_named_type(self, type: m.NamedType) -> None:
|
||||
self._write_line("NamedType")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{type.name.lexeme}"', last=True)
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self._write_line("GenericType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level():
|
||||
type.type.accept(self)
|
||||
self._write_line("params", last=True)
|
||||
with self._child_level():
|
||||
for i, param in enumerate(type.params):
|
||||
self._idx = i
|
||||
if i == len(type.params) - 1:
|
||||
self._mark_last()
|
||||
param.accept(self)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self._write_line("ConstraintType")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
expr.type.accept(self)
|
||||
type.type.accept(self)
|
||||
self._write_line("constraint", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.constraint.accept(self)
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr):
|
||||
self._write_line("TypeExpr")
|
||||
def visit_union_type(self, type: m.UnionType) -> None:
|
||||
self._write_line("UnionType")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{expr.name.lexeme}"')
|
||||
self._write_optional_child("template", expr.template)
|
||||
self._write_line(f"optional: {expr.optional}", last=True)
|
||||
self._write_line("types", last=True)
|
||||
with self._child_level():
|
||||
for i, type_ in enumerate(type.types):
|
||||
self._idx = i
|
||||
if i == len(type.types) - 1:
|
||||
self._mark_last()
|
||||
type_.accept(self)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||
self._write_line("ComplexType")
|
||||
with self._child_level():
|
||||
self._write_line("properties", last=True)
|
||||
with self._child_level():
|
||||
for i, prop in enumerate(type.properties):
|
||||
self._idx = i
|
||||
if i == len(type.properties) - 1:
|
||||
self._mark_last()
|
||||
prop.accept(self)
|
||||
|
||||
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||
def __init__(self, indent: int = 4):
|
||||
self.indent: int = indent
|
||||
self.level: int = 0
|
||||
@@ -257,29 +287,24 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
self.level = 0
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
|
||||
template: str = stmt.template.accept(self) if stmt.template is not None else ""
|
||||
res: str = f"type {stmt.name.lexeme}{template}({stmt.base.accept(self)})"
|
||||
if stmt.constraint is not None:
|
||||
res += " where " + stmt.constraint.accept(self)
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [
|
||||
self._print_type_template_param(param) for param in stmt.params
|
||||
]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
|
||||
template: str = stmt.template.accept(self) if stmt.template is not None else ""
|
||||
res: str = self.indented(f"type {stmt.name.lexeme}{template}")
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
for prop in stmt.properties:
|
||||
res += prop.accept(self)
|
||||
res += "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
|
||||
res: str = param.name.lexeme
|
||||
if param.bound is not None:
|
||||
res += "<:" + param.bound.accept(self)
|
||||
return res
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
if stmt.constraint is not None:
|
||||
res += " where " + stmt.constraint.accept(self)
|
||||
return self.indented(res)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt):
|
||||
@@ -304,9 +329,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
condition: str = stmt.condition.accept(self)
|
||||
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
|
||||
|
||||
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
|
||||
return f"{expr.name.lexeme}{'?' if expr.optional else ''}"
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||
left: str = expr.left.accept(self)
|
||||
operator: str = expr.operator.lexeme
|
||||
@@ -342,12 +364,34 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr):
|
||||
return "_"
|
||||
|
||||
def visit_template_expr(self, expr: m.TemplateExpr):
|
||||
return f"[{expr.type.accept(self)}]"
|
||||
def visit_named_type(self, type: m.NamedType) -> str:
|
||||
return type.name.lexeme
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr):
|
||||
template: str = expr.template.accept(self) if expr.template is not None else ""
|
||||
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}"
|
||||
def visit_generic_type(self, type: m.GenericType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
if len(type.params) != 0:
|
||||
params: list[str] = [param.accept(self) for param in type.params]
|
||||
res += f"[{', '.join(params)}]"
|
||||
return res
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
res += " where " + type.constraint.accept(self)
|
||||
return res
|
||||
|
||||
def visit_union_type(self, type: m.UnionType) -> str:
|
||||
types: list[str] = [type_.accept(self) for type_ in type.types]
|
||||
return " | ".join(types)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> str:
|
||||
res: str = "{\n"
|
||||
self.level += 1
|
||||
for prop in type.properties:
|
||||
res += prop.accept(self)
|
||||
res += "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
|
||||
class PythonAstPrinter(
|
||||
@@ -600,11 +644,11 @@ class PythonAstPrinter(
|
||||
self._write_line("test")
|
||||
with self._child_level(single=True):
|
||||
expr.test.accept(self)
|
||||
|
||||
|
||||
self._write_line("if_true")
|
||||
with self._child_level(single=True):
|
||||
expr.if_true.accept(self)
|
||||
|
||||
|
||||
self._write_line("if_false", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.if_false.accept(self)
|
||||
|
||||
@@ -9,9 +9,9 @@ class BaseType:
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class SimpleType:
|
||||
class AliasType:
|
||||
name: str
|
||||
base: BaseType | SimpleType
|
||||
type: Type
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@@ -39,4 +39,16 @@ class Function:
|
||||
required: bool
|
||||
|
||||
|
||||
Type = BaseType | SimpleType | UnknownType | UnitType | Function
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ComplexType:
|
||||
properties: dict[str, Type]
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnionType:
|
||||
alternatives: list[Type]
|
||||
|
||||
|
||||
Type = (
|
||||
BaseType | AliasType | UnknownType | UnitType | Function | ComplexType | UnionType
|
||||
)
|
||||
|
||||
@@ -18,6 +18,8 @@ class MidasLexer(Lexer):
|
||||
self.add_token(TokenType.LEFT_BRACE)
|
||||
case "}":
|
||||
self.add_token(TokenType.RIGHT_BRACE)
|
||||
case "|":
|
||||
self.add_token(TokenType.PIPE)
|
||||
case "<":
|
||||
self.add_token(
|
||||
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS
|
||||
@@ -40,8 +42,8 @@ class MidasLexer(Lexer):
|
||||
self.add_token(TokenType.AND)
|
||||
case "?":
|
||||
self.add_token(TokenType.QMARK)
|
||||
# case ",":
|
||||
# self.add_token(TokenType.COMMA)
|
||||
case ",":
|
||||
self.add_token(TokenType.COMMA)
|
||||
case "_" if not self.is_identifier_char(self.peek_next(), start=False):
|
||||
self.add_token(TokenType.UNDERSCORE)
|
||||
case "-" if self.match(">"):
|
||||
|
||||
@@ -17,12 +17,13 @@ class TokenType(Enum):
|
||||
LEFT_BRACE = auto()
|
||||
RIGHT_BRACE = auto()
|
||||
COLON = auto()
|
||||
# COMMA = auto()
|
||||
COMMA = auto()
|
||||
UNDERSCORE = auto()
|
||||
ARROW = auto()
|
||||
AND = auto()
|
||||
QMARK = auto()
|
||||
DOT = auto()
|
||||
PIPE = auto()
|
||||
|
||||
# Operators
|
||||
# PLUS = auto()
|
||||
|
||||
@@ -3,22 +3,24 @@ from typing import Optional
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.midas import (
|
||||
BinaryExpr,
|
||||
ComplexTypeStmt,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
NamedType,
|
||||
OpStmt,
|
||||
PredicateStmt,
|
||||
PropertyStmt,
|
||||
SimpleTypeExpr,
|
||||
SimpleTypeStmt,
|
||||
Stmt,
|
||||
TemplateExpr,
|
||||
TypeExpr,
|
||||
Type,
|
||||
TypeStmt,
|
||||
UnaryExpr,
|
||||
UnionType,
|
||||
VariableExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
@@ -81,7 +83,7 @@ class MidasParser(Parser):
|
||||
self.synchronize()
|
||||
return None
|
||||
|
||||
def type_declaration(self) -> SimpleTypeStmt | ComplexTypeStmt:
|
||||
def type_declaration(self) -> TypeStmt:
|
||||
"""Parse a type declaration
|
||||
|
||||
A type declaration can either be a simple type alias or a new complex type.
|
||||
@@ -107,33 +109,22 @@ class MidasParser(Parser):
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
template: Optional[TemplateExpr] = None
|
||||
params: list[TypeStmt.Param] = []
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
template = self.template_expr()
|
||||
params = self.type_stmt_params()
|
||||
|
||||
if self.match(TokenType.LEFT_PAREN):
|
||||
base: TypeExpr = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Unclosed base type parenthesis")
|
||||
constraint: Optional[Expr] = None
|
||||
if self.match(TokenType.WHERE):
|
||||
constraint = self.constraint()
|
||||
return SimpleTypeStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
template=template,
|
||||
base=base,
|
||||
constraint=constraint,
|
||||
)
|
||||
else:
|
||||
properties: list[PropertyStmt] = self.type_properties()
|
||||
return ComplexTypeStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
template=template,
|
||||
properties=properties,
|
||||
)
|
||||
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
||||
|
||||
def template_expr(self) -> TemplateExpr:
|
||||
type: Type = self.type_expr()
|
||||
|
||||
return TypeStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
name=name,
|
||||
params=params,
|
||||
type=type,
|
||||
)
|
||||
|
||||
def type_stmt_params(self) -> list[TypeStmt.Param]:
|
||||
"""Parse a generic template expression
|
||||
|
||||
A template is written `[TypeExpr]`
|
||||
@@ -141,16 +132,27 @@ class MidasParser(Parser):
|
||||
Returns:
|
||||
TemplateExpr: the parsed template expression
|
||||
"""
|
||||
left: Token = self.consume(
|
||||
TokenType.LEFT_BRACKET, "Missing '[' before template expression"
|
||||
)
|
||||
type: TypeExpr = self.type_expr()
|
||||
right: Token = self.consume(
|
||||
TokenType.RIGHT_BRACKET, "Missing ']' after template expression"
|
||||
)
|
||||
return TemplateExpr(location=left.location_to(right), type=type)
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
|
||||
params: list[TypeStmt.Param] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
|
||||
bound: Optional[Type] = None
|
||||
if self.match(TokenType.LESS):
|
||||
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
||||
bound = self.type_expr()
|
||||
params.append(
|
||||
TypeStmt.Param(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
bound=bound,
|
||||
)
|
||||
)
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
|
||||
return params
|
||||
|
||||
def type_expr(self) -> TypeExpr:
|
||||
def type_expr(self) -> Type:
|
||||
"""Parse a type expression
|
||||
|
||||
A type is an identifier, optionally followed by a template expression.
|
||||
@@ -159,30 +161,93 @@ class MidasParser(Parser):
|
||||
Returns:
|
||||
TypeExpr: the parsed type expression
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
template: Optional[TemplateExpr] = None
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
template = self.template_expr()
|
||||
optional: bool = self.match(TokenType.QMARK)
|
||||
return TypeExpr(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
template=template,
|
||||
optional=optional,
|
||||
return self.union_type()
|
||||
|
||||
def union_type(self) -> Type:
|
||||
types: list[Type] = [self.constraint_type()]
|
||||
while self.match(TokenType.PIPE):
|
||||
types.append(self.constraint_type())
|
||||
if len(types) == 1:
|
||||
return types[0]
|
||||
return UnionType(
|
||||
location=Location.span(types[0].location, types[-1].location),
|
||||
types=types,
|
||||
)
|
||||
|
||||
def simple_type_expr(self) -> SimpleTypeExpr:
|
||||
"""Parse a simple type expression
|
||||
def constraint_type(self) -> Type:
|
||||
type: Type = self.base_type()
|
||||
if self.match(TokenType.WHERE):
|
||||
constraint: Expr = self.constraint()
|
||||
return ConstraintType(
|
||||
location=Location.span(type.location, constraint.location),
|
||||
type=type,
|
||||
constraint=constraint,
|
||||
)
|
||||
return type
|
||||
|
||||
A simple type is just an identifier optionally followed by a '?'
|
||||
def base_type(self) -> Type:
|
||||
if self.match(TokenType.LEFT_PAREN):
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
|
||||
return type
|
||||
|
||||
if self.check(TokenType.LEFT_BRACE):
|
||||
return self.complex_type()
|
||||
|
||||
return self.generic_type()
|
||||
|
||||
def generic_type(self) -> Type:
|
||||
type: Type = self.named_type()
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
params: list[Type] = self.type_params()
|
||||
return GenericType(
|
||||
location=Location.span(type.location, self.previous().get_location()),
|
||||
type=type,
|
||||
params=params,
|
||||
)
|
||||
return type
|
||||
|
||||
def type_params(self) -> list[Type]:
|
||||
params: list[Type] = []
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
params.append(self.type_expr())
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
|
||||
return params
|
||||
|
||||
def named_type(self) -> Type:
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
return NamedType(
|
||||
location=name.get_location(),
|
||||
name=name,
|
||||
)
|
||||
|
||||
def complex_type(self) -> Type:
|
||||
"""Parse a type definition body
|
||||
|
||||
A type definition body is a set of whitespace-separated
|
||||
property statements enclosed in curly braces
|
||||
|
||||
Returns:
|
||||
SimpleTypeExpr: the parsed simple type expression
|
||||
list[PropertyStmt]: the parsed type properties
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
optional: bool = self.match(TokenType.QMARK)
|
||||
return SimpleTypeExpr(
|
||||
location=name.location_to(self.previous()), name=name, optional=optional
|
||||
left: Token = self.consume(
|
||||
TokenType.LEFT_BRACE, "Expected '{' to start type body"
|
||||
)
|
||||
properties: list[PropertyStmt] = []
|
||||
names: set[str] = set()
|
||||
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
||||
prop: PropertyStmt = self.property_stmt()
|
||||
if prop.name.lexeme in names:
|
||||
raise self.error(prop.name, "Duplicate property")
|
||||
names.add(prop.name.lexeme)
|
||||
properties.append(prop)
|
||||
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
||||
return ComplexType(
|
||||
location=left.location_to(right),
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
def constraint(self) -> Expr:
|
||||
@@ -308,27 +373,6 @@ class MidasParser(Parser):
|
||||
|
||||
raise self.error(self.peek(), "Expected expression")
|
||||
|
||||
def type_properties(self) -> list[PropertyStmt]:
|
||||
"""Parse a type definition body
|
||||
|
||||
A type definition body is a set of whitespace-separated
|
||||
property statements enclosed in curly braces
|
||||
|
||||
Returns:
|
||||
list[PropertyStmt]: the parsed type properties
|
||||
"""
|
||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body")
|
||||
properties: list[PropertyStmt] = []
|
||||
names: set[str] = set()
|
||||
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
||||
prop: PropertyStmt = self.property_stmt()
|
||||
if prop.name.lexeme in names:
|
||||
raise self.error(prop.name, "Duplicate property")
|
||||
names.add(prop.name.lexeme)
|
||||
properties.append(prop)
|
||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
||||
return properties
|
||||
|
||||
def property_stmt(self) -> PropertyStmt:
|
||||
"""Parse a property statement
|
||||
|
||||
@@ -339,15 +383,11 @@ class MidasParser(Parser):
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after property name")
|
||||
type: TypeExpr = self.type_expr()
|
||||
constraint: Optional[Expr] = None
|
||||
if self.match(TokenType.WHERE):
|
||||
constraint = self.constraint()
|
||||
type: Type = self.type_expr()
|
||||
return PropertyStmt(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
type=type,
|
||||
constraint=constraint,
|
||||
)
|
||||
|
||||
def extend_declaration(self) -> ExtendStmt:
|
||||
@@ -359,7 +399,7 @@ class MidasParser(Parser):
|
||||
ExtendStmt: the parsed extension statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
type: TypeExpr = self.type_expr()
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
||||
operations: list[OpStmt] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
|
||||
@@ -380,11 +420,11 @@ class MidasParser(Parser):
|
||||
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
|
||||
operand: TypeExpr = self.type_expr()
|
||||
operand: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: TypeExpr = self.type_expr()
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return OpStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
@@ -406,7 +446,7 @@ class MidasParser(Parser):
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
||||
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
||||
type: TypeExpr = self.type_expr()
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
||||
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
||||
condition: Expr = self.constraint()
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.types import BaseType, SimpleType, Type
|
||||
from midas.checker.types import (
|
||||
Type,
|
||||
UnionType,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.resolver.builtin import define_builtins
|
||||
|
||||
|
||||
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]):
|
||||
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -94,20 +98,12 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]):
|
||||
for stmt in stmts:
|
||||
stmt.accept(self)
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
|
||||
# TODO generics, optional, constraint
|
||||
base: Type = self.get_type(stmt.base.name.lexeme)
|
||||
match base:
|
||||
case BaseType() | SimpleType():
|
||||
type = SimpleType(
|
||||
name=stmt.name.lexeme,
|
||||
base=base,
|
||||
)
|
||||
self.define_type(type.name, type)
|
||||
case _:
|
||||
raise TypeError(f"Invalid base {base} for simple type")
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None: ...
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
type: Type = stmt.type.accept(self)
|
||||
for param in stmt.params:
|
||||
if param.bound is not None:
|
||||
param.bound.accept(self)
|
||||
self.define_type(stmt.name.lexeme, type)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
||||
|
||||
@@ -127,27 +123,44 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]):
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
|
||||
|
||||
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> Type:
|
||||
return self.get_type(expr.name.lexeme)
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type: ...
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type: ...
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: ...
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> Type: ...
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> Type: ...
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||
return expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type: ...
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type: ...
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||
|
||||
def visit_template_expr(self, expr: m.TemplateExpr) -> Type: ...
|
||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||
return self.get_type(type.name.lexeme)
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr) -> Type:
|
||||
return self.get_type(expr.name.lexeme)
|
||||
def visit_generic_type(self, type: m.GenericType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
params: list[Type] = [param.accept(self) for param in type.params]
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
type.constraint.accept(self)
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
def visit_union_type(self, type: m.UnionType) -> Type:
|
||||
types: list[Type] = [type_.accept(self) for type_ in type.types]
|
||||
return UnionType(alternatives=types)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> Type:
|
||||
for prop in type.properties:
|
||||
prop.accept(self)
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
Reference in New Issue
Block a user