Compare commits

...

9 Commits

15 changed files with 626 additions and 494 deletions

View File

@@ -1,6 +1,6 @@
type Meter(float) type Meter = float
type Second(float) type Second = float
type MeterPerSecond(float) type MeterPerSecond = float
extend Meter { extend Meter {
op __add__(Meter) -> Meter op __add__(Meter) -> Meter

View File

@@ -4,13 +4,20 @@ def minimum(x: int, y: int):
else: else:
return y return y
a = 15 a = 15
b = 72 b = 72
c = minimum(a, b) c = minimum(a, b)
def factorial(n: int) -> int: def factorial(n: int) -> int:
if n <= 1: if n <= 1:
return 1 return 1
return n * factorial(n - 1) return n * factorial(n - 1)
category = "Category 1" if a < 10 else "Category 2" category = "Category 1" if a < 10 else "Category 2"
def foo() -> None:
pass

View File

@@ -283,7 +283,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def indented(self, text: str) -> str: def indented(self, text: str) -> str:
return " " * (self.level * self.indent) + text return " " * (self.level * self.indent) + text
def print(self, expr: m.Expr | m.Stmt): def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
self.level = 0 self.level = 0
return expr.accept(self) return expr.accept(self)
@@ -314,13 +314,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
for op in stmt.operations: for op in stmt.operations:
res += op.accept(self) res += op.accept(self)
self.level -= 1 self.level -= 1
res += "\n" + self.indented("}") res += self.indented("}")
return res return res
def visit_op_stmt(self, stmt: m.OpStmt): def visit_op_stmt(self, stmt: m.OpStmt):
operand: str = stmt.operand.accept(self) operand: str = stmt.operand.accept(self)
result: str = stmt.result.accept(self) result: str = stmt.result.accept(self)
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}") return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}\n")
def visit_predicate_stmt(self, stmt: m.PredicateStmt): def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme name: str = stmt.name.lexeme

View File

@@ -398,7 +398,16 @@ class Checker(
def visit_variable_expr(self, expr: p.VariableExpr) -> Type: def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
return self.look_up_variable(expr.name, expr) or UnknownType() return self.look_up_variable(expr.name, expr) or UnknownType()
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: ... def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
left: Type = expr.left.accept(self)
right: Type = expr.right.accept(self)
# TODO: union type
if left != right:
self.error(
expr.location,
f"Operands must be of the same type, left={left} != right={right}",
)
return left
def visit_set_expr(self, expr: p.SetExpr) -> Type: ... def visit_set_expr(self, expr: p.SetExpr) -> Type: ...

View File

@@ -53,5 +53,6 @@ span {
&.keyword { &.keyword {
color: rgb(211, 72, 9); color: rgb(211, 72, 9);
pointer-events: none;
} }
} }

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Generic, Optional, Protocol, TextIO, TypeVar from typing import Generic, Optional, Protocol, TextIO, TypeVar
@@ -8,6 +9,7 @@ import midas.ast.midas as m
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic from midas.checker.diagnostic import Diagnostic
from midas.lexer.token import Token
H = TypeVar("H", bound="Highlighter", contravariant=True) H = TypeVar("H", bound="Highlighter", contravariant=True)
@@ -22,6 +24,15 @@ class Locatable(Protocol):
def location(self) -> Optional[Location]: ... def location(self) -> Optional[Location]: ...
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
class Highlighter(ABC): class Highlighter(ABC):
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css" BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
EXTRA_CSS_PATH: Optional[Path] = None EXTRA_CSS_PATH: Optional[Path] = None
@@ -206,34 +217,22 @@ class PythonHighlighter(
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ... def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]): class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css" EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
def highlight(self, node: Highlightable[MidasHighlighter]): def highlight(self, node: Highlightable[MidasHighlighter]):
node.accept(self) node.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None: def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self.wrap(stmt, "simple-type") self.wrap(stmt, "type-stmt")
if stmt.template is not None: self.wrap(LocatableToken(stmt.name), "type-name")
stmt.template.accept(self) stmt.type.accept(self)
stmt.base.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None:
self.wrap(stmt, "complex-type")
if stmt.template is not None:
stmt.template.accept(self)
for prop in stmt.properties:
prop.accept(self)
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
self.wrap(stmt, "property") self.wrap(stmt, "property")
stmt.type.accept(self) stmt.type.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None: def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self.wrap(stmt, "extend") self.wrap(stmt, "extend")
@@ -243,17 +242,16 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
def visit_op_stmt(self, stmt: m.OpStmt) -> None: def visit_op_stmt(self, stmt: m.OpStmt) -> None:
self.wrap(stmt, "op") self.wrap(stmt, "op")
self.wrap(LocatableToken(stmt.name), "op-name")
stmt.operand.accept(self) stmt.operand.accept(self)
stmt.result.accept(self) stmt.result.accept(self)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate") self.wrap(stmt, "predicate")
self.wrap(LocatableToken(stmt.name), "predicate-name")
stmt.type.accept(self) stmt.type.accept(self)
stmt.condition.accept(self) stmt.condition.accept(self)
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None:
self.wrap(expr, "simple-type-expr")
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.wrap(expr, "logical-expr") self.wrap(expr, "logical-expr")
expr.left.accept(self) expr.left.accept(self)
@@ -282,14 +280,29 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ... def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_template_expr(self, expr: m.TemplateExpr) -> None: def visit_named_type(self, type: m.NamedType) -> None:
self.wrap(expr, "template") self.wrap(type, "named-type")
expr.type.accept(self)
def visit_type_expr(self, expr: m.TypeExpr) -> None: def visit_generic_type(self, type: m.GenericType) -> None:
self.wrap(expr, "type") self.wrap(type, "generic-type")
if expr.template is not None: type.type.accept(self)
expr.template.accept(self) for param in type.params:
param.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self.wrap(type, "constraint-type")
type.type.accept(self)
type.constraint.accept(self)
def visit_union_type(self, type: m.UnionType) -> None:
self.wrap(type, "union-type")
for type_ in type.types:
type_.accept(self)
def visit_complex_type(self, type: m.ComplexType) -> None:
self.wrap(type, "complex-type")
for prop in type.properties:
prop.accept(self)
class DiagnosticsHighlighter(Highlighter): class DiagnosticsHighlighter(Highlighter):

