Compare commits

...

9 Commits

15 changed files with 626 additions and 494 deletions

View File

@@ -1,6 +1,6 @@
type Meter(float)
type Second(float)
type MeterPerSecond(float)
type Meter = float
type Second = float
type MeterPerSecond = float
extend Meter {
op __add__(Meter) -> Meter

View File

@@ -4,13 +4,20 @@ def minimum(x: int, y: int):
else:
return y
a = 15
b = 72
c = minimum(a, b)
def factorial(n: int) -> int:
if n <= 1:
return 1
return n * factorial(n - 1)
category = "Category 1" if a < 10 else "Category 2"
category = "Category 1" if a < 10 else "Category 2"
def foo() -> None:
pass

View File

@@ -283,7 +283,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def indented(self, text: str) -> str:
return " " * (self.level * self.indent) + text
def print(self, expr: m.Expr | m.Stmt):
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
self.level = 0
return expr.accept(self)
@@ -314,13 +314,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
for op in stmt.operations:
res += op.accept(self)
self.level -= 1
res += "\n" + self.indented("}")
res += self.indented("}")
return res
def visit_op_stmt(self, stmt: m.OpStmt):
operand: str = stmt.operand.accept(self)
result: str = stmt.result.accept(self)
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}")
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}\n")
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme

View File

