Compare commits

...

4 Commits

9 changed files with 445 additions and 262 deletions

View File

@@ -0,0 +1,33 @@
type Foo1 = float
type Foo2 = float where (_ > 3)
type Foo3 = int | float
type Foo4 = int where (_ > 3) | float where (_ > 3)
type Foo5 = (int | float) where (_ > 3)
type Foo6 = {
foo: float
bar: float where (_ > 3)
}
type Foo7[T] = T where (_ > 3)
type Foo8[A, B<:int] = {
a: A
b: B
}
type Complex = {
a: int
b: int
}
type Complex2 = Complex where (_.a > 3 & _.b < 5)
predicate Positive(n: int) = n >= 0
extend Foo1 {
op __add__(Foo1) -> Foo1
}
extend Foo7[T] {
op __add__(Foo7[T]) -> Foo7[T]
}
type Optional[T] = None | T

View File

@@ -13,40 +13,38 @@ from midas.lexer.token import Token
###> Stmt | Statements ###> Stmt | Statements
class SimpleTypeStmt: class TypeStmt:
name: Token name: Token
template: Optional[TemplateExpr] params: list[Param]
base: TypeExpr type: Type
constraint: Optional[Expr]
@dataclass(frozen=True, kw_only=True)
class ComplexTypeStmt: class Param:
location: Location
name: Token name: Token
template: Optional[TemplateExpr] bound: Optional[Type]
properties: list[PropertyStmt]
class PropertyStmt: class PropertyStmt:
name: Token name: Token
type: TypeExpr type: Type
constraint: Optional[Expr]
class ExtendStmt: class ExtendStmt:
type: TypeExpr type: Type
operations: list[OpStmt] operations: list[OpStmt]
class OpStmt: class OpStmt:
name: Token name: Token
operand: TypeExpr operand: Type
result: TypeExpr result: Type
class PredicateStmt: class PredicateStmt:
name: Token name: Token
subject: Token subject: Token
type: TypeExpr type: Type
condition: Expr condition: Expr
@@ -54,9 +52,6 @@ class PredicateStmt:
###> Expr | Expressions ###> Expr | Expressions
class SimpleTypeExpr:
name: Token
optional: bool
class LogicalExpr: class LogicalExpr:
@@ -97,14 +92,31 @@ class WildcardExpr:
token: Token token: Token
class TemplateExpr: ###<
type: TypeExpr
###> Type | Types
class TypeExpr: class NamedType:
name: Token name: Token
template: Optional[TemplateExpr]
optional: bool
class GenericType:
type: Type
params: list[Type]
class ConstraintType:
type: Type
constraint: Expr
class UnionType:
types: list[Type]
class ComplexType:
properties: list[PropertyStmt]
###< ###<

View File