View File

@@ -5,12 +5,12 @@ span {
font-style: italic; font-style: italic;
} }
&.simple-type { &.named-type,
--col: 108, 233, 108; &.generic-type,
} &.constraint-type,
&.union-type,
&.complex-type { &.complex-type {
--col: 233, 206, 108; --col: 150, 150, 150;
} }
&.constraint { &.constraint {
@@ -33,10 +33,6 @@ span {
--col: 193, 108, 233; --col: 193, 108, 233;
} }
&.simple-type-expr {
--col: 150, 150, 150;
}
&.logical-expr, &.logical-expr,
&.binary-expr, &.binary-expr,
&.unary-expr, &.unary-expr,
@@ -48,7 +44,9 @@ span {
--col: 163, 117, 71; --col: 163, 117, 71;
} }
&.type { &.type-name,
&.op-name,
&.predicate-name {
--col: 200, 200, 200; --col: 200, 200, 200;
font-weight: bold; font-weight: bold;
} }

View File

@@ -1,7 +1,6 @@
import ast import ast
import json import json
import logging import logging
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, TextIO, get_args from typing import Optional, TextIO, get_args
@@ -9,14 +8,14 @@ import click
import midas.ast.midas as m import midas.ast.midas as m
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
from midas.checker.checker import Checker from midas.checker.checker import Checker
from midas.checker.diagnostic import Diagnostic from midas.checker.diagnostic import Diagnostic
from midas.checker.types import Type from midas.checker.types import Type
from midas.cli.highlighter import ( from midas.cli.highlighter import (
DiagnosticsHighlighter, DiagnosticsHighlighter,
Highlighter, Highlighter,
LocatableToken,
MidasHighlighter, MidasHighlighter,
PythonHighlighter, PythonHighlighter,
) )
@@ -142,14 +141,6 @@ def highlight_midas(source: str, path: str) -> Highlighter:
for err in parser.errors: for err in parser.errors:
print(err.get_report()) print(err.get_report())
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
for stmt in stmts: for stmt in stmts:
highlighter.highlight(stmt) highlighter.highlight(stmt)
for token in tokens: for token in tokens:
@@ -176,5 +167,21 @@ def highlight(output: TextIO, file: TextIO):
highlighter.dump(output) highlighter.dump(output)
@midas.command()
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.argument("file", type=click.File("r"))
def format(output: TextIO, file: TextIO):
source: str = file.read()
printer = MidasPrinter()
lexer = MidasLexer(source, file=file.name)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
for err in parser.errors:
print(err.get_report())
for stmt in stmts:
output.write(printer.print(stmt) + "\n")
if __name__ == "__main__": if __name__ == "__main__":
midas() midas()