@@ -398,7 +398,16 @@ class Checker(
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
return self.look_up_variable(expr.name, expr) or UnknownType()
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: ...
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
left: Type = expr.left.accept(self)
right: Type = expr.right.accept(self)
# TODO: union type
if left != right:
self.error(
expr.location,
f"Operands must be of the same type, left={left} != right={right}",
)
return left
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...

View File

@@ -53,5 +53,6 @@ span {
&.keyword {
color: rgb(211, 72, 9);
pointer-events: none;
}
}

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Generic, Optional, Protocol, TextIO, TypeVar
@@ -8,6 +9,7 @@ import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic
from midas.lexer.token import Token
H = TypeVar("H", bound="Highlighter", contravariant=True)
@@ -22,6 +24,15 @@ class Locatable(Protocol):
def location(self) -> Optional[Location]: ...
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
class Highlighter(ABC):
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
EXTRA_CSS_PATH: Optional[Path] = None
@@ -206,34 +217,22 @@ class PythonHighlighter(
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
def highlight(self, node: Highlightable[MidasHighlighter]):
node.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
self.wrap(stmt, "simple-type")
if stmt.template is not None:
stmt.template.accept(self)
stmt.base.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None:
self.wrap(stmt, "complex-type")
if stmt.template is not None:
stmt.template.accept(self)
for prop in stmt.properties:
prop.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self.wrap(stmt, "type-stmt")
self.wrap(LocatableToken(stmt.name), "type-name")
stmt.type.accept(self)
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
self.wrap(stmt, "property")
stmt.type.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self.wrap(stmt, "extend")
@@ -243,17 +242,16 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
self.wrap(stmt, "op")
self.wrap(LocatableToken(stmt.name), "op-name")
stmt.operand.accept(self)
stmt.result.accept(self)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate")
self.wrap(LocatableToken(stmt.name), "predicate-name")
stmt.type.accept(self)
stmt.condition.accept(self)
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None:
self.wrap(expr, "simple-type-expr")
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.wrap(expr, "logical-expr")
expr.left.accept(self)
@@ -282,14 +280,29 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
self.wrap(expr, "template")
expr.type.accept(self)
def visit_named_type(self, type: m.NamedType) -> None:
self.wrap(type, "named-type")
def visit_type_expr(self, expr: m.TypeExpr) -> None:
self.wrap(expr, "type")
if expr.template is not None:
expr.template.accept(self)
def visit_generic_type(self, type: m.GenericType) -> None:
self.wrap(type, "generic-type")
type.type.accept(self)
for param in type.params:
param.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self.wrap(type, "constraint-type")
type.type.accept(self)
type.constraint.accept(self)
def visit_union_type(self, type: m.UnionType) -> None:
self.wrap(type, "union-type")
for type_ in type.types:
type_.accept(self)
def visit_complex_type(self, type: m.ComplexType) -> None:
self.wrap(type, "complex-type")
for prop in type.properties:
prop.accept(self)
class DiagnosticsHighlighter(Highlighter):

View File

@@ -5,12 +5,12 @@ span {
font-style: italic;
}
&.simple-type {
--col: 108, 233, 108;
}
&.named-type,
&.generic-type,
&.constraint-type,
&.union-type,
&.complex-type {
--col: 233, 206, 108;
--col: 150, 150, 150;
}
&.constraint {
@@ -33,10 +33,6 @@ span {
--col: 193, 108, 233;
}
&.simple-type-expr {
--col: 150, 150, 150;
}
&.logical-expr,
&.binary-expr,
&.unary-expr,
@@ -48,7 +44,9 @@ span {
--col: 163, 117, 71;
}
&.type {
&.type-name,
&.op-name,
&.predicate-name {
--col: 200, 200, 200;
font-weight: bold;
}

View File

@@ -1,7 +1,6 @@
import ast
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, TextIO, get_args
@@ -9,14 +8,14 @@ import click
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
from midas.checker.checker import Checker
from midas.checker.diagnostic import Diagnostic
from midas.checker.types import Type
from midas.cli.highlighter import (
DiagnosticsHighlighter,
Highlighter,
LocatableToken,
MidasHighlighter,
PythonHighlighter,
)
@@ -142,14 +141,6 @@ def highlight_midas(source: str, path: str) -> Highlighter:
for err in parser.errors:
print(err.get_report())
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
for stmt in stmts:
highlighter.highlight(stmt)
for token in tokens:
@@ -176,5 +167,21 @@ def highlight(output: TextIO, file: TextIO):
highlighter.dump(output)
@midas.command()
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.argument("file", type=click.File("r"))
def format(output: TextIO, file: TextIO):
source: str = file.read()
printer = MidasPrinter()
lexer = MidasLexer(source, file=file.name)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
for err in parser.errors:
print(err.get_report())
for stmt in stmts:
output.write(printer.print(stmt) + "\n")
if __name__ == "__main__":
midas()

View File

@@ -87,6 +87,9 @@ class PythonParser:
case ast.If():
return self.parse_if(node)
case ast.Pass():
return None
case _:
print(f"Unsupported statement: {ast.unparse(node)}")
return None
@@ -311,6 +314,13 @@ class PythonParser:
constraint=right_expr,
)
case ast.Constant(value=None):
return BaseType(
location=loc,
base="None",
param=None,
)
case _:
raise UnsupportedSyntaxError(type_expr)

View File

@@ -2,6 +2,7 @@ from typing import Optional
import midas.ast.midas as m
from midas.checker.types import (
AliasType,
Type,
UnionType,
UnknownType,
@@ -103,7 +104,8 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
for param in stmt.params:
if param.bound is not None:
param.bound.accept(self)
self.define_type(stmt.name.lexeme, type)
name: str = stmt.name.lexeme
self.define_type(name, AliasType(name=name, type=type))
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...

View File

@@ -1,6 +1,6 @@
type Meter(float)
type Second(float)
type MeterPerSecond(float)
type Meter = float
type Second = float
type MeterPerSecond = float
extend Meter {
op __add__(Meter) -> Meter

View File

@@ -1,15 +1,15 @@
// Simple custom type derived from float
type Custom(float)
type Custom = float
// Simple custom types with constraints
type Latitude(float) where (-90 <= _ <= 90)
type Longitude(float) where (-180 <= _ <= 180)
type Latitude = float where (-90 <= _ <= 90)
type Longitude = float where (-180 <= _ <= 180)
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
type Difference[T](T)
type Difference[T] = T
// Complex custom type, containing two values accessible through properties
type GeoLocation {
type GeoLocation = {
lat: Latitude
lon: Longitude
}
@@ -24,7 +24,7 @@ extend GeoLocation {
// For complex generics, you need to specify how the genericity the properties
// are handled
type Difference[GeoLocation] {
type Difference[GeoLocation] = {
lat: Difference[Latitude]
lon: Difference[Longitude]
}
@@ -44,11 +44,11 @@ predicate StrictlyPositive(v: float) = v > 0
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
type Person {
type Person = {
name: str
// Property with an inline constraint
age: int? where (0 <= _ < 150)
age: None | (int where (0 <= _ < 150))
// Property referencing a predicate
height: float where StrictlyPositive

File diff suppressed because it is too large Load Diff

View File

@@ -2,56 +2,61 @@ from typing import Optional, Sequence
from midas.ast.midas import (
BinaryExpr,
ComplexTypeStmt,
ComplexType,
ConstraintType,
Expr,
ExtendStmt,
GenericType,
GetExpr,
GroupingExpr,
LiteralExpr,
LogicalExpr,
NamedType,
OpStmt,
PredicateStmt,
PropertyStmt,
SimpleTypeExpr,
SimpleTypeStmt,
Stmt,
TemplateExpr,
TypeExpr,
Type,
TypeStmt,
UnaryExpr,
UnionType,
VariableExpr,
WildcardExpr,
)
class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
class MidasAstJsonSerializer(
Stmt.Visitor[dict], Expr.Visitor[dict], Type.Visitor[dict]
):
"""An AST serializer which produces a JSON-compatible structure"""
def serialize(self, stmts: list[Stmt]) -> list[dict]:
return [stmt.accept(self) for stmt in stmts]
def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]:
def _serialize_optional(
self, element: Optional[Stmt | Expr | Type]
) -> Optional[dict]:
if element is None:
return None
return element.accept(self)
def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]:
def _serialize_list(self, elements: Sequence[Stmt | Expr | Type]) -> list[dict]:
return [element.accept(self) for element in elements]
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict:
def visit_type_stmt(self, stmt: TypeStmt) -> dict:
return {
"_type": "SimpleTypeStmt",
"_type": "TypeStmt",
"name": stmt.name.lexeme,
"template": self._serialize_optional(stmt.template),
"base": stmt.base.accept(self),
"constraint": self._serialize_optional(stmt.constraint),
"params": [
self._serialize_type_stmt_template_param(param) for param in stmt.params
],
"type": stmt.type.accept(self),
}
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict:
def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
return {
"_type": "ComplexTypeStmt",
"name": stmt.name.lexeme,
"template": self._serialize_optional(stmt.template),
"properties": self._serialize_list(stmt.properties),
"name": param.name.lexeme,
"bound": self._serialize_optional(param.bound),
}
def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
@@ -59,7 +64,6 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"_type": "PropertyStmt",
"name": stmt.name.lexeme,
"type": stmt.type.accept(self),
"constraint": self._serialize_optional(stmt.constraint),
}
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
@@ -86,13 +90,6 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"condition": stmt.condition.accept(self),
}
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict:
return {
"_type": "SimpleTypeExpr",
"name": expr.name.lexeme,
"optional": expr.optional,
}
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
return {
"_type": "LogicalExpr",
@@ -144,16 +141,34 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
return {"_type": "WildcardExpr"}
def visit_template_expr(self, expr: TemplateExpr) -> dict:
def visit_named_type(self, type: NamedType) -> dict:
return {
"_type": "TemplateExpr",
"type": expr.type.accept(self),
"_type": "NamedType",
"name": type.name.lexeme,
}
def visit_type_expr(self, expr: TypeExpr) -> dict:
def visit_generic_type(self, type: GenericType) -> dict:
return {
"_type": "TypeExpr",
"name": expr.name.lexeme,
"template": self._serialize_optional(expr.template),
"optional": expr.optional,
"_type": "GenericType",
"type": type.type.accept(self),
"params": self._serialize_list(type.params),
}
def visit_constraint_type(self, type: ConstraintType) -> dict:
return {
"_type": "ConstraintType",
"type": type.type.accept(self),
"constraint": type.constraint.accept(self),
}
def visit_union_type(self, type: UnionType) -> dict:
return {
"_type": "UnionType",
"types": self._serialize_list(type.types),
}
def visit_complex_type(self, type: ComplexType) -> dict:
return {
"_type": "ComplexType",
"properties": self._serialize_list(type.properties),
}

View File

@@ -22,6 +22,7 @@ from midas.ast.python import (
ReturnStmt,
SetExpr,
Stmt,
TernaryExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
@@ -245,3 +246,11 @@ class PythonAstJsonSerializer(
"type": expr.type.accept(self),
"expr": expr.expr.accept(self),
}
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
return {
"_type": "TernaryExpr",
"test": expr.test.accept(self),
"if_true": expr.if_true.accept(self),
"if_false": expr.if_false.accept(self),
}