29 Commits

Author SHA1 Message Date
288d15a9bc Merge pull request 'Usage documentation' (#7) from feat/usage-documentation into main
Reviewed-on: #7
2026-06-05 10:29:42 +00:00
504703d0f7 fix(cli): remove print in main command 2026-06-05 12:26:09 +02:00
e48895d0af docs: add usage documentation in README 2026-06-05 12:25:02 +02:00
13d32d0d27 Merge pull request 'Basic type checker' (#6) from feat/basic-type-checker into main
Reviewed-on: #6
2026-06-05 09:31:53 +00:00
19b9fdd623 Merge pull request 'Improve syntax and types' (#5) from feat/improve-syntax-and-types into feat/basic-type-checker
Reviewed-on: #5
2026-06-05 09:20:56 +00:00
ddcaebb51a fix: remove outdated syntax definition 2026-06-05 11:19:29 +02:00
f182312cd2 fix: update midas syntax definitions 2026-06-05 11:14:53 +02:00
73b21789d5 fix(tests): remove custom imports 2026-06-05 10:48:46 +02:00
5d7c724bc8 fix(cli): add types files argument 2026-06-05 10:44:20 +02:00
74b297c89c feat(checker): remove custom midas import
remove custom import statement (`midas.using`) in favor of passing type definition files as arguments to the checker
2026-06-05 10:43:52 +02:00
822a74acce refactor(checker): rename methods
improve a couple methods names, namely evaluate → type_of and evaluate_block → process_block
2026-06-03 13:03:41 +02:00
9a934fabfd tests: remove union type 2026-06-02 17:22:19 +02:00
828ec9a3fa fix!: remove union type 2026-06-02 17:19:17 +02:00
63a43d79dd chore: update examples 2026-06-02 13:07:53 +02:00
029caf4526 fix(tests): update tests with new syntax 2026-06-02 13:05:38 +02:00
1c5c418f1c fix(tests): serialize ternary expressions 2026-06-02 13:05:06 +02:00
a4139d4652 feat(checker): handle logical expressions 2026-06-02 13:03:07 +02:00
2fd2071d40 feat(parser): parse pass statement and None 2026-06-02 13:02:45 +02:00
97b1ee8ab8 feat(cli): add format command 2026-06-02 13:00:43 +02:00
dee479def5 fix(checker): wrap type definitions in AliasType 2026-06-02 13:00:03 +02:00
c8536e20d2 feat(tests): update Midas serializer 2026-06-02 12:38:58 +02:00
d70137775f feat(cli): update highlighter with new nodes 2026-06-02 12:29:39 +02:00
35ceda99aa chore: tidy 2026-06-02 11:45:49 +02:00
7f3d74ee49 feat(checker)!: resolve new types 2026-06-02 11:44:31 +02:00
b9f378de6f feat(parser)!: update Midas parser with new nodes 2026-06-02 11:42:35 +02:00
ccb17c7290 feat(parser)!: add new Midas AST nodes 2026-06-02 11:41:53 +02:00
505779310a feat: add new midas syntax example 2026-06-02 11:40:42 +02:00
bea3f399ad feat(checker): handle ternary expression 2026-06-01 15:02:12 +02:00
55060bfecd feat(parser): add ternary statement 2026-06-01 15:00:21 +02:00
36 changed files with 1274 additions and 879 deletions

View File

@@ -5,3 +5,82 @@
*Midas* aims at providing Python developers with a simple annotation system to enable compile-time integrity and data type checks, as well as generating runtime assertions.
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
## Requirements
- Python 3.11+
- [uv](https://docs.astral.sh/uv/getting-started/installation/)
## Installation
1. Clone the repository
```shell
git clone https://git.kb28.ch/HEL/midas.git
```
2. Go in the project directory
```shell
cd midas
```
3. Install the CLI as a user-wide tool
```shell
uv tool install .
```
4. You can now run the `midas` command from anywhere
```shell
midas --help
```
## Commands
### Compiling
> [!NOTE]
> In the current state of the project, the `compile` command doesn't generate any runnable code, it only runs the parsers and type checker on the provided files
```shell
midas compile -t types.midas source.py
```
With the `compile` command, you can process a source Python file, with any number of custom type definition files (`-t FILE` option), and the type checker will verify the coherence of your program and generate the runnable code with valid syntax and runtime assertions.
The optional `-l FILE` option lets you produce a highlighted version of the source code showing diagnostics from the type checker (see [Highlighting](#highlighting))
### Highlighting
```shell
midas utils highlight source.py
# or
midas utils highlight types.midas
```
The `highlight` command takes in a source file (Python or Midas), runs the appropriate parser and outputs an HTML file containing the source code with added highlighting. This highlighting takes the form of hoverable annotations showing some of the parsed structures (e.g. a function definition, an assignment, a generic type, etc.)
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
### Dumping the AST
```shell
midas utils dump-ast source.py
# or
midas utils dump-ast types.midas
```
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `-p` flags lets you toggle the custom AST parsing. Without `-p`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
## Tests
Several snapshot tests are available to assert the good behaviour of the parsers and type checker. They can be run as follows:
```shell
uv run -m tests.midas run -a
uv run -m tests.python run -a
uv run -m tests.checker run -a
```
**Available subcommands:**
- Run all tests: `run -a`
- Run specific tests: `run tests/cases/test1.py tests/cases/test2.py ...`
- Update all tests: `update -a`
- Update specific tests: `update tests/cases/test1.py tests/cases/test2.py ...`

View File

@@ -2,10 +2,6 @@
# ruff: disable[F821]
from __future__ import annotations
# Prototype of custom type import to use valid Python syntax
import midas
midas.using("02_custom_types.midas")
# A data-frame using a custom type
df: Frame[
location: GeoLocation

View File

@@ -0,0 +1,33 @@
type Foo1 = float
type Foo2 = float where (_ > 3)
type Foo3 = int | float
type Foo4 = int where (_ > 3) | float where (_ > 3)
type Foo5 = (int | float) where (_ > 3)
type Foo6 = {
foo: float
bar: float where (_ > 3)
}
type Foo7[T] = T where (_ > 3)
type Foo8[A, B<:int] = {
a: A
b: B
}
type Complex = {
a: int
b: int
}
type Complex2 = Complex where (_.a > 3 & _.b < 5)
predicate Positive(n: int) = n >= 0
extend Foo1 {
op __add__(Foo1) -> Foo1
}
extend Foo7[T] {
op __add__(Foo7[T]) -> Foo7[T]
}
type Optional[T] = None | T

View File

@@ -1,6 +1,6 @@
type Meter(float)
type Second(float)
type MeterPerSecond(float)
type Meter = float
type Second = float
type MeterPerSecond = float
extend Meter {
op __add__(Meter) -> Meter

View File

@@ -1,8 +1,6 @@
# type: ignore
# ruff: disable [F821]
midas.using("02_simple_types.midas")
distance: Meter = cast(Meter, 123.45)
time: Second = cast(Second, 6.7)
speed = distance / time

View File

@@ -4,11 +4,20 @@ def minimum(x: int, y: int):
else:
return y
a = 15
b = 72
c = minimum(a, b)
def factorial(n: int) -> int:
if n <= 1:
return 1
return n * factorial(n - 1)
return n * factorial(n - 1)
category = "Category 1" if a < 10 else "Category 2"
def foo() -> None:
pass

View File

@@ -13,40 +13,38 @@ from midas.lexer.token import Token
###> Stmt | Statements
class SimpleTypeStmt:
class TypeStmt:
name: Token
template: Optional[TemplateExpr]
base: TypeExpr
constraint: Optional[Expr]
params: list[Param]
type: Type
class ComplexTypeStmt:
name: Token
template: Optional[TemplateExpr]
properties: list[PropertyStmt]
@dataclass(frozen=True, kw_only=True)
class Param:
location: Location
name: Token
bound: Optional[Type]
class PropertyStmt:
name: Token
type: TypeExpr
constraint: Optional[Expr]
type: Type
class ExtendStmt:
type: TypeExpr
type: Type
operations: list[OpStmt]
class OpStmt:
name: Token
operand: TypeExpr
result: TypeExpr
operand: Type
result: Type
class PredicateStmt:
name: Token
subject: Token
type: TypeExpr
type: Type
condition: Expr
@@ -54,9 +52,6 @@ class PredicateStmt:
###> Expr | Expressions
class SimpleTypeExpr:
name: Token
optional: bool
class LogicalExpr:
@@ -97,14 +92,27 @@ class WildcardExpr:
token: Token
class TemplateExpr:
type: TypeExpr
###<
###> Type | Types
class TypeExpr:
class NamedType:
name: Token
template: Optional[TemplateExpr]
optional: bool
class GenericType:
type: Type
params: list[Type]
class ConstraintType:
type: Type
constraint: Expr
class ComplexType:
properties: list[PropertyStmt]
###<

View File

@@ -139,4 +139,10 @@ class CastExpr:
expr: Expr
class TernaryExpr:
test: Expr
if_true: Expr
if_false: Expr
###<

View File

@@ -28,10 +28,7 @@ class Stmt(ABC):
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> T: ...
@abstractmethod
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> T: ...
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
@abstractmethod
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
@@ -47,31 +44,25 @@ class Stmt(ABC):
@dataclass(frozen=True)
class SimpleTypeStmt(Stmt):
class TypeStmt(Stmt):
name: Token
template: Optional[TemplateExpr]
base: TypeExpr
constraint: Optional[Expr]
params: list[Param]
type: Type
@dataclass(frozen=True, kw_only=True)
class Param:
location: Location
name: Token
bound: Optional[Type]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_simple_type_stmt(self)
@dataclass(frozen=True)
class ComplexTypeStmt(Stmt):
name: Token
template: Optional[TemplateExpr]
properties: list[PropertyStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_complex_type_stmt(self)
return visitor.visit_type_stmt(self)
@dataclass(frozen=True)
class PropertyStmt(Stmt):
name: Token
type: TypeExpr
constraint: Optional[Expr]
type: Type
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_property_stmt(self)
@@ -79,7 +70,7 @@ class PropertyStmt(Stmt):
@dataclass(frozen=True)
class ExtendStmt(Stmt):
type: TypeExpr
type: Type
operations: list[OpStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
@@ -89,8 +80,8 @@ class ExtendStmt(Stmt):
@dataclass(frozen=True)
class OpStmt(Stmt):
name: Token
operand: TypeExpr
result: TypeExpr
operand: Type
result: Type
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_op_stmt(self)
@@ -100,7 +91,7 @@ class OpStmt(Stmt):
class PredicateStmt(Stmt):
name: Token
subject: Token
type: TypeExpr
type: Type
condition: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
@@ -120,9 +111,6 @@ class Expr(ABC):
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> T: ...
@abstractmethod
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
@@ -147,21 +135,6 @@ class Expr(ABC):
@abstractmethod
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
@abstractmethod
def visit_template_expr(self, expr: TemplateExpr) -> T: ...
@abstractmethod
def visit_type_expr(self, expr: TypeExpr) -> T: ...
@dataclass(frozen=True)
class SimpleTypeExpr(Expr):
name: Token
optional: bool
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_simple_type_expr(self)
@dataclass(frozen=True)
class LogicalExpr(Expr):
@@ -233,19 +206,61 @@ class WildcardExpr(Expr):
return visitor.visit_wildcard_expr(self)
@dataclass(frozen=True)
class TemplateExpr(Expr):
type: TypeExpr
#########
# Types #
#########
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_template_expr(self)
@dataclass(frozen=True, kw_only=True)
class Type(ABC):
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_named_type(self, type: NamedType) -> T: ...
@abstractmethod
def visit_generic_type(self, type: GenericType) -> T: ...
@abstractmethod
def visit_constraint_type(self, type: ConstraintType) -> T: ...
@abstractmethod
def visit_complex_type(self, type: ComplexType) -> T: ...
@dataclass(frozen=True)
class TypeExpr(Expr):
class NamedType(Type):
name: Token
template: Optional[TemplateExpr]
optional: bool
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_type_expr(self)
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_named_type(self)
@dataclass(frozen=True)
class GenericType(Type):
type: Type
params: list[Type]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_generic_type(self)
@dataclass(frozen=True)
class ConstraintType(Type):
type: Type
constraint: Expr
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_constraint_type(self)
@dataclass(frozen=True)
class ComplexType(Type):
properties: list[PropertyStmt]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_complex_type(self)

View File

@@ -85,40 +85,39 @@ class AstPrinter(Generic[T]):
child.accept(self)
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
class MidasAstPrinter(
AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None], m.Type.Visitor[None]
):
# Statements
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
self._write_line("SimpleTypeStmt")
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self._write_line("TypeStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_optional_child("template", stmt.template)
self._write_line("base")
with self._child_level(single=True):
stmt.base.accept(self)
self._write_optional_child("constraint", stmt.constraint, last=True)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
self._write_line("ComplexTypeStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_optional_child("template", stmt.template)
self._write_line("properties", last=True)
self._write_line("params")
with self._child_level():
for i, prop in enumerate(stmt.properties):
for i, param in enumerate(stmt.params):
self._idx = i
if i == len(stmt.properties) - 1:
if i == len(stmt.params) - 1:
self._mark_last()
prop.accept(self)
self._print_type_stmt_param(param)
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
self._write_line("Param")
with self._child_level():
self._write_line(f'name: "{param.name.lexeme}"')
self._write_optional_child("bound", param.bound, last=True)
def visit_property_stmt(self, stmt: m.PropertyStmt):
self._write_line("PropertyStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type")
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
self._write_optional_child("constraint", stmt.constraint, last=True)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._write_line("ExtendStmt")
@@ -161,12 +160,6 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
# Expressions
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
self._write_line("SimpleTypeExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_line(f"optional: {expr.optional}", last=True)
def visit_logical_expr(self, expr: m.LogicalExpr):
self._write_line("LogicalExpr")
with self._child_level():
@@ -230,22 +223,48 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self._write_line("WildcardExpr")
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
self._write_line("TemplateExpr")
with self._child_level(single=True):
def visit_named_type(self, type: m.NamedType) -> None:
self._write_line("NamedType")
with self._child_level():
self._write_line(f'name: "{type.name.lexeme}"', last=True)
def visit_generic_type(self, type: m.GenericType) -> None:
self._write_line("GenericType")
with self._child_level():
self._write_line("type")
with self._child_level():
type.type.accept(self)
self._write_line("params", last=True)
with self._child_level():
for i, param in enumerate(type.params):
self._idx = i
if i == len(type.params) - 1:
self._mark_last()
param.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self._write_line("ConstraintType")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
expr.type.accept(self)
type.type.accept(self)
self._write_line("constraint", last=True)
with self._child_level(single=True):
type.constraint.accept(self)
def visit_type_expr(self, expr: m.TypeExpr):
self._write_line("TypeExpr")
def visit_complex_type(self, type: m.ComplexType) -> None:
self._write_line("ComplexType")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_optional_child("template", expr.template)
self._write_line(f"optional: {expr.optional}", last=True)
self._write_line("properties", last=True)
with self._child_level():
for i, prop in enumerate(type.properties):
self._idx = i
if i == len(type.properties) - 1:
self._mark_last()
prop.accept(self)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
def __init__(self, indent: int = 4):
self.indent: int = indent
self.level: int = 0
@@ -253,33 +272,28 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def indented(self, text: str) -> str:
return " " * (self.level * self.indent) + text
def print(self, expr: m.Expr | m.Stmt):
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
self.level = 0
return expr.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
template: str = stmt.template.accept(self) if stmt.template is not None else ""
res: str = f"type {stmt.name.lexeme}{template}({stmt.base.accept(self)})"
if stmt.constraint is not None:
res += " where " + stmt.constraint.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
template: str = ""
if len(stmt.params) != 0:
params: list[str] = [
self._print_type_template_param(param) for param in stmt.params
]
template = f"[{', '.join(params)}]"
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
return self.indented(res)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
template: str = stmt.template.accept(self) if stmt.template is not None else ""
res: str = self.indented(f"type {stmt.name.lexeme}{template}")
res += " {\n"
self.level += 1
for prop in stmt.properties:
res += prop.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
res: str = param.name.lexeme
if param.bound is not None:
res += "<:" + param.bound.accept(self)
return res
def visit_property_stmt(self, stmt: m.PropertyStmt):
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
if stmt.constraint is not None:
res += " where " + stmt.constraint.accept(self)
return self.indented(res)
def visit_extend_stmt(self, stmt: m.ExtendStmt):
@@ -289,13 +303,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
for op in stmt.operations:
res += op.accept(self)
self.level -= 1
res += "\n" + self.indented("}")
res += self.indented("}")
return res
def visit_op_stmt(self, stmt: m.OpStmt):
operand: str = stmt.operand.accept(self)
result: str = stmt.result.accept(self)
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}")
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}\n")
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme
@@ -304,9 +318,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
condition: str = stmt.condition.accept(self)
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
return f"{expr.name.lexeme}{'?' if expr.optional else ''}"
def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self)
operator: str = expr.operator.lexeme
@@ -342,12 +353,30 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def visit_wildcard_expr(self, expr: m.WildcardExpr):
return "_"
def visit_template_expr(self, expr: m.TemplateExpr):
return f"[{expr.type.accept(self)}]"
def visit_named_type(self, type: m.NamedType) -> str:
return type.name.lexeme
def visit_type_expr(self, expr: m.TypeExpr):
template: str = expr.template.accept(self) if expr.template is not None else ""
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}"
def visit_generic_type(self, type: m.GenericType) -> str:
res: str = type.type.accept(self)
if len(type.params) != 0:
params: list[str] = [param.accept(self) for param in type.params]
res += f"[{', '.join(params)}]"
return res
def visit_constraint_type(self, type: m.ConstraintType) -> str:
res: str = type.type.accept(self)
res += " where " + type.constraint.accept(self)
return res
def visit_complex_type(self, type: m.ComplexType) -> str:
res: str = "{\n"
self.level += 1
for prop in type.properties:
res += prop.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
return res
class PythonAstPrinter(
@@ -465,7 +494,8 @@ class PythonAstPrinter(
self._write_line("IfStmt")
with self._child_level():
self._write_line("test")
stmt.test.accept(self)
with self._child_level(single=True):
stmt.test.accept(self)
self._write_line("body")
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
@@ -592,3 +622,18 @@ class PythonAstPrinter(
self._write_line("expr", last=True)
with self._child_level(single=True):
expr.expr.accept(self)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self._write_line("TernaryExpr")
with self._child_level():
self._write_line("test")
with self._child_level(single=True):
expr.test.accept(self)
self._write_line("if_true")
with self._child_level(single=True):
expr.if_true.accept(self)
self._write_line("if_false", last=True)
with self._child_level(single=True):
expr.if_false.accept(self)

View File

@@ -220,6 +220,9 @@ class Expr(ABC):
@abstractmethod
def visit_cast_expr(self, expr: CastExpr) -> T: ...
@abstractmethod
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
@dataclass(frozen=True)
class BinaryExpr(Expr):
@@ -312,3 +315,13 @@ class CastExpr(Expr):
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_cast_expr(self)
@dataclass(frozen=True)
class TernaryExpr(Expr):
test: Expr
if_true: Expr
if_false: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_ternary_expr(self)

View File

@@ -34,9 +34,15 @@ class Checker(
):
"""A type checker which can use custom type definitions"""
def __init__(self, locals: dict[p.Expr, int], file_path: Path):
def __init__(
self,
locals: dict[p.Expr, int],
source_path: Path,
types_paths: list[Path],
):
self.logger: logging.Logger = logging.getLogger("Checker")
self.file_path: Path = file_path
self.source_path: Path = source_path
self.types_paths: list[Path] = types_paths
self.ctx: MidasResolver = MidasResolver()
self.global_env: Environment = Environment()
self.env: Environment = self.global_env
@@ -46,7 +52,7 @@ class Checker(
def diagnostic(self, type: DiagnosticType, location: Location, message: str):
self.diagnostics.append(
Diagnostic(
file_path=self.file_path,
file_path=self.source_path,
location=location,
type=type,
message=message,
@@ -74,7 +80,7 @@ class Checker(
message=message,
)
def evaluate(self, expr: p.Expr) -> Type:
def type_of(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression
Args:
@@ -85,13 +91,13 @@ class Checker(
"""
return expr.accept(self)
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> bool:
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
"""Evaluate a sequence of statements
Args:
block (list[p.Stmt]): the statements to evaluate
env (Environment): the environment in which to evaluate
Returns:
bool: whether a return statement is present in the block
"""
@@ -119,6 +125,12 @@ class Checker(
list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
"""
self.diagnostics = []
for path in self.types_paths:
self.import_midas(path)
self.logger.debug(f"Midas types: {self.ctx._types}")
self.logger.debug(f"Midas operations: {self.ctx._operations}")
for stmt in statements:
stmt.accept(self)
@@ -140,30 +152,6 @@ class Checker(
return self.env.get_at(distance, name)
return self.global_env.get(name)
def parse_midas_import(self, expr: p.CallExpr) -> Optional[Path]:
"""Parse a Midas import statement
The statement should be written as `midas.using("path/to/types.midas")`
Args:
expr (p.CallExpr): the import call expression
Returns:
Optional[Path]: the path to the imported file, or None if the expression is malformed
"""
match expr:
case p.CallExpr(
callee=p.GetExpr(
object=p.VariableExpr(name="midas"),
name="using",
),
arguments=[
p.LiteralExpr(value=path),
],
):
return Path(path)
return None
def import_midas(self, path: Path) -> None:
"""Import Midas definitions from a path
@@ -171,17 +159,14 @@ class Checker(
path (Path): the import path
"""
self.logger.debug(f"Importing type definitions from {path}")
path = (self.file_path.parent / path).resolve()
lexer: MidasLexer = MidasLexer(path.read_text())
tokens: list[Token] = lexer.process()
parser: MidasParser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
self.ctx.resolve(stmts)
self.logger.debug(f"Midas types: {self.ctx._types}")
self.logger.debug(f"Midas operations: {self.ctx._operations}")
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
self.evaluate(stmt.expr)
self.type_of(stmt.expr)
def visit_function(self, stmt: p.Function) -> None:
env: Environment = Environment(self.env)
@@ -223,7 +208,7 @@ class Checker(
for arg in pos_args + args + kw_args:
env.define(arg.name, arg.type)
returns_hint: Optional[Type] = None
if stmt.returns is not None:
returns_hint = stmt.returns.accept(self)
@@ -237,7 +222,7 @@ class Checker(
)
self.env.define(stmt.name, inside_function)
returned: bool = self.evaluate_block(stmt.body, env)
returned: bool = self.process_block(stmt.body, env)
inferred_return: Type = UnknownType()
if not returned:
env.return_types.append(UnitType())
@@ -278,7 +263,7 @@ class Checker(
self.env.define(stmt.name, type)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
value: Type = self.evaluate(stmt.value)
value: Type = self.type_of(stmt.value)
for target in stmt.targets:
if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}")
@@ -317,8 +302,8 @@ class Checker(
)
env: Environment = Environment(self.env)
body_returned: bool = self.evaluate_block(stmt.body, env)
else_returned: bool = self.evaluate_block(stmt.orelse, env)
body_returned: bool = self.process_block(stmt.body, env)
else_returned: bool = self.process_block(stmt.orelse, env)
self.env.return_types.extend(env.return_types)
if body_returned and else_returned:
raise ReturnException()
@@ -329,8 +314,8 @@ class Checker(
self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType()
left: Type = self.evaluate(expr.left)
right: Type = self.evaluate(expr.right)
left: Type = self.type_of(expr.left)
right: Type = self.type_of(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None:
@@ -347,8 +332,8 @@ class Checker(
self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType()
left: Type = self.evaluate(expr.left)
right: Type = self.evaluate(expr.right)
left: Type = self.type_of(expr.left)
right: Type = self.type_of(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None:
@@ -362,10 +347,7 @@ class Checker(
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
def visit_call_expr(self, expr: p.CallExpr) -> Type:
if path := self.parse_midas_import(expr):
self.import_midas(path)
return UnknownType()
callee: Type = self.evaluate(expr.callee)
callee: Type = self.type_of(expr.callee)
if not isinstance(callee, Function):
self.error(expr.callee.location, "Callee is not a function")
return UnknownType()
@@ -398,13 +380,41 @@ class Checker(
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
return self.look_up_variable(expr.name, expr) or UnknownType()
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: ...
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
left: Type = expr.left.accept(self)
right: Type = expr.right.accept(self)
# TODO: union type
if left != right:
self.error(
expr.location,
f"Operands must be of the same type, left={left} != right={right}",
)
return left
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
return expr.type.accept(self)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
test_type: Type = expr.test.accept(self)
# TODO Allow subtypes or any type
if test_type != self.ctx.get_type("bool"):
self.error(
expr.test.location, f"If test must be a boolean, got {test_type}"
)
true_type: Type = expr.if_true.accept(self)
false_type: Type = expr.if_false.accept(self)
if true_type != false_type:
self.error(
expr.location,
f"Type mismatch in ternary if branches: true={true_type} != false={false_type}",
)
return UnknownType()
return true_type
def visit_base_type(self, node: p.BaseType) -> Type:
return self.ctx.get_type(node.base)
@@ -432,10 +442,10 @@ class Checker(
list[MappedArgument]: the list of mapped arguments
"""
positional: list[tuple[p.Expr, Type]] = [
(arg, self.evaluate(arg)) for arg in call.arguments
(arg, self.type_of(arg)) for arg in call.arguments
]
keywords: dict[str, tuple[p.Expr, Type]] = {
name: (arg, self.evaluate(arg)) for name, arg in call.keywords.items()
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
}
set_args: set[str] = set()

View File

@@ -9,9 +9,9 @@ class BaseType:
@dataclass(frozen=True, kw_only=True)
class SimpleType:
class AliasType:
name: str
base: BaseType | SimpleType
type: Type
@dataclass(frozen=True, kw_only=True)
@@ -39,4 +39,9 @@ class Function:
required: bool
Type = BaseType | SimpleType | UnknownType | UnitType | Function
@dataclass(frozen=True, kw_only=True)
class ComplexType:
properties: dict[str, Type]
Type = BaseType | AliasType | UnknownType | UnitType | Function | ComplexType

View File

@@ -53,5 +53,6 @@ span {
&.keyword {
color: rgb(211, 72, 9);
pointer-events: none;
}
}

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Generic, Optional, Protocol, TextIO, TypeVar
@@ -8,6 +9,7 @@ import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic
from midas.lexer.token import Token
H = TypeVar("H", bound="Highlighter", contravariant=True)
@@ -22,6 +24,15 @@ class Locatable(Protocol):
def location(self) -> Optional[Location]: ...
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
class Highlighter(ABC):
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
EXTRA_CSS_PATH: Optional[Path] = None
@@ -203,35 +214,25 @@ class PythonHighlighter(
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
def highlight(self, node: Highlightable[MidasHighlighter]):
node.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
self.wrap(stmt, "simple-type")
if stmt.template is not None:
stmt.template.accept(self)
stmt.base.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None:
self.wrap(stmt, "complex-type")
if stmt.template is not None:
stmt.template.accept(self)
for prop in stmt.properties:
prop.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self.wrap(stmt, "type-stmt")
self.wrap(LocatableToken(stmt.name), "type-name")
stmt.type.accept(self)
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
self.wrap(stmt, "property")
stmt.type.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self.wrap(stmt, "extend")
@@ -241,17 +242,16 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
self.wrap(stmt, "op")
self.wrap(LocatableToken(stmt.name), "op-name")
stmt.operand.accept(self)
stmt.result.accept(self)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate")
self.wrap(LocatableToken(stmt.name), "predicate-name")
stmt.type.accept(self)
stmt.condition.accept(self)
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None:
self.wrap(expr, "simple-type-expr")
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.wrap(expr, "logical-expr")
expr.left.accept(self)
@@ -280,14 +280,24 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
self.wrap(expr, "template")
expr.type.accept(self)
def visit_named_type(self, type: m.NamedType) -> None:
self.wrap(type, "named-type")
def visit_type_expr(self, expr: m.TypeExpr) -> None:
self.wrap(expr, "type")
if expr.template is not None:
expr.template.accept(self)
def visit_generic_type(self, type: m.GenericType) -> None:
self.wrap(type, "generic-type")
type.type.accept(self)
for param in type.params:
param.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self.wrap(type, "constraint-type")
type.type.accept(self)
type.constraint.accept(self)
def visit_complex_type(self, type: m.ComplexType) -> None:
self.wrap(type, "complex-type")
for prop in type.properties:
prop.accept(self)
class DiagnosticsHighlighter(Highlighter):

View File

@@ -5,12 +5,11 @@ span {
font-style: italic;
}
&.simple-type {
--col: 108, 233, 108;
}
&.named-type,
&.generic-type,
&.constraint-type,
&.complex-type {
--col: 233, 206, 108;
--col: 150, 150, 150;
}
&.constraint {
@@ -33,10 +32,6 @@ span {
--col: 193, 108, 233;
}
&.simple-type-expr {
--col: 150, 150, 150;
}
&.logical-expr,
&.binary-expr,
&.unary-expr,
@@ -48,7 +43,9 @@ span {
--col: 163, 117, 71;
}
&.type {
&.type-name,
&.op-name,
&.predicate-name {
--col: 200, 200, 200;
font-weight: bold;
}

View File

@@ -1,7 +1,6 @@
import ast
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, TextIO, get_args
@@ -9,14 +8,14 @@ import click
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
from midas.checker.checker import Checker
from midas.checker.diagnostic import Diagnostic
from midas.checker.types import Type
from midas.cli.highlighter import (
DiagnosticsHighlighter,
Highlighter,
LocatableToken,
MidasHighlighter,
PythonHighlighter,
)
@@ -30,13 +29,14 @@ from midas.utils import UniversalJSONDumper
@click.group()
def midas():
click.echo("Welcome to Midas!")
pass
@midas.command()
@click.option("-l", "--highlight", type=click.File("w"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.argument("file", type=click.File("r"))
def compile(highlight: Optional[TextIO], file: TextIO):
def compile(highlight: Optional[TextIO], file: TextIO, types: tuple[TextIO]):
logging.basicConfig(level=logging.DEBUG)
source: str = file.read()
tree: ast.Module = ast.parse(source, filename=file.name)
@@ -44,7 +44,12 @@ def compile(highlight: Optional[TextIO], file: TextIO):
stmts: list[p.Stmt] = parser.parse_module(tree)
resolver = Resolver()
resolver.resolve(*stmts)
checker = Checker(resolver.locals, file_path=Path(file.name).resolve())
types_paths: list[Path] = [Path(t.name).resolve() for t in types]
checker = Checker(
resolver.locals,
source_path=Path(file.name).resolve(),
types_paths=types_paths,
)
diagnostics: list[Diagnostic] = checker.check(stmts)
for diagnostic in diagnostics:
print(diagnostic)
@@ -142,14 +147,6 @@ def highlight_midas(source: str, path: str) -> Highlighter:
for err in parser.errors:
print(err.get_report())
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
for stmt in stmts:
highlighter.highlight(stmt)
for token in tokens:
@@ -176,5 +173,21 @@ def highlight(output: TextIO, file: TextIO):
highlighter.dump(output)
@midas.command()
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.argument("file", type=click.File("r"))
def format(output: TextIO, file: TextIO):
source: str = file.read()
printer = MidasPrinter()
lexer = MidasLexer(source, file=file.name)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
for err in parser.errors:
print(err.get_report())
for stmt in stmts:
output.write(printer.print(stmt) + "\n")
if __name__ == "__main__":
midas()

View File

@@ -40,8 +40,8 @@ class MidasLexer(Lexer):
self.add_token(TokenType.AND)
case "?":
self.add_token(TokenType.QMARK)
# case ",":
# self.add_token(TokenType.COMMA)
case ",":
self.add_token(TokenType.COMMA)
case "_" if not self.is_identifier_char(self.peek_next(), start=False):
self.add_token(TokenType.UNDERSCORE)
case "-" if self.match(">"):

View File

@@ -17,7 +17,7 @@ class TokenType(Enum):
LEFT_BRACE = auto()
RIGHT_BRACE = auto()
COLON = auto()
# COMMA = auto()
COMMA = auto()
UNDERSCORE = auto()
ARROW = auto()
AND = auto()

View File

@@ -3,21 +3,22 @@ from typing import Optional
from midas.ast.location import Location
from midas.ast.midas import (
BinaryExpr,
ComplexTypeStmt,
ComplexType,
ConstraintType,
Expr,
ExtendStmt,
GenericType,
GetExpr,
GroupingExpr,
LiteralExpr,
LogicalExpr,
NamedType,
OpStmt,
PredicateStmt,
PropertyStmt,
SimpleTypeExpr,
SimpleTypeStmt,
Stmt,
TemplateExpr,
TypeExpr,
Type,
TypeStmt,
UnaryExpr,
VariableExpr,
WildcardExpr,
@@ -81,7 +82,7 @@ class MidasParser(Parser):
self.synchronize()
return None
def type_declaration(self) -> SimpleTypeStmt | ComplexTypeStmt:
def type_declaration(self) -> TypeStmt:
"""Parse a type declaration
A type declaration can either be a simple type alias or a new complex type.
@@ -107,33 +108,22 @@ class MidasParser(Parser):
"""
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
template: Optional[TemplateExpr] = None
params: list[TypeStmt.Param] = []
if self.check(TokenType.LEFT_BRACKET):
template = self.template_expr()
params = self.type_stmt_params()
if self.match(TokenType.LEFT_PAREN):
base: TypeExpr = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Unclosed base type parenthesis")
constraint: Optional[Expr] = None
if self.match(TokenType.WHERE):
constraint = self.constraint()
return SimpleTypeStmt(
location=keyword.location_to(self.previous()),
name=name,
template=template,
base=base,
constraint=constraint,
)
else:
properties: list[PropertyStmt] = self.type_properties()
return ComplexTypeStmt(
location=keyword.location_to(self.previous()),
name=name,
template=template,
properties=properties,
)
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
def template_expr(self) -> TemplateExpr:
type: Type = self.type_expr()
return TypeStmt(
location=keyword.location_to(self.previous()),
name=name,
params=params,
type=type,
)
def type_stmt_params(self) -> list[TypeStmt.Param]:
"""Parse a generic template expression
A template is written `[TypeExpr]`
@@ -141,16 +131,27 @@ class MidasParser(Parser):
Returns:
TemplateExpr: the parsed template expression
"""
left: Token = self.consume(
TokenType.LEFT_BRACKET, "Missing '[' before template expression"
)
type: TypeExpr = self.type_expr()
right: Token = self.consume(
TokenType.RIGHT_BRACKET, "Missing ']' after template expression"
)
return TemplateExpr(location=left.location_to(right), type=type)
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
params: list[TypeStmt.Param] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
bound: Optional[Type] = None
if self.match(TokenType.LESS):
self.consume(TokenType.COLON, "Expected ':' after '<'")
bound = self.type_expr()
params.append(
TypeStmt.Param(
location=name.location_to(self.previous()),
name=name,
bound=bound,
)
)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
return params
def type_expr(self) -> TypeExpr:
def type_expr(self) -> Type:
"""Parse a type expression
A type is an identifier, optionally followed by a template expression.
@@ -159,30 +160,82 @@ class MidasParser(Parser):
Returns:
TypeExpr: the parsed type expression
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
template: Optional[TemplateExpr] = None
return self.constraint_type()
def constraint_type(self) -> Type:
type: Type = self.base_type()
if self.match(TokenType.WHERE):
constraint: Expr = self.constraint()
return ConstraintType(
location=Location.span(type.location, constraint.location),
type=type,
constraint=constraint,
)
return type
def base_type(self) -> Type:
if self.match(TokenType.LEFT_PAREN):
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
return type
if self.check(TokenType.LEFT_BRACE):
return self.complex_type()
return self.generic_type()
def generic_type(self) -> Type:
type: Type = self.named_type()
if self.check(TokenType.LEFT_BRACKET):
template = self.template_expr()
optional: bool = self.match(TokenType.QMARK)
return TypeExpr(
location=name.location_to(self.previous()),
params: list[Type] = self.type_params()
return GenericType(
location=Location.span(type.location, self.previous().get_location()),
type=type,
params=params,
)
return type
def type_params(self) -> list[Type]:
params: list[Type] = []
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
params.append(self.type_expr())
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
return params
def named_type(self) -> Type:
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
return NamedType(
location=name.get_location(),
name=name,
template=template,
optional=optional,
)
def simple_type_expr(self) -> SimpleTypeExpr:
"""Parse a simple type expression
def complex_type(self) -> Type:
"""Parse a type definition body
A simple type is just an identifier optionally followed by a '?'
A type definition body is a set of whitespace-separated
property statements enclosed in curly braces
Returns:
SimpleTypeExpr: the parsed simple type expression
list[PropertyStmt]: the parsed type properties
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
optional: bool = self.match(TokenType.QMARK)
return SimpleTypeExpr(
location=name.location_to(self.previous()), name=name, optional=optional
left: Token = self.consume(
TokenType.LEFT_BRACE, "Expected '{' to start type body"
)
properties: list[PropertyStmt] = []
names: set[str] = set()
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
prop: PropertyStmt = self.property_stmt()
if prop.name.lexeme in names:
raise self.error(prop.name, "Duplicate property")
names.add(prop.name.lexeme)
properties.append(prop)
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
return ComplexType(
location=left.location_to(right),
properties=properties,
)
def constraint(self) -> Expr:
@@ -308,27 +361,6 @@ class MidasParser(Parser):
raise self.error(self.peek(), "Expected expression")
def type_properties(self) -> list[PropertyStmt]:
"""Parse a type definition body
A type definition body is a set of whitespace-separated
property statements enclosed in curly braces
Returns:
list[PropertyStmt]: the parsed type properties
"""
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body")
properties: list[PropertyStmt] = []
names: set[str] = set()
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
prop: PropertyStmt = self.property_stmt()
if prop.name.lexeme in names:
raise self.error(prop.name, "Duplicate property")
names.add(prop.name.lexeme)
properties.append(prop)
self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
return properties
def property_stmt(self) -> PropertyStmt:
"""Parse a property statement
@@ -339,15 +371,11 @@ class MidasParser(Parser):
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
self.consume(TokenType.COLON, "Expected ':' after property name")
type: TypeExpr = self.type_expr()
constraint: Optional[Expr] = None
if self.match(TokenType.WHERE):
constraint = self.constraint()
type: Type = self.type_expr()
return PropertyStmt(
location=name.location_to(self.previous()),
name=name,
type=type,
constraint=constraint,
)
def extend_declaration(self) -> ExtendStmt:
@@ -359,7 +387,7 @@ class MidasParser(Parser):
ExtendStmt: the parsed extension statement
"""
keyword: Token = self.previous()
type: TypeExpr = self.type_expr()
type: Type = self.type_expr()
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
operations: list[OpStmt] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
@@ -380,11 +408,11 @@ class MidasParser(Parser):
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
operand: TypeExpr = self.type_expr()
operand: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: TypeExpr = self.type_expr()
result: Type = self.type_expr()
return OpStmt(
location=keyword.location_to(self.previous()),
@@ -406,7 +434,7 @@ class MidasParser(Parser):
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
self.consume(TokenType.COLON, "Expected ':' after subject name")
type: TypeExpr = self.type_expr()
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint()

View File

@@ -22,6 +22,7 @@ from midas.ast.python import (
MidasType,
ReturnStmt,
Stmt,
TernaryExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
@@ -86,6 +87,9 @@ class PythonParser:
case ast.If():
return self.parse_if(node)
case ast.Pass():
return None
case _:
print(f"Unsupported statement: {ast.unparse(node)}")
return None
@@ -310,6 +314,13 @@ class PythonParser:
constraint=right_expr,
)
case ast.Constant(value=None):
return BaseType(
location=loc,
base="None",
param=None,
)
case _:
raise UnsupportedSyntaxError(type_expr)
@@ -389,6 +400,9 @@ class PythonParser:
case ast.Call():
return self.parse_call(node)
case ast.IfExp():
return self.parse_ternary(node)
case ast.Constant(value=value):
return LiteralExpr(location=location, value=value)
@@ -478,3 +492,11 @@ class PythonParser:
if arg.arg is not None # Should always be True, type checker happy
},
)
def parse_ternary(self, node: ast.IfExp) -> TernaryExpr:
return TernaryExpr(
location=Location.from_ast(node),
test=self.parse_expr(node.test),
if_true=self.parse_expr(node.body),
if_false=self.parse_expr(node.orelse),
)

View File

@@ -16,6 +16,7 @@ def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
result=t3,
)
def basic_op(ctx: MidasResolver, type: Type, op: str):
ctx.define_operation(
left=type,
@@ -68,4 +69,4 @@ def define_builtins(ctx: MidasResolver):
op(ctx, float, "__gt__", int, bool) # float > int = bool
op(ctx, float, "__le__", int, bool) # float <= int = bool
op(ctx, float, "__ge__", int, bool) # float >= int = bool
op(ctx, float, "__eq__", int, bool) # float == int = bool
op(ctx, float, "__eq__", int, bool) # float == int = bool

View File

@@ -1,11 +1,15 @@
from typing import Optional
import midas.ast.midas as m
from midas.checker.types import BaseType, SimpleType, Type
from midas.checker.types import (
AliasType,
Type,
UnknownType,
)
from midas.resolver.builtin import define_builtins
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]):
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self) -> None:
@@ -94,20 +98,13 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]):
for stmt in stmts:
stmt.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
# TODO generics, optional, constraint
base: Type = self.get_type(stmt.base.name.lexeme)
match base:
case BaseType() | SimpleType():
type = SimpleType(
name=stmt.name.lexeme,
base=base,
)
self.define_type(type.name, type)
case _:
raise TypeError(f"Invalid base {base} for simple type")
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None: ...
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
type: Type = stmt.type.accept(self)
for param in stmt.params:
if param.bound is not None:
param.bound.accept(self)
name: str = stmt.name.lexeme
self.define_type(name, AliasType(name=name, type=type))
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
@@ -127,27 +124,40 @@ class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]):
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> Type:
return self.get_type(expr.name.lexeme)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type: ...
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type: ...
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: ...
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
def visit_get_expr(self, expr: m.GetExpr) -> Type: ...
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
def visit_variable_expr(self, expr: m.VariableExpr) -> Type: ...
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type: ...
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type: ...
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_template_expr(self, expr: m.TemplateExpr) -> Type: ...
def visit_named_type(self, type: m.NamedType) -> Type:
return self.get_type(type.name.lexeme)
def visit_type_expr(self, expr: m.TypeExpr) -> Type:
return self.get_type(expr.name.lexeme)
def visit_generic_type(self, type: m.GenericType) -> Type:
type_: Type = type.type.accept(self)
params: list[Type] = [param.accept(self) for param in type.params]
# TODO
return UnknownType()
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
type_: Type = type.type.accept(self)
type.constraint.accept(self)
# TODO
return UnknownType()
def visit_complex_type(self, type: m.ComplexType) -> Type:
for prop in type.properties:
prop.accept(self)
# TODO
return UnknownType()

View File

@@ -180,3 +180,8 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
def visit_cast_expr(self, expr: p.CastExpr) -> None:
self.resolve(expr.expr)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self.resolve(expr.test)
self.resolve(expr.if_true)
self.resolve(expr.if_false)

View File

@@ -19,16 +19,24 @@ Comparison ::= Unary (ComparisonOp Unary)*
Equality ::= Comparison (EqualityOp Comparison)*
Constraint ::= Equality ("&" Equality)*
SimpleType ::= Identifier "?"?
Template ::= "[" Type "]"
Type ::= Identifier Template? "?"?
TemplateParam ::= Identifier ("<:" Type)?
Template ::= "[" (TemplateParam ("," TemplateParam)*)? "]"
TypeProperty ::= Identifier ":" Type
ComplexType ::= "{" TypeProperty* "}"
NamedType ::= Identifier
TypeParams ::= "[" (Type ("," Type)*)? "]"
GenericType ::= NamedType TypeParams?
GroupedType ::= "(" Type ")"
BaseType ::= GroupedType | ComplexType | GenericType
ConstraintType ::= BaseType ("where" Constraint)?
Type ::= ConstraintType
TypeProperty ::= Identifier ":" Type ("where" Constraints)?
ComplexTypeBody ::= "{" TypeProperty* "}"
OpDefinition ::= "op" Identifier "(" Type ")" "->" Type
ExtendBody ::= "{" OpDefinition* "}"
TypeStatement ::= "type" Identifier Template? ("(" Type ")" ("where" Constraint)? | ComplexTypeBody)
TypeStatement ::= "type" Identifier Template? "=" Type
ExtendStatement ::= "extend" Type ExtendBody
PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint

View File

@@ -43,28 +43,52 @@ svg.railroad .terminal rect {
{[`constraint` 'equality'*"&"]}
```
#let simple-type = ```
{[`simple-type` 'identifier' <!, "?">]}
#let template-param = ```
{[`template-param` 'identifier' <!, ["<:" 'type']>]}
```
#let template = ```
{[`template` "[" 'type' "]"]}
```
#let type = ```
{[`type` 'identifier' <!, 'template'> <!, "?">]}
{[`template` "[" <!, 'template-param'*","> "]"]}
```
#let type-property = ```
{[`type-property` 'identifier' ":" 'type' <!, ["where" 'constraint']>]}
{[`type-property` 'identifier' ":" 'type']}
```
#let type-body = ```
{[`type-body` "{" <!, 'type-property'*!> "}"]}
#let complex-type = ```
{[`complex-type` "{" <!, 'type-property'*!> "}"]}
```
#let named-type = ```
{[`named-type` 'identifier']}
```
#let type-params = ```
{[`type-params` "[" <!, 'type'*","> "]"]}
```
#let generic-type = ```
{[`generic-type` 'named-type' <!, 'type-params'>]}
```
#let grouped-type = ```
{[`grouped-type` "(" 'type' ")"]}
```
#let base-type = ```
{[`base-type` <'grouped-type', 'complex-type', 'generic-type'>]}
```
#let constraint-type = ```
{[`constraint-type` 'base-type' <!, ["where" 'constraint']>]}
```
#let type = ```
{[`type` 'constraint-type']}
```
#let type-statement = ```
{[`type-statement` "type" 'identifier' <!, 'template'> <[["(" 'type' ")"] <!, ["where" 'constraint']>], 'type-body'>]}
{[`type-statement` "type" 'identifier' <!, 'template'> "=" 'type']}
```
#let op-definition = ```
@@ -92,11 +116,17 @@ svg.railroad .terminal rect {
comparison: comparison,
equality: equality,
constraint: constraint,
simple-type: simple-type,
template-param: template-param,
template: template,
type: type,
type-property: type-property,
type-body: type-body,
complex-type: complex-type,
named-type: named-type,
type-params: type-params,
generic-type: generic-type,
grouped-type: grouped-type,
base-type: base-type,
constraint-type: constraint-type,
type: type,
type-statement: type-statement,
op-definition: op-definition,
extend-statement: extend-statement,
@@ -107,10 +137,16 @@ svg.railroad .terminal rect {
#let inline = (
"grouping",
"value",
"template-param",
"template",
"simple-type",
"type-property",
"type-body",
"complex-type",
"type-params",
"named-type",
"grouped-type",
"generic-type",
"base-type",
"constraint-type",
"op-definition",
"type-statement",
"extend-statement",

View File

@@ -141,3 +141,9 @@ class Tester(ABC):
success = tester.run_tests(args.FILE)
if not success:
sys.exit(1)
case None:
print("No subcommand provided. Available subcommands: run, update")
sys.exit(1)
case _:
print(f"Unknown subcommand '{args.subcommand}'")
sys.exit(1)

View File

@@ -1,6 +1,6 @@
type Meter(float)
type Second(float)
type MeterPerSecond(float)
type Meter = float
type Second = float
type MeterPerSecond = float
extend Meter {
op __add__(Meter) -> Meter

View File

@@ -1,8 +1,6 @@
# type: ignore
# ruff: disable [F821]
midas.using("04_custom_types.midas")
distance: Meter = cast(Meter, 123.45)
time: Second = cast(Second, 6.7)
speed = distance / time

View File

@@ -1,15 +1,15 @@
// Simple custom type derived from float
type Custom(float)
type Custom = float
// Simple custom types with constraints
type Latitude(float) where (-90 <= _ <= 90)
type Longitude(float) where (-180 <= _ <= 180)
type Latitude = float where (-90 <= _ <= 90)
type Longitude = float where (-180 <= _ <= 180)
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
type Difference[T](T)
type Difference[T] = T
// Complex custom type, containing two values accessible through properties
type GeoLocation {
type GeoLocation = {
lat: Latitude
lon: Longitude
}
@@ -24,7 +24,7 @@ extend GeoLocation {
// For complex generics, you need to specify how the genericity the properties
// are handled
type Difference[GeoLocation] {
type Difference[GeoLocation] = {
lat: Difference[Latitude]
lon: Difference[Longitude]
}
@@ -44,11 +44,11 @@ predicate StrictlyPositive(v: float) = v > 0
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
type Person {
type Person = {
name: str
// Property with an inline constraint
age: int? where (0 <= _ < 150)
age: Optional[int where (0 <= _ < 150)]
// Property referencing a predicate
height: float where StrictlyPositive

File diff suppressed because it is too large Load Diff

View File

@@ -2,10 +2,6 @@
# ruff: disable[F821]
from __future__ import annotations
import midas
midas.using("02_custom_types.midas")
df: Frame[
location: GeoLocation
]

View File

@@ -1,26 +1,5 @@
{
"stmts": [
{
"_type": "ExpressionStmt",
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "GetExpr",
"object": {
"_type": "VariableExpr",
"name": "midas"
},
"name": "using"
},
"arguments": [
{
"_type": "LiteralExpr",
"value": "02_custom_types.midas"
}
],
"keywords": {}
}
},
{
"_type": "TypeAssign",
"name": "df",

View File

@@ -33,6 +33,10 @@ class CheckerTester(Tester):
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
types_paths: list[Path] = []
types_path: Path = path.with_suffix(".midas")
if types_path.exists():
types_paths.append(types_path)
source: str = path.read_text()
tree: ast.Module = ast.parse(source, filename=path)
parser = PythonParser()
@@ -40,7 +44,11 @@ class CheckerTester(Tester):
resolver = Resolver()
resolver.resolve(*stmts)
result: CaseResult = CaseResult()
checker = Checker(resolver.locals, file_path=path)
checker = Checker(
resolver.locals,
source_path=path,
types_paths=types_paths,
)
diagnostics: list[Diagnostic] = checker.check(stmts)
for diagnostic in diagnostics:
result.diagnostics.append(

View File

@@ -2,56 +2,60 @@ from typing import Optional, Sequence
from midas.ast.midas import (
BinaryExpr,
ComplexTypeStmt,
ComplexType,
ConstraintType,
Expr,
ExtendStmt,
GenericType,
GetExpr,
GroupingExpr,
LiteralExpr,
LogicalExpr,
NamedType,
OpStmt,
PredicateStmt,
PropertyStmt,
SimpleTypeExpr,
SimpleTypeStmt,
Stmt,
TemplateExpr,
TypeExpr,
Type,
TypeStmt,
UnaryExpr,
VariableExpr,
WildcardExpr,
)
class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
class MidasAstJsonSerializer(
Stmt.Visitor[dict], Expr.Visitor[dict], Type.Visitor[dict]
):
"""An AST serializer which produces a JSON-compatible structure"""
def serialize(self, stmts: list[Stmt]) -> list[dict]:
return [stmt.accept(self) for stmt in stmts]
def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]:
def _serialize_optional(
self, element: Optional[Stmt | Expr | Type]
) -> Optional[dict]:
if element is None:
return None
return element.accept(self)
def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]:
def _serialize_list(self, elements: Sequence[Stmt | Expr | Type]) -> list[dict]:
return [element.accept(self) for element in elements]
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict:
def visit_type_stmt(self, stmt: TypeStmt) -> dict:
return {
"_type": "SimpleTypeStmt",
"_type": "TypeStmt",
"name": stmt.name.lexeme,
"template": self._serialize_optional(stmt.template),
"base": stmt.base.accept(self),
"constraint": self._serialize_optional(stmt.constraint),
"params": [
self._serialize_type_stmt_template_param(param) for param in stmt.params
],
"type": stmt.type.accept(self),
}
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict:
def _serialize_type_stmt_template_param(self, param: TypeStmt.Param) -> dict:
return {
"_type": "ComplexTypeStmt",
"name": stmt.name.lexeme,
"template": self._serialize_optional(stmt.template),
"properties": self._serialize_list(stmt.properties),
"name": param.name.lexeme,
"bound": self._serialize_optional(param.bound),
}
def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
@@ -59,7 +63,6 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"_type": "PropertyStmt",
"name": stmt.name.lexeme,
"type": stmt.type.accept(self),
"constraint": self._serialize_optional(stmt.constraint),
}
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
@@ -86,13 +89,6 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"condition": stmt.condition.accept(self),
}
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict:
return {
"_type": "SimpleTypeExpr",
"name": expr.name.lexeme,
"optional": expr.optional,
}
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
return {
"_type": "LogicalExpr",
@@ -144,16 +140,28 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
return {"_type": "WildcardExpr"}
def visit_template_expr(self, expr: TemplateExpr) -> dict:
def visit_named_type(self, type: NamedType) -> dict:
return {
"_type": "TemplateExpr",
"type": expr.type.accept(self),
"_type": "NamedType",
"name": type.name.lexeme,
}
def visit_type_expr(self, expr: TypeExpr) -> dict:
def visit_generic_type(self, type: GenericType) -> dict:
return {
"_type": "TypeExpr",
"name": expr.name.lexeme,
"template": self._serialize_optional(expr.template),
"optional": expr.optional,
"_type": "GenericType",
"type": type.type.accept(self),
"params": self._serialize_list(type.params),
}
def visit_constraint_type(self, type: ConstraintType) -> dict:
return {
"_type": "ConstraintType",
"type": type.type.accept(self),
"constraint": type.constraint.accept(self),
}
def visit_complex_type(self, type: ComplexType) -> dict:
return {
"_type": "ComplexType",
"properties": self._serialize_list(type.properties),
}

View File

@@ -22,6 +22,7 @@ from midas.ast.python import (
ReturnStmt,
SetExpr,
Stmt,
TernaryExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
@@ -245,3 +246,11 @@ class PythonAstJsonSerializer(
"type": expr.type.accept(self),
"expr": expr.expr.accept(self),
}
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
return {
"_type": "TernaryExpr",
"test": expr.test.accept(self),
"if_true": expr.if_true.accept(self),
"if_false": expr.if_false.accept(self),
}