View File

@@ -87,6 +87,9 @@ class PythonParser:
case ast.If(): case ast.If():
return self.parse_if(node) return self.parse_if(node)
case ast.Pass():
return None
case _: case _:
print(f"Unsupported statement: {ast.unparse(node)}") print(f"Unsupported statement: {ast.unparse(node)}")
return None return None
@@ -311,6 +314,13 @@ class PythonParser:
constraint=right_expr, constraint=right_expr,
) )
case ast.Constant(value=None):
return BaseType(
location=loc,
base="None",
param=None,
)
case _: case _:
raise UnsupportedSyntaxError(type_expr) raise UnsupportedSyntaxError(type_expr)

View File

@@ -2,6 +2,7 @@ from typing import Optional
import midas.ast.midas as m import midas.ast.midas as m
from midas.checker.types import ( from midas.checker.types import (
AliasType,
Type, Type,
UnionType, UnionType,
UnknownType, UnknownType,
@@ -103,7 +104,8 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[T
for param in stmt.params: for param in stmt.params:
if param.bound is not None: if param.bound is not None:
param.bound.accept(self) param.bound.accept(self)
self.define_type(stmt.name.lexeme, type) name: str = stmt.name.lexeme
self.define_type(name, AliasType(name=name, type=type))
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ... def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...

View File

@@ -1,6 +1,6 @@
type Meter(float) type Meter = float
type Second(float) type Second = float
type MeterPerSecond(float) type MeterPerSecond = float
extend Meter { extend Meter {
op __add__(Meter) -> Meter op __add__(Meter) -> Meter

View File

@@ -1,15 +1,15 @@
// Simple custom type derived from float // Simple custom type derived from float
type Custom(float) type Custom = float
// Simple custom types with constraints // Simple custom types with constraints
type Latitude(float) where (-90 <= _ <= 90) type Latitude = float where (-90 <= _ <= 90)
type Longitude(float) where (-180 <= _ <= 180) type Longitude = float where (-180 <= _ <= 180)
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float // Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
type Difference[T](T) type Difference[T] = T
// Complex custom type, containing two values accessible through properties // Complex custom type, containing two values accessible through properties
type GeoLocation { type GeoLocation = {
lat: Latitude lat: Latitude
lon: Longitude lon: Longitude
} }
@@ -24,7 +24,7 @@ extend GeoLocation {
// For complex generics, you need to specify how the genericity the properties // For complex generics, you need to specify how the genericity the properties
// are handled // are handled
type Difference[GeoLocation] { type Difference[GeoLocation] = {
lat: Difference[Latitude] lat: Difference[Latitude]
lon: Difference[Longitude] lon: Difference[Longitude]
} }
@@ -44,11 +44,11 @@ predicate StrictlyPositive(v: float) = v > 0
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10) predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66) predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
type Person { type Person = {
name: str name: str
// Property with an inline constraint // Property with an inline constraint
age: int? where (0 <= _ < 150) age: None | (int where (0 <= _ < 150))
// Property referencing a predicate // Property referencing a predicate
height: float where StrictlyPositive height: float where StrictlyPositive

File diff suppressed because it is too large Load Diff

View File

@@ -2,56 +2,61 @@ from typing import Optional, Sequence
from midas.ast.midas import ( from midas.ast.midas import (
BinaryExpr, BinaryExpr,
ComplexTypeStmt, ComplexType,
ConstraintType,
Expr, Expr,
ExtendStmt, ExtendStmt,
GenericType,
GetExpr, GetExpr,
GroupingExpr, GroupingExpr,
LiteralExpr, LiteralExpr,
LogicalExpr, LogicalExpr,
NamedType,
OpStmt, OpStmt,
PredicateStmt, PredicateStmt,
PropertyStmt, PropertyStmt,
SimpleTypeExpr,
SimpleTypeStmt,
Stmt, Stmt,
TemplateExpr, Type,
TypeExpr, TypeStmt,
UnaryExpr, UnaryExpr,
UnionType,
VariableExpr, VariableExpr,
WildcardExpr, WildcardExpr,
) )
class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]): class MidasAstJsonSerializer(
Stmt.Visitor[dict], Expr.Visitor[dict], Type.Visitor[dict]
):
"""An AST serializer which produces a JSON-compatible structure""" """An AST serializer which produces a JSON-compatible structure"""
def serialize(self, stmts: list[Stmt]) -> list[dict]: def serialize(self, stmts: list[Stmt]) -> list[dict]:
return [stmt.accept(self) for stmt in stmts] return [stmt.accept(self) for stmt in stmts]
def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]: def _serialize_optional(
self, element: Optional[Stmt | Expr | Type]
) -> Optional[dict]:
if element is None: if element is None:
return None return None
return element.accept(self) return element.accept(self)
def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]: def _serialize_list(self, elements: Sequence[Stmt | Expr | Type]) -> list[dict]:
return [element.accept(self) for element in elements] return [element.accept(self) for element in elements]
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict: def visit_type_stmt(self, stmt: TypeStmt) -> dict:
return { return {
"_type": "SimpleTypeStmt", "_type": "TypeStmt",
"name": stmt.name.lexeme, "name": stmt.name.lexeme,
"template": self._serialize_optional(stmt.template), "params": [
"base": stmt.base.accept(self), self._serialize_type_stmt_template_param(param) for param in stmt.params
"constraint": self._serialize_optional(stmt.constraint), ],
"type": stmt.type.accept(self),
} }
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict: def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
return { return {
"_type": "ComplexTypeStmt", "name": param.name.lexeme,
"name": stmt.name.lexeme, "bound": self._serialize_optional(param.bound),
"template": self._serialize_optional(stmt.template),
"properties": self._serialize_list(stmt.properties),
} }
def visit_property_stmt(self, stmt: PropertyStmt) -> dict: def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
@@ -59,7 +64,6 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"_type": "PropertyStmt", "_type": "PropertyStmt",
"name": stmt.name.lexeme, "name": stmt.name.lexeme,
"type": stmt.type.accept(self), "type": stmt.type.accept(self),
"constraint": self._serialize_optional(stmt.constraint),
} }
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict: def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
@@ -86,13 +90,6 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"condition": stmt.condition.accept(self), "condition": stmt.condition.accept(self),
} }
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict:
return {
"_type": "SimpleTypeExpr",
"name": expr.name.lexeme,
"optional": expr.optional,
}
def visit_logical_expr(self, expr: LogicalExpr) -> dict: def visit_logical_expr(self, expr: LogicalExpr) -> dict:
return { return {
"_type": "LogicalExpr", "_type": "LogicalExpr",
@@ -144,16 +141,34 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict: def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
return {"_type": "WildcardExpr"} return {"_type": "WildcardExpr"}
def visit_template_expr(self, expr: TemplateExpr) -> dict: def visit_named_type(self, type: NamedType) -> dict:
return { return {
"_type": "TemplateExpr", "_type": "NamedType",
"type": expr.type.accept(self), "name": type.name.lexeme,
} }
def visit_type_expr(self, expr: TypeExpr) -> dict: def visit_generic_type(self, type: GenericType) -> dict:
return { return {
"_type": "TypeExpr", "_type": "GenericType",
"name": expr.name.lexeme, "type": type.type.accept(self),
"template": self._serialize_optional(expr.template), "params": self._serialize_list(type.params),
"optional": expr.optional, }
def visit_constraint_type(self, type: ConstraintType) -> dict:
return {
"_type": "ConstraintType",
"type": type.type.accept(self),
"constraint": type.constraint.accept(self),
}
def visit_union_type(self, type: UnionType) -> dict:
return {
"_type": "UnionType",
"types": self._serialize_list(type.types),
}
def visit_complex_type(self, type: ComplexType) -> dict:
return {
"_type": "ComplexType",
"properties": self._serialize_list(type.properties),
} }

View File

@@ -22,6 +22,7 @@ from midas.ast.python import (
ReturnStmt, ReturnStmt,
SetExpr, SetExpr,
Stmt, Stmt,
TernaryExpr,
TypeAssign, TypeAssign,
UnaryExpr, UnaryExpr,
VariableExpr, VariableExpr,
@@ -245,3 +246,11 @@ class PythonAstJsonSerializer(
"type": expr.type.accept(self), "type": expr.type.accept(self),
"expr": expr.expr.accept(self), "expr": expr.expr.accept(self),
} }
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
return {
"_type": "TernaryExpr",
"test": expr.test.accept(self),
"if_true": expr.if_true.accept(self),
"if_false": expr.if_false.accept(self),
}