Merge pull request 'Improve syntax and types' (#5) from feat/improve-syntax-and-types into feat/basic-type-checker

Reviewed-on: #5
This commit was merged in pull request #5.
This commit is contained in:
2026-06-05 09:20:56 +00:00
31 changed files with 1119 additions and 881 deletions

View File

@@ -2,10 +2,6 @@
# ruff: disable[F821] # ruff: disable[F821]
from __future__ import annotations from __future__ import annotations
# Prototype of custom type import to use valid Python syntax
import midas
midas.using("02_custom_types.midas")
# A data-frame using a custom type # A data-frame using a custom type
df: Frame[ df: Frame[
location: GeoLocation location: GeoLocation

View File

@@ -0,0 +1,33 @@
type Foo1 = float
type Foo2 = float where (_ > 3)
type Foo3 = int | float
type Foo4 = int where (_ > 3) | float where (_ > 3)
type Foo5 = (int | float) where (_ > 3)
type Foo6 = {
foo: float
bar: float where (_ > 3)
}
type Foo7[T] = T where (_ > 3)
type Foo8[A, B<:int] = {
a: A
b: B
}
type Complex = {
a: int
b: int
}
type Complex2 = Complex where (_.a > 3 & _.b < 5)
predicate Positive(n: int) = n >= 0
extend Foo1 {
op __add__(Foo1) -> Foo1
}
extend Foo7[T] {
op __add__(Foo7[T]) -> Foo7[T]
}
type Optional[T] = None | T

View File

@@ -1,6 +1,6 @@
type Meter(float) type Meter = float
type Second(float) type Second = float
type MeterPerSecond(float) type MeterPerSecond = float
extend Meter { extend Meter {
op __add__(Meter) -> Meter op __add__(Meter) -> Meter

View File

@@ -1,8 +1,6 @@
# type: ignore # type: ignore
# ruff: disable [F821] # ruff: disable [F821]
midas.using("02_simple_types.midas")
distance: Meter = cast(Meter, 123.45) distance: Meter = cast(Meter, 123.45)
time: Second = cast(Second, 6.7) time: Second = cast(Second, 6.7)
speed = distance / time speed = distance / time

View File

@@ -4,13 +4,20 @@ def minimum(x: int, y: int):
else: else:
return y return y
a = 15 a = 15
b = 72 b = 72
c = minimum(a, b) c = minimum(a, b)
def factorial(n: int) -> int: def factorial(n: int) -> int:
if n <= 1: if n <= 1:
return 1 return 1
return n * factorial(n - 1) return n * factorial(n - 1)
category = "Category 1" if a < 10 else "Category 2"
category = "Category 1" if a < 10 else "Category 2"
def foo() -> None:
pass

View File

@@ -13,40 +13,38 @@ from midas.lexer.token import Token
###> Stmt | Statements ###> Stmt | Statements
class SimpleTypeStmt: class TypeStmt:
name: Token name: Token
template: Optional[TemplateExpr] params: list[Param]
base: TypeExpr type: Type
constraint: Optional[Expr]
@dataclass(frozen=True, kw_only=True)
class ComplexTypeStmt: class Param:
name: Token location: Location
template: Optional[TemplateExpr] name: Token
properties: list[PropertyStmt] bound: Optional[Type]
class PropertyStmt: class PropertyStmt:
name: Token name: Token
type: TypeExpr type: Type
constraint: Optional[Expr]
class ExtendStmt: class ExtendStmt:
type: TypeExpr type: Type
operations: list[OpStmt] operations: list[OpStmt]
class OpStmt: class OpStmt:
name: Token name: Token
operand: TypeExpr operand: Type
result: TypeExpr result: Type
class PredicateStmt: class PredicateStmt:
name: Token name: Token
subject: Token subject: Token
type: TypeExpr type: Type
condition: Expr condition: Expr
@@ -54,9 +52,6 @@ class PredicateStmt:
###> Expr | Expressions ###> Expr | Expressions
class SimpleTypeExpr:
name: Token
optional: bool
class LogicalExpr: class LogicalExpr:
@@ -97,14 +92,27 @@ class WildcardExpr:
token: Token token: Token
class TemplateExpr: ###<
type: TypeExpr
###> Type | Types
class TypeExpr: class NamedType:
name: Token name: Token
template: Optional[TemplateExpr]
optional: bool
class GenericType:
type: Type
params: list[Type]
class ConstraintType:
type: Type
constraint: Expr
class ComplexType:
properties: list[PropertyStmt]
###< ###<

View File

@@ -28,10 +28,7 @@ class Stmt(ABC):
class Visitor(ABC, Generic[T]): class Visitor(ABC, Generic[T]):
@abstractmethod @abstractmethod
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> T: ... def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
@abstractmethod
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> T: ...
@abstractmethod @abstractmethod
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ... def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
@@ -47,31 +44,25 @@ class Stmt(ABC):
@dataclass(frozen=True) @dataclass(frozen=True)
class SimpleTypeStmt(Stmt): class TypeStmt(Stmt):
name: Token name: Token
template: Optional[TemplateExpr] params: list[Param]
base: TypeExpr type: Type
constraint: Optional[Expr]
@dataclass(frozen=True, kw_only=True)
class Param:
location: Location
name: Token
bound: Optional[Type]
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_simple_type_stmt(self) return visitor.visit_type_stmt(self)
@dataclass(frozen=True)
class ComplexTypeStmt(Stmt):
name: Token
template: Optional[TemplateExpr]
properties: list[PropertyStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_complex_type_stmt(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class PropertyStmt(Stmt): class PropertyStmt(Stmt):
name: Token name: Token
type: TypeExpr type: Type
constraint: Optional[Expr]
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_property_stmt(self) return visitor.visit_property_stmt(self)
@@ -79,7 +70,7 @@ class PropertyStmt(Stmt):
@dataclass(frozen=True) @dataclass(frozen=True)
class ExtendStmt(Stmt): class ExtendStmt(Stmt):
type: TypeExpr type: Type
operations: list[OpStmt] operations: list[OpStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
@@ -89,8 +80,8 @@ class ExtendStmt(Stmt):
@dataclass(frozen=True) @dataclass(frozen=True)
class OpStmt(Stmt): class OpStmt(Stmt):
name: Token name: Token
operand: TypeExpr operand: Type
result: TypeExpr result: Type
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_op_stmt(self) return visitor.visit_op_stmt(self)
@@ -100,7 +91,7 @@ class OpStmt(Stmt):
class PredicateStmt(Stmt): class PredicateStmt(Stmt):
name: Token name: Token
subject: Token subject: Token
type: TypeExpr type: Type
condition: Expr condition: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T: def accept(self, visitor: Stmt.Visitor[T]) -> T:
@@ -120,9 +111,6 @@ class Expr(ABC):
def accept(self, visitor: Visitor[T]) -> T: ... def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]): class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> T: ...
@abstractmethod @abstractmethod
def visit_logical_expr(self, expr: LogicalExpr) -> T: ... def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
@@ -147,21 +135,6 @@ class Expr(ABC):
@abstractmethod @abstractmethod
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ... def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
@abstractmethod
def visit_template_expr(self, expr: TemplateExpr) -> T: ...
@abstractmethod
def visit_type_expr(self, expr: TypeExpr) -> T: ...
@dataclass(frozen=True)
class SimpleTypeExpr(Expr):
name: Token
optional: bool
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_simple_type_expr(self)
@dataclass(frozen=True) @dataclass(frozen=True)
class LogicalExpr(Expr): class LogicalExpr(Expr):
@@ -233,19 +206,61 @@ class WildcardExpr(Expr):
return visitor.visit_wildcard_expr(self) return visitor.visit_wildcard_expr(self)
@dataclass(frozen=True) #########
class TemplateExpr(Expr): # Types #
type: TypeExpr #########
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_template_expr(self) @dataclass(frozen=True, kw_only=True)
class Type(ABC):
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_named_type(self, type: NamedType) -> T: ...
@abstractmethod
def visit_generic_type(self, type: GenericType) -> T: ...
@abstractmethod
def visit_constraint_type(self, type: ConstraintType) -> T: ...
@abstractmethod
def visit_complex_type(self, type: ComplexType) -> T: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class TypeExpr(Expr): class NamedType(Type):
name: Token name: Token
template: Optional[TemplateExpr]
optional: bool
def accept(self, visitor: Expr.Visitor[T]) -> T: def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_type_expr(self) return visitor.visit_named_type(self)
@dataclass(frozen=True)
class GenericType(Type):
type: Type
params: list[Type]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_generic_type(self)
@dataclass(frozen=True)
class ConstraintType(Type):
type: Type
constraint: Expr
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_constraint_type(self)
@dataclass(frozen=True)
class ComplexType(Type):
properties: list[PropertyStmt]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_complex_type(self)

View File

@@ -85,40 +85,39 @@ class AstPrinter(Generic[T]):
child.accept(self) child.accept(self)
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]): class MidasAstPrinter(
AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None], m.Type.Visitor[None]
):
# Statements # Statements
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt): def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self._write_line("SimpleTypeStmt") self._write_line("TypeStmt")
with self._child_level(): with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_optional_child("template", stmt.template) self._write_line("params")
self._write_line("base")
with self._child_level(single=True):
stmt.base.accept(self)
self._write_optional_child("constraint", stmt.constraint, last=True)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
self._write_line("ComplexTypeStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_optional_child("template", stmt.template)
self._write_line("properties", last=True)
with self._child_level(): with self._child_level():
for i, prop in enumerate(stmt.properties): for i, param in enumerate(stmt.params):
self._idx = i self._idx = i
if i == len(stmt.properties) - 1: if i == len(stmt.params) - 1:
self._mark_last() self._mark_last()
prop.accept(self) self._print_type_stmt_param(param)
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
self._write_line("Param")
with self._child_level():
self._write_line(f'name: "{param.name.lexeme}"')
self._write_optional_child("bound", param.bound, last=True)
def visit_property_stmt(self, stmt: m.PropertyStmt): def visit_property_stmt(self, stmt: m.PropertyStmt):
self._write_line("PropertyStmt") self._write_line("PropertyStmt")
with self._child_level(): with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"') self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type") self._write_line("type", last=True)
with self._child_level(single=True): with self._child_level(single=True):
stmt.type.accept(self) stmt.type.accept(self)
self._write_optional_child("constraint", stmt.constraint, last=True)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._write_line("ExtendStmt") self._write_line("ExtendStmt")
@@ -161,12 +160,6 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
# Expressions # Expressions
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
self._write_line("SimpleTypeExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_line(f"optional: {expr.optional}", last=True)
def visit_logical_expr(self, expr: m.LogicalExpr): def visit_logical_expr(self, expr: m.LogicalExpr):
self._write_line("LogicalExpr") self._write_line("LogicalExpr")
with self._child_level(): with self._child_level():
@@ -230,22 +223,48 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self._write_line("WildcardExpr") self._write_line("WildcardExpr")
def visit_template_expr(self, expr: m.TemplateExpr) -> None: def visit_named_type(self, type: m.NamedType) -> None:
self._write_line("TemplateExpr") self._write_line("NamedType")
with self._child_level(single=True): with self._child_level():
self._write_line(f'name: "{type.name.lexeme}"', last=True)
def visit_generic_type(self, type: m.GenericType) -> None:
self._write_line("GenericType")
with self._child_level():
self._write_line("type")
with self._child_level():
type.type.accept(self)
self._write_line("params", last=True)
with self._child_level():
for i, param in enumerate(type.params):
self._idx = i
if i == len(type.params) - 1:
self._mark_last()
param.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self._write_line("ConstraintType")
with self._child_level():
self._write_line("type") self._write_line("type")
with self._child_level(single=True): with self._child_level(single=True):
expr.type.accept(self) type.type.accept(self)
self._write_line("constraint", last=True)
with self._child_level(single=True):
type.constraint.accept(self)
def visit_type_expr(self, expr: m.TypeExpr): def visit_complex_type(self, type: m.ComplexType) -> None:
self._write_line("TypeExpr") self._write_line("ComplexType")
with self._child_level(): with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"') self._write_line("properties", last=True)
self._write_optional_child("template", expr.template) with self._child_level():
self._write_line(f"optional: {expr.optional}", last=True) for i, prop in enumerate(type.properties):
self._idx = i
if i == len(type.properties) - 1:
self._mark_last()
prop.accept(self)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]): class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
def __init__(self, indent: int = 4): def __init__(self, indent: int = 4):
self.indent: int = indent self.indent: int = indent
self.level: int = 0 self.level: int = 0
@@ -253,33 +272,28 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def indented(self, text: str) -> str: def indented(self, text: str) -> str:
return " " * (self.level * self.indent) + text return " " * (self.level * self.indent) + text
def print(self, expr: m.Expr | m.Stmt): def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
self.level = 0 self.level = 0
return expr.accept(self) return expr.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt): def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
template: str = stmt.template.accept(self) if stmt.template is not None else "" template: str = ""
res: str = f"type {stmt.name.lexeme}{template}({stmt.base.accept(self)})" if len(stmt.params) != 0:
if stmt.constraint is not None: params: list[str] = [
res += " where " + stmt.constraint.accept(self) self._print_type_template_param(param) for param in stmt.params
]
template = f"[{', '.join(params)}]"
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
return self.indented(res) return self.indented(res)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt): def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
template: str = stmt.template.accept(self) if stmt.template is not None else "" res: str = param.name.lexeme
res: str = self.indented(f"type {stmt.name.lexeme}{template}") if param.bound is not None:
res += " {\n" res += "<:" + param.bound.accept(self)
self.level += 1
for prop in stmt.properties:
res += prop.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
return res return res
def visit_property_stmt(self, stmt: m.PropertyStmt): def visit_property_stmt(self, stmt: m.PropertyStmt):
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}" res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
if stmt.constraint is not None:
res += " where " + stmt.constraint.accept(self)
return self.indented(res) return self.indented(res)
def visit_extend_stmt(self, stmt: m.ExtendStmt): def visit_extend_stmt(self, stmt: m.ExtendStmt):
@@ -289,13 +303,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
for op in stmt.operations: for op in stmt.operations:
res += op.accept(self) res += op.accept(self)
self.level -= 1 self.level -= 1
res += "\n" + self.indented("}") res += self.indented("}")
return res return res
def visit_op_stmt(self, stmt: m.OpStmt): def visit_op_stmt(self, stmt: m.OpStmt):
operand: str = stmt.operand.accept(self) operand: str = stmt.operand.accept(self)
result: str = stmt.result.accept(self) result: str = stmt.result.accept(self)
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}") return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}\n")
def visit_predicate_stmt(self, stmt: m.PredicateStmt): def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme name: str = stmt.name.lexeme
@@ -304,9 +318,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
condition: str = stmt.condition.accept(self) condition: str = stmt.condition.accept(self)
return self.indented(f"predicate {name}({subject}: {type}) = {condition}") return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
return f"{expr.name.lexeme}{'?' if expr.optional else ''}"
def visit_logical_expr(self, expr: m.LogicalExpr): def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self) left: str = expr.left.accept(self)
operator: str = expr.operator.lexeme operator: str = expr.operator.lexeme
@@ -342,12 +353,30 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def visit_wildcard_expr(self, expr: m.WildcardExpr): def visit_wildcard_expr(self, expr: m.WildcardExpr):
return "_" return "_"
def visit_template_expr(self, expr: m.TemplateExpr): def visit_named_type(self, type: m.NamedType) -> str:
return f"[{expr.type.accept(self)}]" return type.name.lexeme
def visit_type_expr(self, expr: m.TypeExpr): def visit_generic_type(self, type: m.GenericType) -> str:
template: str = expr.template.accept(self) if expr.template is not None else "" res: str = type.type.accept(self)
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}" if len(type.params) != 0:
params: list[str] = [param.accept(self) for param in type.params]
res += f"[{', '.join(params)}]"
return res
def visit_constraint_type(self, type: m.ConstraintType) -> str:
res: str = type.type.accept(self)
res += " where " + type.constraint.accept(self)
return res
def visit_complex_type(self, type: m.ComplexType) -> str:
res: str = "{\n"
self.level += 1
for prop in type.properties:
res += prop.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
return res
class PythonAstPrinter( class PythonAstPrinter(
@@ -600,11 +629,11 @@ class PythonAstPrinter(
self._write_line("test") self._write_line("test")
with self._child_level(single=True): with self._child_level(single=True):
expr.test.accept(self) expr.test.accept(self)
self._write_line("if_true") self._write_line("if_true")
with self._child_level(single=True): with self._child_level(single=True):
expr.if_true.accept(self) expr.if_true.accept(self)
self._write_line("if_false", last=True) self._write_line("if_false", last=True)
with self._child_level(single=True): with self._child_level(single=True):
expr.if_false.accept(self) expr.if_false.accept(self)

View File

@@ -9,7 +9,7 @@ from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
from midas.checker.types import BaseType, Function, SimpleType, Type, UnitType, UnknownType from midas.checker.types import Function, Type, UnitType, UnknownType
from midas.lexer.midas import MidasLexer from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token from midas.lexer.token import Token
from midas.parser.midas import MidasParser from midas.parser.midas import MidasParser
@@ -34,9 +34,15 @@ class Checker(
): ):
"""A type checker which can use custom type definitions""" """A type checker which can use custom type definitions"""
def __init__(self, locals: dict[p.Expr, int], file_path: Path): def __init__(
self,
locals: dict[p.Expr, int],
source_path: Path,
types_paths: list[Path],
):
self.logger: logging.Logger = logging.getLogger("Checker") self.logger: logging.Logger = logging.getLogger("Checker")
self.file_path: Path = file_path self.source_path: Path = source_path
self.types_paths: list[Path] = types_paths
self.ctx: MidasResolver = MidasResolver() self.ctx: MidasResolver = MidasResolver()
self.global_env: Environment = Environment() self.global_env: Environment = Environment()
self.env: Environment = self.global_env self.env: Environment = self.global_env
@@ -46,7 +52,7 @@ class Checker(
def diagnostic(self, type: DiagnosticType, location: Location, message: str): def diagnostic(self, type: DiagnosticType, location: Location, message: str):
self.diagnostics.append( self.diagnostics.append(
Diagnostic( Diagnostic(
file_path=self.file_path, file_path=self.source_path,
location=location, location=location,
type=type, type=type,
message=message, message=message,
@@ -74,7 +80,7 @@ class Checker(
message=message, message=message,
) )
def evaluate(self, expr: p.Expr) -> Type: def type_of(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression """Evaluate the type of an expression
Args: Args:
@@ -85,13 +91,13 @@ class Checker(
""" """
return expr.accept(self) return expr.accept(self)
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> bool: def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
"""Evaluate a sequence of statements """Evaluate a sequence of statements
Args: Args:
block (list[p.Stmt]): the statements to evaluate block (list[p.Stmt]): the statements to evaluate
env (Environment): the environment in which to evaluate env (Environment): the environment in which to evaluate
Returns: Returns:
bool: whether a return statement is present in the block bool: whether a return statement is present in the block
""" """
@@ -119,6 +125,12 @@ class Checker(
list[Diagnostic]: the list of diagnostics (errors, warning, etc.) list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
""" """
self.diagnostics = [] self.diagnostics = []
for path in self.types_paths:
self.import_midas(path)
self.logger.debug(f"Midas types: {self.ctx._types}")
self.logger.debug(f"Midas operations: {self.ctx._operations}")
for stmt in statements: for stmt in statements:
stmt.accept(self) stmt.accept(self)
@@ -140,30 +152,6 @@ class Checker(
return self.env.get_at(distance, name) return self.env.get_at(distance, name)
return self.global_env.get(name) return self.global_env.get(name)
def parse_midas_import(self, expr: p.CallExpr) -> Optional[Path]:
"""Parse a Midas import statement
The statement should be written as `midas.using("path/to/types.midas")`
Args:
expr (p.CallExpr): the import call expression
Returns:
Optional[Path]: the path to the imported file, or None if the expression is malformed
"""
match expr:
case p.CallExpr(
callee=p.GetExpr(
object=p.VariableExpr(name="midas"),
name="using",
),
arguments=[
p.LiteralExpr(value=path),
],
):
return Path(path)
return None
def import_midas(self, path: Path) -> None: def import_midas(self, path: Path) -> None:
"""Import Midas definitions from a path """Import Midas definitions from a path
@@ -171,17 +159,14 @@ class Checker(
path (Path): the import path path (Path): the import path
""" """
self.logger.debug(f"Importing type definitions from {path}") self.logger.debug(f"Importing type definitions from {path}")
path = (self.file_path.parent / path).resolve()
lexer: MidasLexer = MidasLexer(path.read_text()) lexer: MidasLexer = MidasLexer(path.read_text())
tokens: list[Token] = lexer.process() tokens: list[Token] = lexer.process()
parser: MidasParser = MidasParser(tokens) parser: MidasParser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse() stmts: list[m.Stmt] = parser.parse()
self.ctx.resolve(stmts) self.ctx.resolve(stmts)
self.logger.debug(f"Midas types: {self.ctx._types}")
self.logger.debug(f"Midas operations: {self.ctx._operations}")
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None: def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
self.evaluate(stmt.expr) self.type_of(stmt.expr)
def visit_function(self, stmt: p.Function) -> None: def visit_function(self, stmt: p.Function) -> None:
env: Environment = Environment(self.env) env: Environment = Environment(self.env)
@@ -223,7 +208,7 @@ class Checker(
for arg in pos_args + args + kw_args: for arg in pos_args + args + kw_args:
env.define(arg.name, arg.type) env.define(arg.name, arg.type)
returns_hint: Optional[Type] = None returns_hint: Optional[Type] = None
if stmt.returns is not None: if stmt.returns is not None:
returns_hint = stmt.returns.accept(self) returns_hint = stmt.returns.accept(self)
@@ -237,7 +222,7 @@ class Checker(
) )
self.env.define(stmt.name, inside_function) self.env.define(stmt.name, inside_function)
returned: bool = self.evaluate_block(stmt.body, env) returned: bool = self.process_block(stmt.body, env)
inferred_return: Type = UnknownType() inferred_return: Type = UnknownType()
if not returned: if not returned:
env.return_types.append(UnitType()) env.return_types.append(UnitType())
@@ -278,7 +263,7 @@ class Checker(
self.env.define(stmt.name, type) self.env.define(stmt.name, type)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
value: Type = self.evaluate(stmt.value) value: Type = self.type_of(stmt.value)
for target in stmt.targets: for target in stmt.targets:
if not isinstance(target, p.VariableExpr): if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}") self.logger.warning(f"Unsupported assignment to {target}")
@@ -317,8 +302,8 @@ class Checker(
) )
env: Environment = Environment(self.env) env: Environment = Environment(self.env)
body_returned: bool = self.evaluate_block(stmt.body, env) body_returned: bool = self.process_block(stmt.body, env)
else_returned: bool = self.evaluate_block(stmt.orelse, env) else_returned: bool = self.process_block(stmt.orelse, env)
self.env.return_types.extend(env.return_types) self.env.return_types.extend(env.return_types)
if body_returned and else_returned: if body_returned and else_returned:
raise ReturnException() raise ReturnException()
@@ -329,8 +314,8 @@ class Checker(
self.logger.warning(f"Unsupported operator {expr.operator}") self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}") self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType() return UnknownType()
left: Type = self.evaluate(expr.left) left: Type = self.type_of(expr.left)
right: Type = self.evaluate(expr.right) right: Type = self.type_of(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right) result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None: if result is None:
@@ -347,8 +332,8 @@ class Checker(
self.logger.warning(f"Unsupported operator {expr.operator}") self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}") self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType() return UnknownType()
left: Type = self.evaluate(expr.left) left: Type = self.type_of(expr.left)
right: Type = self.evaluate(expr.right) right: Type = self.type_of(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right) result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None: if result is None:
@@ -362,10 +347,7 @@ class Checker(
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ... def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
def visit_call_expr(self, expr: p.CallExpr) -> Type: def visit_call_expr(self, expr: p.CallExpr) -> Type:
if path := self.parse_midas_import(expr): callee: Type = self.type_of(expr.callee)
self.import_midas(path)
return UnknownType()
callee: Type = self.evaluate(expr.callee)
if not isinstance(callee, Function): if not isinstance(callee, Function):
self.error(expr.callee.location, "Callee is not a function") self.error(expr.callee.location, "Callee is not a function")
return UnknownType() return UnknownType()
@@ -398,7 +380,16 @@ class Checker(
def visit_variable_expr(self, expr: p.VariableExpr) -> Type: def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
return self.look_up_variable(expr.name, expr) or UnknownType() return self.look_up_variable(expr.name, expr) or UnknownType()
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: ... def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
left: Type = expr.left.accept(self)
right: Type = expr.right.accept(self)
# TODO: union type
if left != right:
self.error(
expr.location,
f"Operands must be of the same type, left={left} != right={right}",
)
return left
def visit_set_expr(self, expr: p.SetExpr) -> Type: ... def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
@@ -417,7 +408,10 @@ class Checker(
true_type: Type = expr.if_true.accept(self) true_type: Type = expr.if_true.accept(self)
false_type: Type = expr.if_false.accept(self) false_type: Type = expr.if_false.accept(self)
if true_type != false_type: if true_type != false_type:
self.error(expr.location, f"Type mismatch in ternary if branches: true={true_type} != false={false_type}") self.error(
expr.location,
f"Type mismatch in ternary if branches: true={true_type} != false={false_type}",
)
return UnknownType() return UnknownType()
return true_type return true_type
@@ -448,10 +442,10 @@ class Checker(
list[MappedArgument]: the list of mapped arguments list[MappedArgument]: the list of mapped arguments
""" """
positional: list[tuple[p.Expr, Type]] = [ positional: list[tuple[p.Expr, Type]] = [
(arg, self.evaluate(arg)) for arg in call.arguments (arg, self.type_of(arg)) for arg in call.arguments
] ]
keywords: dict[str, tuple[p.Expr, Type]] = { keywords: dict[str, tuple[p.Expr, Type]] = {
name: (arg, self.evaluate(arg)) for name, arg in call.keywords.items() name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
} }
set_args: set[str] = set() set_args: set[str] = set()

View File

@@ -9,9 +9,9 @@ class BaseType:
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class SimpleType: class AliasType:
name: str name: str
base: BaseType | SimpleType type: Type
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
@@ -39,4 +39,9 @@ class Function:
required: bool required: bool
Type = BaseType | SimpleType | UnknownType | UnitType | Function @dataclass(frozen=True, kw_only=True)
class ComplexType:
properties: dict[str, Type]
Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType

View File

@@ -53,5 +53,6 @@ span {
&.keyword { &.keyword {
color: rgb(211, 72, 9); color: rgb(211, 72, 9);
pointer-events: none;
} }
} }

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Generic, Optional, Protocol, TextIO, TypeVar from typing import Generic, Optional, Protocol, TextIO, TypeVar
@@ -8,6 +9,7 @@ import midas.ast.midas as m
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic from midas.checker.diagnostic import Diagnostic
from midas.lexer.token import Token
H = TypeVar("H", bound="Highlighter", contravariant=True) H = TypeVar("H", bound="Highlighter", contravariant=True)
@@ -22,6 +24,15 @@ class Locatable(Protocol):
def location(self) -> Optional[Location]: ... def location(self) -> Optional[Location]: ...
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
class Highlighter(ABC): class Highlighter(ABC):
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css" BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
EXTRA_CSS_PATH: Optional[Path] = None EXTRA_CSS_PATH: Optional[Path] = None
@@ -206,34 +217,22 @@ class PythonHighlighter(
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ... def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]): class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css" EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
def highlight(self, node: Highlightable[MidasHighlighter]): def highlight(self, node: Highlightable[MidasHighlighter]):
node.accept(self) node.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None: def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self.wrap(stmt, "simple-type") self.wrap(stmt, "type-stmt")
if stmt.template is not None: self.wrap(LocatableToken(stmt.name), "type-name")
stmt.template.accept(self) stmt.type.accept(self)
stmt.base.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None:
self.wrap(stmt, "complex-type")
if stmt.template is not None:
stmt.template.accept(self)
for prop in stmt.properties:
prop.accept(self)
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
self.wrap(stmt, "property") self.wrap(stmt, "property")
stmt.type.accept(self) stmt.type.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self.wrap(stmt, "extend") self.wrap(stmt, "extend")
@@ -243,17 +242,16 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
def visit_op_stmt(self, stmt: m.OpStmt) -> None: def visit_op_stmt(self, stmt: m.OpStmt) -> None:
self.wrap(stmt, "op") self.wrap(stmt, "op")
self.wrap(LocatableToken(stmt.name), "op-name")
stmt.operand.accept(self) stmt.operand.accept(self)
stmt.result.accept(self) stmt.result.accept(self)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate") self.wrap(stmt, "predicate")
self.wrap(LocatableToken(stmt.name), "predicate-name")
stmt.type.accept(self) stmt.type.accept(self)
stmt.condition.accept(self) stmt.condition.accept(self)
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None:
self.wrap(expr, "simple-type-expr")
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.wrap(expr, "logical-expr") self.wrap(expr, "logical-expr")
expr.left.accept(self) expr.left.accept(self)
@@ -282,14 +280,24 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ... def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_template_expr(self, expr: m.TemplateExpr) -> None: def visit_named_type(self, type: m.NamedType) -> None:
self.wrap(expr, "template") self.wrap(type, "named-type")
expr.type.accept(self)
def visit_type_expr(self, expr: m.TypeExpr) -> None: def visit_generic_type(self, type: m.GenericType) -> None:
self.wrap(expr, "type") self.wrap(type, "generic-type")
if expr.template is not None: type.type.accept(self)
expr.template.accept(self) for param in type.params:
param.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self.wrap(type, "constraint-type")
type.type.accept(self)
type.constraint.accept(self)
def visit_complex_type(self, type: m.ComplexType) -> None:
self.wrap(type, "complex-type")
for prop in type.properties:
prop.accept(self)
class DiagnosticsHighlighter(Highlighter): class DiagnosticsHighlighter(Highlighter):

View File

@@ -5,12 +5,11 @@ span {
font-style: italic; font-style: italic;
} }
&.simple-type { &.named-type,
--col: 108, 233, 108; &.generic-type,
} &.constraint-type,
&.complex-type { &.complex-type {
--col: 233, 206, 108; --col: 150, 150, 150;
} }
&.constraint { &.constraint {
@@ -33,10 +32,6 @@ span {
--col: 193, 108, 233; --col: 193, 108, 233;
} }
&.simple-type-expr {
--col: 150, 150, 150;
}
&.logical-expr, &.logical-expr,
&.binary-expr, &.binary-expr,
&.unary-expr, &.unary-expr,
@@ -48,7 +43,9 @@ span {
--col: 163, 117, 71; --col: 163, 117, 71;
} }
&.type { &.type-name,
&.op-name,
&.predicate-name {
--col: 200, 200, 200; --col: 200, 200, 200;
font-weight: bold; font-weight: bold;
} }

View File

@@ -1,7 +1,6 @@
import ast import ast
import json import json
import logging import logging
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, TextIO, get_args from typing import Optional, TextIO, get_args
@@ -9,14 +8,14 @@ import click
import midas.ast.midas as m import midas.ast.midas as m
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
from midas.checker.checker import Checker from midas.checker.checker import Checker
from midas.checker.diagnostic import Diagnostic from midas.checker.diagnostic import Diagnostic
from midas.checker.types import Type from midas.checker.types import Type
from midas.cli.highlighter import ( from midas.cli.highlighter import (
DiagnosticsHighlighter, DiagnosticsHighlighter,
Highlighter, Highlighter,
LocatableToken,
MidasHighlighter, MidasHighlighter,
PythonHighlighter, PythonHighlighter,
) )
@@ -35,8 +34,9 @@ def midas():
@midas.command() @midas.command()
@click.option("-l", "--highlight", type=click.File("w")) @click.option("-l", "--highlight", type=click.File("w"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
def compile(highlight: Optional[TextIO], file: TextIO): def compile(highlight: Optional[TextIO], file: TextIO, types: tuple[TextIO]):
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
source: str = file.read() source: str = file.read()
tree: ast.Module = ast.parse(source, filename=file.name) tree: ast.Module = ast.parse(source, filename=file.name)
@@ -44,7 +44,12 @@ def compile(highlight: Optional[TextIO], file: TextIO):
stmts: list[p.Stmt] = parser.parse_module(tree) stmts: list[p.Stmt] = parser.parse_module(tree)
resolver = Resolver() resolver = Resolver()
resolver.resolve(*stmts) resolver.resolve(*stmts)
checker = Checker(resolver.locals, file_path=Path(file.name).resolve()) types_paths: list[Path] = [Path(t.name).resolve() for t in types]
checker = Checker(
resolver.locals,
source_path=Path(file.name).resolve(),
types_paths=types_paths,
)
diagnostics: list[Diagnostic] = checker.check(stmts) diagnostics: list[Diagnostic] = checker.check(stmts)
for diagnostic in diagnostics: for diagnostic in diagnostics:
print(diagnostic) print(diagnostic)
@@ -142,14 +147,6 @@ def highlight_midas(source: str, path: str) -> Highlighter:
for err in parser.errors: for err in parser.errors:
print(err.get_report()) print(err.get_report())
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
for stmt in stmts: for stmt in stmts:
highlighter.highlight(stmt) highlighter.highlight(stmt)
for token in tokens: for token in tokens:
@@ -176,5 +173,21 @@ def highlight(output: TextIO, file: TextIO):
highlighter.dump(output) highlighter.dump(output)
@midas.command()
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.argument("file", type=click.File("r"))
def format(output: TextIO, file: TextIO):
source: str = file.read()
printer = MidasPrinter()
lexer = MidasLexer(source, file=file.name)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
for err in parser.errors:
print(err.get_report())
for stmt in stmts:
output.write(printer.print(stmt) + "\n")
if __name__ == "__main__": if __name__ == "__main__":
midas() midas()

View File

@@ -40,8 +40,8 @@ class MidasLexer(Lexer):
self.add_token(TokenType.AND) self.add_token(TokenType.AND)
case "?": case "?":
self.add_token(TokenType.QMARK) self.add_token(TokenType.QMARK)
# case ",": case ",":
# self.add_token(TokenType.COMMA) self.add_token(TokenType.COMMA)
case "_" if not self.is_identifier_char(self.peek_next(), start=False): case "_" if not self.is_identifier_char(self.peek_next(), start=False):
self.add_token(TokenType.UNDERSCORE) self.add_token(TokenType.UNDERSCORE)
case "-" if self.match(">"): case "-" if self.match(">"):

View File

@@ -17,7 +17,7 @@ class TokenType(Enum):
LEFT_BRACE = auto() LEFT_BRACE = auto()
RIGHT_BRACE = auto() RIGHT_BRACE = auto()
COLON = auto() COLON = auto()
# COMMA = auto() COMMA = auto()
UNDERSCORE = auto() UNDERSCORE = auto()
ARROW = auto() ARROW = auto()
AND = auto() AND = auto()

View File

@@ -3,21 +3,22 @@ from typing import Optional
from midas.ast.location import Location from midas.ast.location import Location
from midas.ast.midas import ( from midas.ast.midas import (
BinaryExpr, BinaryExpr,
ComplexTypeStmt, ComplexType,
ConstraintType,
Expr, Expr,
ExtendStmt, ExtendStmt,
GenericType,
GetExpr, GetExpr,
GroupingExpr, GroupingExpr,
LiteralExpr, LiteralExpr,
LogicalExpr, LogicalExpr,
NamedType,
OpStmt, OpStmt,
PredicateStmt, PredicateStmt,
PropertyStmt, PropertyStmt,
SimpleTypeExpr,
SimpleTypeStmt,
Stmt, Stmt,
TemplateExpr, Type,
TypeExpr, TypeStmt,
UnaryExpr, UnaryExpr,
VariableExpr, VariableExpr,
WildcardExpr, WildcardExpr,
@@ -81,7 +82,7 @@ class MidasParser(Parser):
self.synchronize() self.synchronize()
return None return None
def type_declaration(self) -> SimpleTypeStmt | ComplexTypeStmt: def type_declaration(self) -> TypeStmt:
"""Parse a type declaration """Parse a type declaration
A type declaration can either be a simple type alias or a new complex type. A type declaration can either be a simple type alias or a new complex type.
@@ -107,33 +108,22 @@ class MidasParser(Parser):
""" """
keyword: Token = self.previous() keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
template: Optional[TemplateExpr] = None params: list[TypeStmt.Param] = []
if self.check(TokenType.LEFT_BRACKET): if self.check(TokenType.LEFT_BRACKET):
template = self.template_expr() params = self.type_stmt_params()
if self.match(TokenType.LEFT_PAREN): self.consume(TokenType.EQUAL, "Expected '=' before type definition")
base: TypeExpr = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Unclosed base type parenthesis")
constraint: Optional[Expr] = None
if self.match(TokenType.WHERE):
constraint = self.constraint()
return SimpleTypeStmt(
location=keyword.location_to(self.previous()),
name=name,
template=template,
base=base,
constraint=constraint,
)
else:
properties: list[PropertyStmt] = self.type_properties()
return ComplexTypeStmt(
location=keyword.location_to(self.previous()),
name=name,
template=template,
properties=properties,
)
def template_expr(self) -> TemplateExpr: type: Type = self.type_expr()
return TypeStmt(
location=keyword.location_to(self.previous()),
name=name,
params=params,
type=type,
)
def type_stmt_params(self) -> list[TypeStmt.Param]:
"""Parse a generic template expression """Parse a generic template expression
A template is written `[TypeExpr]` A template is written `[TypeExpr]`
@@ -141,16 +131,27 @@ class MidasParser(Parser):
Returns: Returns:
TemplateExpr: the parsed template expression TemplateExpr: the parsed template expression
""" """
left: Token = self.consume( self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
TokenType.LEFT_BRACKET, "Missing '[' before template expression" params: list[TypeStmt.Param] = []
) while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
type: TypeExpr = self.type_expr() name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
right: Token = self.consume( bound: Optional[Type] = None
TokenType.RIGHT_BRACKET, "Missing ']' after template expression" if self.match(TokenType.LESS):
) self.consume(TokenType.COLON, "Expected ':' after '<'")
return TemplateExpr(location=left.location_to(right), type=type) bound = self.type_expr()
params.append(
TypeStmt.Param(
location=name.location_to(self.previous()),
name=name,
bound=bound,
)
)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
return params
def type_expr(self) -> TypeExpr: def type_expr(self) -> Type:
"""Parse a type expression """Parse a type expression
A type is an identifier, optionally followed by a template expression. A type is an identifier, optionally followed by a template expression.
@@ -159,30 +160,82 @@ class MidasParser(Parser):
Returns: Returns:
TypeExpr: the parsed type expression TypeExpr: the parsed type expression
""" """
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") return self.constraint_type()
template: Optional[TemplateExpr] = None
def constraint_type(self) -> Type:
type: Type = self.base_type()
if self.match(TokenType.WHERE):
constraint: Expr = self.constraint()
return ConstraintType(
location=Location.span(type.location, constraint.location),
type=type,
constraint=constraint,
)
return type
def base_type(self) -> Type:
if self.match(TokenType.LEFT_PAREN):
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
return type
if self.check(TokenType.LEFT_BRACE):
return self.complex_type()
return self.generic_type()
def generic_type(self) -> Type:
type: Type = self.named_type()
if self.check(TokenType.LEFT_BRACKET): if self.check(TokenType.LEFT_BRACKET):
template = self.template_expr() params: list[Type] = self.type_params()
optional: bool = self.match(TokenType.QMARK) return GenericType(
return TypeExpr( location=Location.span(type.location, self.previous().get_location()),
location=name.location_to(self.previous()), type=type,
params=params,
)
return type
def type_params(self) -> list[Type]:
params: list[Type] = []
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
params.append(self.type_expr())
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
return params
def named_type(self) -> Type:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
return NamedType(
location=name.get_location(),
name=name, name=name,
template=template,
optional=optional,
) )
def simple_type_expr(self) -> SimpleTypeExpr: def complex_type(self) -> Type:
"""Parse a simple type expression """Parse a type definition body
A simple type is just an identifier optionally followed by a '?' A type definition body is a set of whitespace-separated
property statements enclosed in curly braces
Returns: Returns:
SimpleTypeExpr: the parsed simple type expression list[PropertyStmt]: the parsed type properties
""" """
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name") left: Token = self.consume(
optional: bool = self.match(TokenType.QMARK) TokenType.LEFT_BRACE, "Expected '{' to start type body"
return SimpleTypeExpr( )
location=name.location_to(self.previous()), name=name, optional=optional properties: list[PropertyStmt] = []
names: set[str] = set()
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
prop: PropertyStmt = self.property_stmt()
if prop.name.lexeme in names:
raise self.error(prop.name, "Duplicate property")
names.add(prop.name.lexeme)
properties.append(prop)
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
return ComplexType(
location=left.location_to(right),
properties=properties,
) )
def constraint(self) -> Expr: def constraint(self) -> Expr:
@@ -308,27 +361,6 @@ class MidasParser(Parser):
raise self.error(self.peek(), "Expected expression") raise self.error(self.peek(), "Expected expression")
def type_properties(self) -> list[PropertyStmt]:
"""Parse a type definition body
A type definition body is a set of whitespace-separated
property statements enclosed in curly braces
Returns:
list[PropertyStmt]: the parsed type properties
"""
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body")
properties: list[PropertyStmt] = []
names: set[str] = set()
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
prop: PropertyStmt = self.property_stmt()
if prop.name.lexeme in names:
raise self.error(prop.name, "Duplicate property")
names.add(prop.name.lexeme)
properties.append(prop)
self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
return properties
def property_stmt(self) -> PropertyStmt: def property_stmt(self) -> PropertyStmt:
"""Parse a property statement """Parse a property statement
@@ -339,15 +371,11 @@ class MidasParser(Parser):
""" """
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
self.consume(TokenType.COLON, "Expected ':' after property name") self.consume(TokenType.COLON, "Expected ':' after property name")
type: TypeExpr = self.type_expr() type: Type = self.type_expr()
constraint: Optional[Expr] = None
if self.match(TokenType.WHERE):
constraint = self.constraint()
return PropertyStmt( return PropertyStmt(
location=name.location_to(self.previous()), location=name.location_to(self.previous()),
name=name, name=name,
type=type, type=type,
constraint=constraint,
) )
def extend_declaration(self) -> ExtendStmt: def extend_declaration(self) -> ExtendStmt:
@@ -359,7 +387,7 @@ class MidasParser(Parser):
ExtendStmt: the parsed extension statement ExtendStmt: the parsed extension statement
""" """
keyword: Token = self.previous() keyword: Token = self.previous()
type: TypeExpr = self.type_expr() type: Type = self.type_expr()
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body") self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
operations: list[OpStmt] = [] operations: list[OpStmt] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE): while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
@@ -380,11 +408,11 @@ class MidasParser(Parser):
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name") name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type") self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
operand: TypeExpr = self.type_expr() operand: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type") self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
self.consume(TokenType.ARROW, "Expected '->' before result type") self.consume(TokenType.ARROW, "Expected '->' before result type")
result: TypeExpr = self.type_expr() result: Type = self.type_expr()
return OpStmt( return OpStmt(
location=keyword.location_to(self.previous()), location=keyword.location_to(self.previous()),
@@ -406,7 +434,7 @@ class MidasParser(Parser):
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject") self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name") subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
self.consume(TokenType.COLON, "Expected ':' after subject name") self.consume(TokenType.COLON, "Expected ':' after subject name")
type: TypeExpr = self.type_expr() type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject") self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject") self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint() condition: Expr = self.constraint()

View File

@@ -87,6 +87,9 @@ class PythonParser:
case ast.If(): case ast.If():
return self.parse_if(node) return self.parse_if(node)
case ast.Pass():
return None
case _: case _:
print(f"Unsupported statement: {ast.unparse(node)}") print(f"Unsupported statement: {ast.unparse(node)}")
return None return None
@@ -311,6 +314,13 @@ class PythonParser:
constraint=right_expr, constraint=right_expr,
) )
case ast.Constant(value=None):
return BaseType(
location=loc,
base="None",
param=None,
)
case _: case _:
raise UnsupportedSyntaxError(type_expr) raise UnsupportedSyntaxError(type_expr)

View File

@@ -16,6 +16,7 @@ def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
result=t3, result=t3,
) )
def basic_op(ctx: MidasResolver, type: Type, op: str): def basic_op(ctx: MidasResolver, type: Type, op: str):
ctx.define_operation( ctx.define_operation(
left=type, left=type,
@@ -68,4 +69,4 @@ def define_builtins(ctx: MidasResolver):
op(ctx, float, "__gt__", int, bool) # float > int = bool op(ctx, float, "__gt__", int, bool) # float > int = bool
op(ctx, float, "__le__", int, bool) # float <= int = bool op(ctx, float, "__le__", int, bool) # float <= int = bool
op(ctx, float, "__ge__", int, bool) # float >= int = bool op(ctx, float, "__ge__", int, bool) # float >= int = bool
op(ctx, float, "__eq__", int, bool) # float == int = bool op(ctx, float, "__eq__", int, bool) # float == int = bool

View File

@@ -1,11 +1,15 @@
from typing import Optional from typing import Optional
import midas.ast.midas as m import midas.ast.midas as m
from midas.checker.types import BaseType, SimpleType, Type from midas.checker.types import (
AliasType,
Type,
UnknownType,
)
from midas.resolver.builtin import define_builtins from midas.resolver.builtin import define_builtins
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]): class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry""" """A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self) -> None: def __init__(self) -> None:
@@ -94,20 +98,13 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]):
for stmt in stmts: for stmt in stmts:
stmt.accept(self) stmt.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None: def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
# TODO generics, optional, constraint type: Type = stmt.type.accept(self)
base: Type = self.get_type(stmt.base.name.lexeme) for param in stmt.params:
match base: if param.bound is not None:
case BaseType() | SimpleType(): param.bound.accept(self)
type = SimpleType( name: str = stmt.name.lexeme
name=stmt.name.lexeme, self.define_type(name, AliasType(name=name, type=type))
base=base,
)
self.define_type(type.name, type)
case _:
raise TypeError(f"Invalid base {base} for simple type")
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None: ...
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
@@ -127,27 +124,40 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]):
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ... def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> Type: def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
return self.get_type(expr.name.lexeme)
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type: ... def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type: ... def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: ... def visit_get_expr(self, expr: m.GetExpr) -> None: ...
def visit_get_expr(self, expr: m.GetExpr) -> Type: ... def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
def visit_variable_expr(self, expr: m.VariableExpr) -> Type: ... def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
return expr.expr.accept(self) return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type: ... def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type: ... def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_template_expr(self, expr: m.TemplateExpr) -> Type: ... def visit_named_type(self, type: m.NamedType) -> Type:
return self.get_type(type.name.lexeme)
def visit_type_expr(self, expr: m.TypeExpr) -> Type: def visit_generic_type(self, type: m.GenericType) -> Type:
return self.get_type(expr.name.lexeme) type_: Type = type.type.accept(self)
params: list[Type] = [param.accept(self) for param in type.params]
# TODO
return UnknownType()
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
type_: Type = type.type.accept(self)
type.constraint.accept(self)
# TODO
return UnknownType()
def visit_complex_type(self, type: m.ComplexType) -> Type:
for prop in type.properties:
prop.accept(self)
# TODO
return UnknownType()

View File

@@ -19,16 +19,24 @@ Comparison ::= Unary (ComparisonOp Unary)*
Equality ::= Comparison (EqualityOp Comparison)* Equality ::= Comparison (EqualityOp Comparison)*
Constraint ::= Equality ("&" Equality)* Constraint ::= Equality ("&" Equality)*
SimpleType ::= Identifier "?"? TemplateParam ::= Identifier ("<:" Type)?
Template ::= "[" Type "]" Template ::= "[" (TemplateParam ("," TemplateParam)*)? "]"
Type ::= Identifier Template? "?"?
TypeProperty ::= Identifier ":" Type
ComplexType ::= "{" TypeProperty* "}"
NamedType ::= Identifier
TypeParams ::= "[" (Type ("," Type)*)? "]"
GenericType ::= NamedType TypeParams?
GroupedType ::= "(" Type ")"
BaseType ::= GroupedType | ComplexType | GenericType
ConstraintType ::= BaseType ("where" Constraint)?
Type ::= ConstraintType
TypeProperty ::= Identifier ":" Type ("where" Constraints)?
ComplexTypeBody ::= "{" TypeProperty* "}"
OpDefinition ::= "op" Identifier "(" Type ")" "->" Type OpDefinition ::= "op" Identifier "(" Type ")" "->" Type
ExtendBody ::= "{" OpDefinition* "}" ExtendBody ::= "{" OpDefinition* "}"
TypeStatement ::= "type" Identifier Template? ("(" Type ")" ("where" Constraint)? | ComplexTypeBody) TypeStatement ::= "type" Identifier Template? "=" Type
ExtendStatement ::= "extend" Type ExtendBody ExtendStatement ::= "extend" Type ExtendBody
PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint

View File

@@ -43,28 +43,52 @@ svg.railroad .terminal rect {
{[`constraint` 'equality'*"&"]} {[`constraint` 'equality'*"&"]}
``` ```
#let simple-type = ``` #let template-param = ```
{[`simple-type` 'identifier' <!, "?">]} {[`template-param` 'identifier' <!, ["<:" 'type']>]}
``` ```
#let template = ``` #let template = ```
{[`template` "[" 'type' "]"]} {[`template` "[" <!, 'template-param'*","> "]"]}
```
#let type = ```
{[`type` 'identifier' <!, 'template'> <!, "?">]}
``` ```
#let type-property = ``` #let type-property = ```
{[`type-property` 'identifier' ":" 'type' <!, ["where" 'constraint']>]} {[`type-property` 'identifier' ":" 'type']}
``` ```
#let type-body = ``` #let complex-type = ```
{[`type-body` "{" <!, 'type-property'*!> "}"]} {[`complex-type` "{" <!, 'type-property'*!> "}"]}
```
#let named-type = ```
{[`named-type` 'identifier']}
```
#let type-params = ```
{[`type-params` "[" <!, 'type'*","> "]"]}
```
#let generic-type = ```
{[`generic-type` 'named-type' <!, 'type-params'>]}
```
#let grouped-type = ```
{[`grouped-type` "(" 'type' ")"]}
```
#let base-type = ```
{[`base-type` <'grouped-type', 'complex-type', 'generic-type'>]}
```
#let constraint-type = ```
{[`constraint-type` 'base-type' <!, ["where" 'constraint']>]}
```
#let type = ```
{[`type` 'constraint-type']}
``` ```
#let type-statement = ``` #let type-statement = ```
{[`type-statement` "type" 'identifier' <!, 'template'> <[["(" 'type' ")"] <!, ["where" 'constraint']>], 'type-body'>]} {[`type-statement` "type" 'identifier' <!, 'template'> "=" 'type']}
``` ```
#let op-definition = ``` #let op-definition = ```
@@ -92,11 +116,17 @@ svg.railroad .terminal rect {
comparison: comparison, comparison: comparison,
equality: equality, equality: equality,
constraint: constraint, constraint: constraint,
simple-type: simple-type, template-param: template-param,
template: template, template: template,
type: type,
type-property: type-property, type-property: type-property,
type-body: type-body, complex-type: complex-type,
named-type: named-type,
type-params: type-params,
generic-type: generic-type,
grouped-type: grouped-type,
base-type: base-type,
constraint-type: constraint-type,
type: type,
type-statement: type-statement, type-statement: type-statement,
op-definition: op-definition, op-definition: op-definition,
extend-statement: extend-statement, extend-statement: extend-statement,
@@ -107,10 +137,16 @@ svg.railroad .terminal rect {
#let inline = ( #let inline = (
"grouping", "grouping",
"value", "value",
"template-param",
"template", "template",
"simple-type",
"type-property", "type-property",
"type-body", "complex-type",
"type-params",
"named-type",
"grouped-type",
"generic-type",
"base-type",
"constraint-type",
"op-definition", "op-definition",
"type-statement", "type-statement",
"extend-statement", "extend-statement",

View File

@@ -1,6 +1,6 @@
type Meter(float) type Meter = float
type Second(float) type Second = float
type MeterPerSecond(float) type MeterPerSecond = float
extend Meter { extend Meter {
op __add__(Meter) -> Meter op __add__(Meter) -> Meter

View File

@@ -1,8 +1,6 @@
# type: ignore # type: ignore
# ruff: disable [F821] # ruff: disable [F821]
midas.using("04_custom_types.midas")
distance: Meter = cast(Meter, 123.45) distance: Meter = cast(Meter, 123.45)
time: Second = cast(Second, 6.7) time: Second = cast(Second, 6.7)
speed = distance / time speed = distance / time

View File

@@ -1,15 +1,15 @@
// Simple custom type derived from float // Simple custom type derived from float
type Custom(float) type Custom = float
// Simple custom types with constraints // Simple custom types with constraints
type Latitude(float) where (-90 <= _ <= 90) type Latitude = float where (-90 <= _ <= 90)
type Longitude(float) where (-180 <= _ <= 180) type Longitude = float where (-180 <= _ <= 180)
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float // Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
type Difference[T](T) type Difference[T] = T
// Complex custom type, containing two values accessible through properties // Complex custom type, containing two values accessible through properties
type GeoLocation { type GeoLocation = {
lat: Latitude lat: Latitude
lon: Longitude lon: Longitude
} }
@@ -24,7 +24,7 @@ extend GeoLocation {
// For complex generics, you need to specify how the genericity the properties // For complex generics, you need to specify how the genericity the properties
// are handled // are handled
type Difference[GeoLocation] { type Difference[GeoLocation] = {
lat: Difference[Latitude] lat: Difference[Latitude]
lon: Difference[Longitude] lon: Difference[Longitude]
} }
@@ -44,11 +44,11 @@ predicate StrictlyPositive(v: float) = v > 0
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10) predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66) predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
type Person { type Person = {
name: str name: str
// Property with an inline constraint // Property with an inline constraint
age: int? where (0 <= _ < 150) age: Optional[int where (0 <= _ < 150)]
// Property referencing a predicate // Property referencing a predicate
height: float where StrictlyPositive height: float where StrictlyPositive

File diff suppressed because it is too large Load Diff

View File

@@ -2,10 +2,6 @@
# ruff: disable[F821] # ruff: disable[F821]
from __future__ import annotations from __future__ import annotations
import midas
midas.using("02_custom_types.midas")
df: Frame[ df: Frame[
location: GeoLocation location: GeoLocation
] ]

View File

@@ -1,26 +1,5 @@
{ {
"stmts": [ "stmts": [
{
"_type": "ExpressionStmt",
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "GetExpr",
"object": {
"_type": "VariableExpr",
"name": "midas"
},
"name": "using"
},
"arguments": [
{
"_type": "LiteralExpr",
"value": "02_custom_types.midas"
}
],
"keywords": {}
}
},
{ {
"_type": "TypeAssign", "_type": "TypeAssign",
"name": "df", "name": "df",

View File

@@ -33,6 +33,10 @@ class CheckerTester(Tester):
if not path.is_file(): if not path.is_file():
raise TypeError(f"Test '{path}' is not a file") raise TypeError(f"Test '{path}' is not a file")
types_paths: list[Path] = []
types_path: Path = path.with_suffix(".midas")
if types_path.exists():
types_paths.append(types_path)
source: str = path.read_text() source: str = path.read_text()
tree: ast.Module = ast.parse(source, filename=path) tree: ast.Module = ast.parse(source, filename=path)
parser = PythonParser() parser = PythonParser()
@@ -40,7 +44,11 @@ class CheckerTester(Tester):
resolver = Resolver() resolver = Resolver()
resolver.resolve(*stmts) resolver.resolve(*stmts)
result: CaseResult = CaseResult() result: CaseResult = CaseResult()
checker = Checker(resolver.locals, file_path=path) checker = Checker(
resolver.locals,
source_path=path,
types_paths=types_paths,
)
diagnostics: list[Diagnostic] = checker.check(stmts) diagnostics: list[Diagnostic] = checker.check(stmts)
for diagnostic in diagnostics: for diagnostic in diagnostics:
result.diagnostics.append( result.diagnostics.append(

View File

@@ -2,56 +2,60 @@ from typing import Optional, Sequence
from midas.ast.midas import ( from midas.ast.midas import (
BinaryExpr, BinaryExpr,
ComplexTypeStmt, ComplexType,
ConstraintType,
Expr, Expr,
ExtendStmt, ExtendStmt,
GenericType,
GetExpr, GetExpr,
GroupingExpr, GroupingExpr,
LiteralExpr, LiteralExpr,
LogicalExpr, LogicalExpr,
NamedType,
OpStmt, OpStmt,
PredicateStmt, PredicateStmt,
PropertyStmt, PropertyStmt,
SimpleTypeExpr,
SimpleTypeStmt,
Stmt, Stmt,
TemplateExpr, Type,
TypeExpr, TypeStmt,
UnaryExpr, UnaryExpr,
VariableExpr, VariableExpr,
WildcardExpr, WildcardExpr,
) )
class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]): class MidasAstJsonSerializer(
Stmt.Visitor[dict], Expr.Visitor[dict], Type.Visitor[dict]
):
"""An AST serializer which produces a JSON-compatible structure""" """An AST serializer which produces a JSON-compatible structure"""
def serialize(self, stmts: list[Stmt]) -> list[dict]: def serialize(self, stmts: list[Stmt]) -> list[dict]:
return [stmt.accept(self) for stmt in stmts] return [stmt.accept(self) for stmt in stmts]
def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]: def _serialize_optional(
self, element: Optional[Stmt | Expr | Type]
) -> Optional[dict]:
if element is None: if element is None:
return None return None
return element.accept(self) return element.accept(self)
def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]: def _serialize_list(self, elements: Sequence[Stmt | Expr | Type]) -> list[dict]:
return [element.accept(self) for element in elements] return [element.accept(self) for element in elements]
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict: def visit_type_stmt(self, stmt: TypeStmt) -> dict:
return { return {
"_type": "SimpleTypeStmt", "_type": "TypeStmt",
"name": stmt.name.lexeme, "name": stmt.name.lexeme,
"template": self._serialize_optional(stmt.template), "params": [
"base": stmt.base.accept(self), self._serialize_type_stmt_template_param(param) for param in stmt.params
"constraint": self._serialize_optional(stmt.constraint), ],
"type": stmt.type.accept(self),
} }
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict: def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
return { return {
"_type": "ComplexTypeStmt", "name": param.name.lexeme,
"name": stmt.name.lexeme, "bound": self._serialize_optional(param.bound),
"template": self._serialize_optional(stmt.template),
"properties": self._serialize_list(stmt.properties),
} }
def visit_property_stmt(self, stmt: PropertyStmt) -> dict: def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
@@ -59,7 +63,6 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"_type": "PropertyStmt", "_type": "PropertyStmt",
"name": stmt.name.lexeme, "name": stmt.name.lexeme,
"type": stmt.type.accept(self), "type": stmt.type.accept(self),
"constraint": self._serialize_optional(stmt.constraint),
} }
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict: def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
@@ -86,13 +89,6 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"condition": stmt.condition.accept(self), "condition": stmt.condition.accept(self),
} }
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict:
return {
"_type": "SimpleTypeExpr",
"name": expr.name.lexeme,
"optional": expr.optional,
}
def visit_logical_expr(self, expr: LogicalExpr) -> dict: def visit_logical_expr(self, expr: LogicalExpr) -> dict:
return { return {
"_type": "LogicalExpr", "_type": "LogicalExpr",
@@ -144,16 +140,28 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict: def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
return {"_type": "WildcardExpr"} return {"_type": "WildcardExpr"}
def visit_template_expr(self, expr: TemplateExpr) -> dict: def visit_named_type(self, type: NamedType) -> dict:
return { return {
"_type": "TemplateExpr", "_type": "NamedType",
"type": expr.type.accept(self), "name": type.name.lexeme,
} }
def visit_type_expr(self, expr: TypeExpr) -> dict: def visit_generic_type(self, type: GenericType) -> dict:
return { return {
"_type": "TypeExpr", "_type": "GenericType",
"name": expr.name.lexeme, "type": type.type.accept(self),
"template": self._serialize_optional(expr.template), "params": self._serialize_list(type.params),
"optional": expr.optional, }
def visit_constraint_type(self, type: ConstraintType) -> dict:
return {
"_type": "ConstraintType",
"type": type.type.accept(self),
"constraint": type.constraint.accept(self),
}
def visit_complex_type(self, type: ComplexType) -> dict:
return {
"_type": "ComplexType",
"properties": self._serialize_list(type.properties),
} }

View File

@@ -22,6 +22,7 @@ from midas.ast.python import (
ReturnStmt, ReturnStmt,
SetExpr, SetExpr,
Stmt, Stmt,
TernaryExpr,
TypeAssign, TypeAssign,
UnaryExpr, UnaryExpr,
VariableExpr, VariableExpr,
@@ -245,3 +246,11 @@ class PythonAstJsonSerializer(
"type": expr.type.accept(self), "type": expr.type.accept(self),
"expr": expr.expr.accept(self), "expr": expr.expr.accept(self),
} }
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
return {
"_type": "TernaryExpr",
"test": expr.test.accept(self),
"if_true": expr.if_true.accept(self),
"if_false": expr.if_false.accept(self),
}