@@ -28,10 +28,7 @@ class Stmt(ABC):
class Visitor(ABC, Generic[T]): class Visitor(ABC, Generic[T]):
@abstractmethod @abstractmethod
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> T: ... def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
@abstractmethod
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> T: ...
@abstractmethod @abstractmethod
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ... def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
@@ -47,31 +44,25 @@ class Stmt(ABC):
@dataclass(frozen=True) @dataclass(frozen=True)
class SimpleTypeStmt(Stmt): class TypeStmt(Stmt):
name: Token name: Token
template: Optional[TemplateExpr] params: list[Param]
base: TypeExpr type: Type
constraint: Optional[Expr]
@dataclass(frozen=True, kw_only=True)
class Param:
location: Location
name: Token
bound: Optional[Type]
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_simple_type_stmt(self) return visitor.visit_type_stmt(self)
@dataclass(frozen=True)
class ComplexTypeStmt(Stmt):
name: Token
template: Optional[TemplateExpr]
properties: list[PropertyStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_complex_type_stmt(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class PropertyStmt(Stmt): class PropertyStmt(Stmt):
name: Token name: Token
type: TypeExpr type: Type
constraint: Optional[Expr]
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_property_stmt(self) return visitor.visit_property_stmt(self)
@@ -79,7 +70,7 @@ class PropertyStmt(Stmt):
@dataclass(frozen=True) @dataclass(frozen=True)
class ExtendStmt(Stmt): class ExtendStmt(Stmt):
type: TypeExpr type: Type
operations: list[OpStmt] operations: list[OpStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
@@ -89,8 +80,8 @@ class ExtendStmt(Stmt):
@dataclass(frozen=True) @dataclass(frozen=True)
class OpStmt(Stmt): class OpStmt(Stmt):
name: Token name: Token
operand: TypeExpr operand: Type
result: TypeExpr result: Type
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_op_stmt(self) return visitor.visit_op_stmt(self)
@@ -100,7 +91,7 @@ class OpStmt(Stmt):
class PredicateStmt(Stmt): class PredicateStmt(Stmt):
name: Token name: Token
subject: Token subject: Token
type: TypeExpr type: Type
condition: Expr condition: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
@@ -120,9 +111,6 @@ class Expr(ABC):
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]): class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> T: ...
@abstractmethod @abstractmethod
def visit_logical_expr(self, expr: LogicalExpr) -> T: ... def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
@@ -147,21 +135,6 @@ class Expr(ABC):
@abstractmethod @abstractmethod
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ... def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
@abstractmethod
def visit_template_expr(self, expr: TemplateExpr) -> T: ...
@abstractmethod
def visit_type_expr(self, expr: TypeExpr) -> T: ...
@dataclass(frozen=True)
class SimpleTypeExpr(Expr):
name: Token
optional: bool
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_simple_type_expr(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class LogicalExpr(Expr): class LogicalExpr(Expr):
@@ -233,19 +206,72 @@ class WildcardExpr(Expr):
return visitor.visit_wildcard_expr(self) return visitor.visit_wildcard_expr(self)
@dataclass(frozen=True) #########
class TemplateExpr(Expr): # Types #
type: TypeExpr #########
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_template_expr(self) @dataclass(frozen=True, kw_only=True)
class Type(ABC):
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_named_type(self, type: NamedType) -> T: ...
@abstractmethod
def visit_generic_type(self, type: GenericType) -> T: ...
@abstractmethod
def visit_constraint_type(self, type: ConstraintType) -> T: ...
@abstractmethod
def visit_union_type(self, type: UnionType) -> T: ...
@abstractmethod
def visit_complex_type(self, type: ComplexType) -> T: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class TypeExpr(Expr): class NamedType(Type):
name: Token name: Token
template: Optional[TemplateExpr]
optional: bool
def accept(self, visitor: Expr.Visitor[T]) -> T: def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_type_expr(self) return visitor.visit_named_type(self)
@dataclass(frozen=True)
class GenericType(Type):
type: Type
params: list[Type]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_generic_type(self)
@dataclass(frozen=True)
class ConstraintType(Type):
type: Type
constraint: Expr
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_constraint_type(self)
@dataclass(frozen=True)
class UnionType(Type):
types: list[Type]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_union_type(self)
@dataclass(frozen=True)
class ComplexType(Type):
properties: list[PropertyStmt]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_complex_type(self)

View File

@@ -85,40 +85,39 @@ class AstPrinter(Generic[T]):
child.accept(self) child.accept(self)
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): class MidasAstPrinter(
AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None], m.Type.Visitor[None]
):
# Statements # Statements
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt): def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self._write_line("SimpleTypeStmt") self._write_line("TypeStmt")
with self._child_level(): with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_optional_child("template", stmt.template) self._write_line("params")
self._write_line("base")
with self._child_level(single=True):
stmt.base.accept(self)
self._write_optional_child("constraint", stmt.constraint, last=True)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
self._write_line("ComplexTypeStmt")
with self._child_level(): with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"') for i, param in enumerate(stmt.params):
self._write_optional_child("template", stmt.template)
self._write_line("properties", last=True)
with self._child_level():
for i, prop in enumerate(stmt.properties):
self._idx = i self._idx = i
if i == len(stmt.properties) - 1: if i == len(stmt.params) - 1:
self._mark_last() self._mark_last()
prop.accept(self) self._print_type_stmt_param(param)
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
self._write_line("Param")
with self._child_level():
self._write_line(f'name: "{param.name.lexeme}"')
self._write_optional_child("bound", param.bound, last=True)
def visit_property_stmt(self, stmt: m.PropertyStmt): def visit_property_stmt(self, stmt: m.PropertyStmt):
self._write_line("PropertyStmt") self._write_line("PropertyStmt")
with self._child_level(): with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type") self._write_line("type", last=True)
with self._child_level(single=True): with self._child_level(single=True):
stmt.type.accept(self) stmt.type.accept(self)
self._write_optional_child("constraint", stmt.constraint, last=True)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._write_line("ExtendStmt") self._write_line("ExtendStmt")
@@ -161,12 +160,6 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
# Expressions # Expressions
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
self._write_line("SimpleTypeExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_line(f"optional: {expr.optional}", last=True)
def visit_logical_expr(self, expr: m.LogicalExpr): def visit_logical_expr(self, expr: m.LogicalExpr):
self._write_line("LogicalExpr") self._write_line("LogicalExpr")
with self._child_level(): with self._child_level():
@@ -230,22 +223,59 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self._write_line("WildcardExpr") self._write_line("WildcardExpr")
def visit_template_expr(self, expr: m.TemplateExpr) -> None: def visit_named_type(self, type: m.NamedType) -> None:
self._write_line("TemplateExpr") self._write_line("NamedType")
with self._child_level(single=True): with self._child_level():
self._write_line(f'name: "{type.name.lexeme}"', last=True)
def visit_generic_type(self, type: m.GenericType) -> None:
self._write_line("GenericType")
with self._child_level():
self._write_line("type")
with self._child_level():
type.type.accept(self)
self._write_line("params", last=True)
with self._child_level():
for i, param in enumerate(type.params):
self._idx = i
if i == len(type.params) - 1:
self._mark_last()
param.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self._write_line("ConstraintType")
with self._child_level():
self._write_line("type") self._write_line("type")
with self._child_level(single=True): with self._child_level(single=True):
expr.type.accept(self) type.type.accept(self)
self._write_line("constraint", last=True)
with self._child_level(single=True):
type.constraint.accept(self)
def visit_type_expr(self, expr: m.TypeExpr): def visit_union_type(self, type: m.UnionType) -> None:
self._write_line("TypeExpr") self._write_line("UnionType")
with self._child_level(): with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"') self._write_line("types", last=True)
self._write_optional_child("template", expr.template) with self._child_level():
self._write_line(f"optional: {expr.optional}", last=True) for i, type_ in enumerate(type.types):
self._idx = i
if i == len(type.types) - 1:
self._mark_last()
type_.accept(self)
def visit_complex_type(self, type: m.ComplexType) -> None:
self._write_line("ComplexType")
with self._child_level():
self._write_line("properties", last=True)
with self._child_level():
for i, prop in enumerate(type.properties):
self._idx = i
if i == len(type.properties) - 1:
self._mark_last()
prop.accept(self)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
def __init__(self, indent: int = 4): def __init__(self, indent: int = 4):
self.indent: int = indent self.indent: int = indent
self.level: int = 0 self.level: int = 0
@@ -257,29 +287,24 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
self.level = 0 self.level = 0
return expr.accept(self) return expr.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt): def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
template: str = stmt.template.accept(self) if stmt.template is not None else "" template: str = ""
res: str = f"type {stmt.name.lexeme}{template}({stmt.base.accept(self)})" if len(stmt.params) != 0:
if stmt.constraint is not None: params: list[str] = [
res += " where " + stmt.constraint.accept(self) self._print_type_template_param(param) for param in stmt.params
]
template = f"[{', '.join(params)}]"
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
return self.indented(res) return self.indented(res)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt): def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
template: str = stmt.template.accept(self) if stmt.template is not None else "" res: str = param.name.lexeme
res: str = self.indented(f"type {stmt.name.lexeme}{template}") if param.bound is not None:
res += " {\n" res += "<:" + param.bound.accept(self)
self.level += 1
for prop in stmt.properties:
res += prop.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
return res return res
def visit_property_stmt(self, stmt: m.PropertyStmt): def visit_property_stmt(self, stmt: m.PropertyStmt):
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}" res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
if stmt.constraint is not None:
res += " where " + stmt.constraint.accept(self)
return self.indented(res) return self.indented(res)
def visit_extend_stmt(self, stmt: m.ExtendStmt): def visit_extend_stmt(self, stmt: m.ExtendStmt):
@@ -304,9 +329,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
condition: str = stmt.condition.accept(self) condition: str = stmt.condition.accept(self)
return self.indented(f"predicate {name}({subject}: {type}) = {condition}") return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
return f"{expr.name.lexeme}{'?' if expr.optional else ''}"
def visit_logical_expr(self, expr: m.LogicalExpr): def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self) left: str = expr.left.accept(self)
operator: str = expr.operator.lexeme operator: str = expr.operator.lexeme
@@ -342,12 +364,34 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def visit_wildcard_expr(self, expr: m.WildcardExpr): def visit_wildcard_expr(self, expr: m.WildcardExpr):
return "_" return "_"
def visit_template_expr(self, expr: m.TemplateExpr): def visit_named_type(self, type: m.NamedType) -> str:
return f"[{expr.type.accept(self)}]" return type.name.lexeme
def visit_type_expr(self, expr: m.TypeExpr): def visit_generic_type(self, type: m.GenericType) -> str:
template: str = expr.template.accept(self) if expr.template is not None else "" res: str = type.type.accept(self)
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}" if len(type.params) != 0:
params: list[str] = [param.accept(self) for param in type.params]
res += f"[{', '.join(params)}]"
return res
def visit_constraint_type(self, type: m.ConstraintType) -> str:
res: str = type.type.accept(self)
res += " where " + type.constraint.accept(self)
return res
def visit_union_type(self, type: m.UnionType) -> str:
types: list[str] = [type_.accept(self) for type_ in type.types]
return " | ".join(types)
def visit_complex_type(self, type: m.ComplexType) -> str:
res: str = "{\n"
self.level += 1
for prop in type.properties:
res += prop.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
return res
class PythonAstPrinter( class PythonAstPrinter(

View File

@@ -9,9 +9,9 @@ class BaseType:
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class SimpleType: class AliasType:
name: str name: str
base: BaseType | SimpleType type: Type
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
@@ -39,4 +39,16 @@ class Function:
required: bool required: bool
Type = BaseType | SimpleType | UnknownType | UnitType | Function @dataclass(frozen=True, kw_only=True)
class ComplexType:
properties: dict[str, Type]
@dataclass(frozen=True, kw_only=True)
class UnionType:
alternatives: list[Type]
Type = (
BaseType | AliasType | UnknownType | UnitType | Function | ComplexType | UnionType
)

View File

@@ -18,6 +18,8 @@ class MidasLexer(Lexer):
self.add_token(TokenType.LEFT_BRACE) self.add_token(TokenType.LEFT_BRACE)
case "}": case "}":
self.add_token(TokenType.RIGHT_BRACE) self.add_token(TokenType.RIGHT_BRACE)
case "|":
self.add_token(TokenType.PIPE)
case "<": case "<":
self.add_token( self.add_token(
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS
@@ -40,8 +42,8 @@ class MidasLexer(Lexer):
self.add_token(TokenType.AND) self.add_token(TokenType.AND)
case "?": case "?":
self.add_token(TokenType.QMARK) self.add_token(TokenType.QMARK)
# case ",": case ",":
# self.add_token(TokenType.COMMA) self.add_token(TokenType.COMMA)
case "_" if not self.is_identifier_char(self.peek_next(), start=False): case "_" if not self.is_identifier_char(self.peek_next(), start=False):
self.add_token(TokenType.UNDERSCORE) self.add_token(TokenType.UNDERSCORE)
case "-" if self.match(">"): case "-" if self.match(">"):

View File

@@ -17,12 +17,13 @@ class TokenType(Enum):
LEFT_BRACE = auto() LEFT_BRACE = auto()
RIGHT_BRACE = auto() RIGHT_BRACE = auto()
COLON = auto() COLON = auto()
# COMMA = auto() COMMA = auto()
UNDERSCORE = auto() UNDERSCORE = auto()
ARROW = auto() ARROW = auto()
AND = auto() AND = auto()
QMARK = auto() QMARK = auto()
DOT = auto() DOT = auto()
PIPE = auto()
# Operators # Operators
# PLUS = auto() # PLUS = auto()

View File

@@ -3,22 +3,24 @@ from typing import Optional
from midas.ast.location import Location from midas.ast.location import Location
from midas.ast.midas import ( from midas.ast.midas import (
BinaryExpr, BinaryExpr,
ComplexTypeStmt, ComplexType,
ConstraintType,
Expr, Expr,
ExtendStmt, ExtendStmt,
GenericType,
GetExpr, GetExpr,
GroupingExpr, GroupingExpr,
LiteralExpr, LiteralExpr,
LogicalExpr, LogicalExpr,
NamedType,
OpStmt, OpStmt,
PredicateStmt, PredicateStmt,
PropertyStmt, PropertyStmt,
SimpleTypeExpr,
SimpleTypeStmt,
Stmt, Stmt,
TemplateExpr, Type,
TypeExpr, TypeStmt,
UnaryExpr, UnaryExpr,
UnionType,
VariableExpr, VariableExpr,
WildcardExpr, WildcardExpr,
) )
@@ -81,7 +83,7 @@ class MidasParser(Parser):
self.synchronize() self.synchronize()
return None return None
def type_declaration(self) -> SimpleTypeStmt | ComplexTypeStmt: def type_declaration(self) -> TypeStmt:
"""Parse a type declaration """Parse a type declaration
A type declaration can either be a simple type alias or a new complex type. A type declaration can either be a simple type alias or a new complex type.
@@ -107,33 +109,22 @@ class MidasParser(Parser):
""" """
keyword: Token = self.previous() keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
template: Optional[TemplateExpr] = None params: list[TypeStmt.Param] = []
if self.check(TokenType.LEFT_BRACKET): if self.check(TokenType.LEFT_BRACKET):
template = self.template_expr() params = self.type_stmt_params()
if self.match(TokenType.LEFT_PAREN): self.consume(TokenType.EQUAL, "Expected '=' before type definition")
base: TypeExpr = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Unclosed base type parenthesis") type: Type = self.type_expr()
constraint: Optional[Expr] = None
if self.match(TokenType.WHERE): return TypeStmt(
constraint = self.constraint()
return SimpleTypeStmt(
location=keyword.location_to(self.previous()), location=keyword.location_to(self.previous()),
name=name, name=name,
template=template, params=params,
base=base, type=type,
constraint=constraint,
)
else:
properties: list[PropertyStmt] = self.type_properties()
return ComplexTypeStmt(
location=keyword.location_to(self.previous()),
name=name,
template=template,
properties=properties,
) )
def template_expr(self) -> TemplateExpr: def type_stmt_params(self) -> list[TypeStmt.Param]:
"""Parse a generic template expression """Parse a generic template expression
A template is written `[TypeExpr]` A template is written `[TypeExpr]`
@@ -141,16 +132,27 @@ class MidasParser(Parser):
Returns: Returns:
TemplateExpr: the parsed template expression TemplateExpr: the parsed template expression
""" """
left: Token = self.consume( self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
TokenType.LEFT_BRACKET, "Missing '[' before template expression" params: list[TypeStmt.Param] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
bound: Optional[Type] = None
if self.match(TokenType.LESS):
self.consume(TokenType.COLON, "Expected ':' after '<'")
bound = self.type_expr()
params.append(
TypeStmt.Param(
location=name.location_to(self.previous()),
name=name,
bound=bound,
) )
type: TypeExpr = self.type_expr()
right: Token = self.consume(
TokenType.RIGHT_BRACKET, "Missing ']' after template expression"
) )
return TemplateExpr(location=left.location_to(right), type=type) if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
return params
def type_expr(self) -> TypeExpr: def type_expr(self) -> Type:
"""Parse a type expression """Parse a type expression
A type is an identifier, optionally followed by a template expression. A type is an identifier, optionally followed by a template expression.
@@ -159,30 +161,93 @@ class MidasParser(Parser):
Returns: Returns:
TypeExpr: the parsed type expression TypeExpr: the parsed type expression
""" """
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") return self.union_type()
template: Optional[TemplateExpr] = None
if self.check(TokenType.LEFT_BRACKET): def union_type(self) -> Type:
template = self.template_expr() types: list[Type] = [self.constraint_type()]
optional: bool = self.match(TokenType.QMARK) while self.match(TokenType.PIPE):
return TypeExpr( types.append(self.constraint_type())
location=name.location_to(self.previous()), if len(types) == 1:
name=name, return types[0]
template=template, return UnionType(
optional=optional, location=Location.span(types[0].location, types[-1].location),
types=types,
) )
def simple_type_expr(self) -> SimpleTypeExpr: def constraint_type(self) -> Type:
"""Parse a simple type expression type: Type = self.base_type()
if self.match(TokenType.WHERE):
constraint: Expr = self.constraint()
return ConstraintType(
location=Location.span(type.location, constraint.location),
type=type,
constraint=constraint,
)
return type
A simple type is just an identifier optionally followed by a '?' def base_type(self) -> Type:
if self.match(TokenType.LEFT_PAREN):
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
return type
if self.check(TokenType.LEFT_BRACE):
return self.complex_type()
return self.generic_type()
def generic_type(self) -> Type:
type: Type = self.named_type()
if self.check(TokenType.LEFT_BRACKET):
params: list[Type] = self.type_params()
return GenericType(
location=Location.span(type.location, self.previous().get_location()),
type=type,
params=params,
)
return type
def type_params(self) -> list[Type]:
params: list[Type] = []
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
params.append(self.type_expr())
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
return params
def named_type(self) -> Type:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
return NamedType(
location=name.get_location(),
name=name,
)
def complex_type(self) -> Type:
"""Parse a type definition body
A type definition body is a set of whitespace-separated
property statements enclosed in curly braces
Returns: Returns:
SimpleTypeExpr: the parsed simple type expression list[PropertyStmt]: the parsed type properties
""" """
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") left: Token = self.consume(
optional: bool = self.match(TokenType.QMARK) TokenType.LEFT_BRACE, "Expected '{' to start type body"
return SimpleTypeExpr( )
location=name.location_to(self.previous()), name=name, optional=optional properties: list[PropertyStmt] = []
names: set[str] = set()
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
prop: PropertyStmt = self.property_stmt()
if prop.name.lexeme in names:
raise self.error(prop.name, "Duplicate property")
names.add(prop.name.lexeme)
properties.append(prop)
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
return ComplexType(
location=left.location_to(right),
properties=properties,
) )
def constraint(self) -> Expr: def constraint(self) -> Expr:
@@ -308,27 +373,6 @@ class MidasParser(Parser):
raise self.error(self.peek(), "Expected expression") raise self.error(self.peek(), "Expected expression")
def type_properties(self) -> list[PropertyStmt]:
"""Parse a type definition body
A type definition body is a set of whitespace-separated
property statements enclosed in curly braces
Returns:
list[PropertyStmt]: the parsed type properties
"""
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body")
properties: list[PropertyStmt] = []
names: set[str] = set()
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
prop: PropertyStmt = self.property_stmt()
if prop.name.lexeme in names:
raise self.error(prop.name, "Duplicate property")
names.add(prop.name.lexeme)
properties.append(prop)
self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
return properties
def property_stmt(self) -> PropertyStmt: def property_stmt(self) -> PropertyStmt:
"""Parse a property statement """Parse a property statement
@@ -339,15 +383,11 @@ class MidasParser(Parser):
""" """
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
self.consume(TokenType.COLON, "Expected ':' after property name") self.consume(TokenType.COLON, "Expected ':' after property name")
type: TypeExpr = self.type_expr() type: Type = self.type_expr()
constraint: Optional[Expr] = None
if self.match(TokenType.WHERE):
constraint = self.constraint()
return PropertyStmt( return PropertyStmt(
location=name.location_to(self.previous()), location=name.location_to(self.previous()),
name=name, name=name,
type=type, type=type,
constraint=constraint,
) )
def extend_declaration(self) -> ExtendStmt: def extend_declaration(self) -> ExtendStmt:
@@ -359,7 +399,7 @@ class MidasParser(Parser):
ExtendStmt: the parsed extension statement ExtendStmt: the parsed extension statement
""" """
keyword: Token = self.previous() keyword: Token = self.previous()
type: TypeExpr = self.type_expr() type: Type = self.type_expr()
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
operations: list[OpStmt] = [] operations: list[OpStmt] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
@@ -380,11 +420,11 @@ class MidasParser(Parser):
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type") self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
operand: TypeExpr = self.type_expr() operand: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type") self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
self.consume(TokenType.ARROW, "Expected '->' before result type") self.consume(TokenType.ARROW, "Expected '->' before result type")
result: TypeExpr = self.type_expr() result: Type = self.type_expr()
return OpStmt( return OpStmt(
location=keyword.location_to(self.previous()), location=keyword.location_to(self.previous()),
@@ -406,7 +446,7 @@ class MidasParser(Parser):
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject") self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name") subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
self.consume(TokenType.COLON, "Expected ':' after subject name") self.consume(TokenType.COLON, "Expected ':' after subject name")
type: TypeExpr = self.type_expr() type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject") self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject") self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint() condition: Expr = self.constraint()

View File

@@ -1,11 +1,15 @@
from typing import Optional from typing import Optional
import midas.ast.midas as m import midas.ast.midas as m
from midas.checker.types import BaseType, SimpleType, Type from midas.checker.types import (
Type,
UnionType,
UnknownType,
)
from midas.resolver.builtin import define_builtins from midas.resolver.builtin import define_builtins
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]): class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry""" """A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self) -> None: def __init__(self) -> None:
@@ -94,20 +98,12 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]):
for stmt in stmts: for stmt in stmts:
stmt.accept(self) stmt.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None: def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
# TODO generics, optional, constraint type: Type = stmt.type.accept(self)
base: Type = self.get_type(stmt.base.name.lexeme) for param in stmt.params:
match base: if param.bound is not None:
case BaseType() | SimpleType(): param.bound.accept(self)
type = SimpleType( self.define_type(stmt.name.lexeme, type)
name=stmt.name.lexeme,
base=base,
)
self.define_type(type.name, type)
case _:
raise TypeError(f"Invalid base {base} for simple type")
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None: ...
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
@@ -127,27 +123,44 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]):
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ... def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> Type: def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
return self.get_type(expr.name.lexeme)
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type: ... def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type: ... def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: ... def visit_get_expr(self, expr: m.GetExpr) -> None: ...
def visit_get_expr(self, expr: m.GetExpr) -> Type: ... def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
def visit_variable_expr(self, expr: m.VariableExpr) -> Type: ... def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
return expr.expr.accept(self) return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type: ... def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type: ... def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_template_expr(self, expr: m.TemplateExpr) -> Type: ... def visit_named_type(self, type: m.NamedType) -> Type:
return self.get_type(type.name.lexeme)
def visit_type_expr(self, expr: m.TypeExpr) -> Type: def visit_generic_type(self, type: m.GenericType) -> Type:
return self.get_type(expr.name.lexeme) type_: Type = type.type.accept(self)
params: list[Type] = [param.accept(self) for param in type.params]
# TODO
return UnknownType()
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
type_: Type = type.type.accept(self)
type.constraint.accept(self)
# TODO
return UnknownType()
def visit_union_type(self, type: m.UnionType) -> Type:
types: list[Type] = [type_.accept(self) for type_ in type.types]
return UnionType(alternatives=types)
def visit_complex_type(self, type: m.ComplexType) -> Type:
for prop in type.properties:
prop.accept(self)
# TODO
return UnknownType()