79 Commits

Author SHA1 Message Date
f7c43837b5 Merge pull request 'CLI tweaks' (#22) from fix/cli-tweaks into main
Reviewed-on: #22
2026-06-24 12:18:07 +00:00
32ed62a6f1 fix(cli): show summary of diagnostic counts 2026-06-24 14:11:39 +02:00
66f39acec0 fix(cli): show all diagnostics in types command
combine type checker diagnostics with judgements info diagnostics
2026-06-24 14:11:15 +02:00
6c04e2fee4 feat(cli): add compile option to ignore errors 2026-06-24 14:10:30 +02:00
2bb2e0a684 Merge pull request 'Unsafe cast' (#21) from feat/unsafe-cast into main
Reviewed-on: #21
2026-06-24 12:00:03 +00:00
5630320d21 chore: use unsafe_cast in demo script 2026-06-24 13:57:38 +02:00
9f05ba3224 feat: handle unsafe casts 2026-06-24 13:51:14 +02:00
5fbe965919 feat(checker): add typing submodule with cast functions 2026-06-24 13:40:23 +02:00
252a5abdfd Merge pull request 'Static evalution of casts on literals' (#20) from feat/literal-static-constraints into main
Reviewed-on: #20
2026-06-24 09:32:54 +00:00
55fba6a088 tests: update test without evaluated casts 2026-06-24 11:28:44 +02:00
70ce263ea2 feat(gen): skip assertions for evaluated casts
avoid generating a runtime assertion for a cast which has already been checked statically
2026-06-24 11:28:43 +02:00
e1d5eac8b8 feat(checker): evaluate constraints statically on literals 2026-06-24 11:10:09 +02:00
82666a4918 feat(checker): add evaluator
add an evaluator class to evaluate expressions using literal values
2026-06-24 11:08:15 +02:00
45f84a2f23 feat(checker): add debug diagnostics 2026-06-24 11:07:42 +02:00
dedfcb4dbb feat(checker): store builtin python functions in preamble 2026-06-24 11:05:36 +02:00
d9ea6365ea tests: update with cast expression judgement 2026-06-23 16:49:38 +02:00
9c7a93412c Merge pull request 'Fixes and small demo' (#19) from feat/demonstration into main
Reviewed-on: #19
2026-06-23 08:15:56 +00:00
d6b8fbfb60 chore: improve demo example 2026-06-23 10:03:24 +02:00
b290c59ac4 fix(gen): add bases for ConstraintType and TypeVar 2026-06-23 00:25:43 +02:00
093f2bc477 fix(checker): lookup member on typevar bound 2026-06-23 00:24:37 +02:00
7c771c4070 feat(checker): add input function to preamble 2026-06-23 00:22:38 +02:00
a50a207385 fix(gen): don't generate stubs for builtin types 2026-06-22 15:40:31 +02:00
7e5ea5e414 chore: add example to demonstrate some features 2026-06-22 15:29:39 +02:00
0ba0266bae fix(checker): check general subtype case for AppliedType
this adds the case where we check whether AppliedType <: Type, and delegates to the body

this may not be a legitimate rule, or may need to be refined
2026-06-22 15:27:06 +02:00
216c80f08c fix(checker): produce judgement for expression in cast 2026-06-22 15:24:51 +02:00
f75d7722a1 fix(checker): look up members on constraint type 2026-06-22 15:24:18 +02:00
2f29c47274 fix(gen): assert type var bound 2026-06-22 15:23:53 +02:00
80af2b9048 fix(checker): handle is_subtype of TypeVar 2026-06-22 14:44:51 +02:00
577454ee7e fix(checker): make UnknownType a top type for subtyping 2026-06-22 14:15:18 +02:00
878693383e feat(cli): add watch option to stubs command 2026-06-22 14:14:05 +02:00
0b91de75a8 feat(checker): handle type vars in python functions 2026-06-22 14:13:25 +02:00
739871c101 Merge pull request 'Generic call unification' (#18) from feat/unification into main
Reviewed-on: #18
2026-06-21 11:41:48 +00:00
4395e9339b fix(checker): abort unification on conflict 2026-06-21 13:36:07 +02:00
29e601128d tests: add unification test 2026-06-21 13:19:17 +02:00
b591f5508f fix(checker): make map definition generic 2026-06-21 13:17:35 +02:00
41d0c84bbe feat(checker): add unifier
add unifier class to infer type parameters from local call context
2026-06-21 13:12:27 +02:00
cccf2f8f9f Merge pull request 'Stubs generator' (#17) from feat/stubs-gen into main
Reviewed-on: #17
2026-06-20 15:44:34 +00:00
3f48c2138f chore: add stubs command to README 2026-06-20 17:44:15 +02:00
e4ab27673d fix(gen): handle TypeVar variance in stubs generator 2026-06-20 17:34:40 +02:00
b02ecc6326 fix(gen): handle ConstraintType in stubs generator 2026-06-20 17:34:22 +02:00
9e83079910 fix(cli): add missing methods to highlighter 2026-06-20 17:23:18 +02:00
ec468dd982 feat(cli): add stubs command 2026-06-20 17:10:25 +02:00
3edc25d778 feat(gen): add base for stubs generator 2026-06-20 17:10:24 +02:00
451e54b009 fix(checker): handle calls to AliasType 2026-06-20 17:10:24 +02:00
0dc14f67aa fix(checker): allow substitutyping type vars in GenericType and TopType 2026-06-20 17:10:23 +02:00
ff79f25628 fix(checker): store member kind in registry 2026-06-20 17:10:23 +02:00
12782dda1e Merge pull request 'Variance inference and subtyping' (#16) from feat/variance into main
Reviewed-on: #16
2026-06-20 14:55:01 +00:00
48a20b4aa0 tests: add tests for variance inference and subtyping 2026-06-20 16:48:19 +02:00
9467187313 feat(checker): use variance in subtype check 2026-06-20 16:30:30 +02:00
cd8f14153d feat(checker): infer type variables variance 2026-06-20 13:39:32 +02:00
6eea0c02e0 Merge pull request 'Constraint types' (#15) from feat/constraint-type into main
Reviewed-on: #15
2026-06-19 20:21:04 +00:00
3205e7b961 fix(checker): change back warning to errors 2026-06-19 22:13:10 +02:00
0aba134290 tests: add predicates and constraints test 2026-06-19 22:13:10 +02:00
1f0bcab2ca fix(checker) minor tweaks 2026-06-19 22:13:09 +02:00
db8d88ef35 feat(parser): parse strings in Midas files 2026-06-19 22:13:09 +02:00
7695d50537 fix(parser): correctly parse keyword arguments 2026-06-19 22:13:08 +02:00
8461d05fa6 fix(checker): handle all operations and calls in predicates 2026-06-19 22:13:08 +02:00
43d2118db7 fix(checker): lookup predicate variables in preamble 2026-06-19 22:13:07 +02:00
6a87b5396f feat(cli): print predicate with dump-registry 2026-06-19 22:13:07 +02:00
e6a581ba6e fix(checker): typo in docstring 2026-06-19 22:13:07 +02:00
2a7aac69ed fix(checker): change some diagnostics to warnings
temporarily change type errors in predicates to warnings until operations are fully type checked
2026-06-19 22:13:06 +02:00
eb5bf19c61 feat(gen): generate type hints for functions 2026-06-19 22:13:06 +02:00
657406ea01 feat(gen): handle predicate aliases
handle cases where a predicate is defined as an alias, i.e. without any parameters
2026-06-19 22:13:05 +02:00
2974386110 fix(parser): fix call expr location span 2026-06-19 22:13:05 +02:00
92ca6b6732 feat(types): detect constraint base subtyping 2026-06-19 22:13:04 +02:00
6aacdb98b7 feat(checker): type check predicate body 2026-06-19 22:13:04 +02:00
1b100b6ceb fix(gen): remove id from named predicate function 2026-06-19 22:13:03 +02:00
6b4c7d27bc fix(tests): update generator tester 2026-06-19 22:13:03 +02:00
2523d638f7 feat(gen): generate predicate functions 2026-06-19 22:13:02 +02:00
5fc7461e29 feat(gen): generate basic constraint assertion 2026-06-19 22:13:02 +02:00
c5154bde81 feat(types): add ConstraintType 2026-06-19 22:13:02 +02:00
d07e8ac0ca refactor: ensure exhaustiveness in some match/case 2026-06-19 22:13:01 +02:00
3380995082 tests: update with new predicate AST representation 2026-06-19 22:13:01 +02:00
7efc44c496 fix(tests): correctly serialize param name 2026-06-19 22:13:00 +02:00
ca94443699 feat(midas): generalize param spec of predicate and parse 2026-06-19 22:12:59 +02:00
c513a85cf2 feat(midas): add CallExpr 2026-06-19 22:12:59 +02:00
2a106c5d07 refactor: add param spec for FunctionType 2026-06-19 22:12:58 +02:00
9672dfd588 Merge pull request 'Update README' (#14) from fix/update-readme into main
Reviewed-on: #14
2026-06-19 13:25:09 +00:00
7639ccc94d chore: update README with new commands 2026-06-19 15:23:49 +02:00
51 changed files with 4441 additions and 322 deletions

View File

@@ -1,4 +1,4 @@
# Midas
<h1>Midas</h1>
*Midas* is a type system to _Maintain Integrity of Data with Annotated Structures_. In Greek mythology, [Midas](https://en.wikipedia.org/wiki/Midas) was a Phrygian king who was blessed with the gift of turning everything he touched into gold.
@@ -6,6 +6,25 @@
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
<details>
<summary><strong>Table of Contents</strong></summary>
- [Requirements](#requirements)
- [Installation](#installation)
- [Commands](#commands)
- [Type Checking](#type-checking)
- [Compiling](#compiling)
- [Formatting](#formatting)
- [Highlighting](#highlighting)
- [Dumping the AST](#dumping-the-ast)
- [Dumping the Registry](#dumping-the-registry)
- [Generating Stubs](#generating-stubs)
- [Showing Type Judgements](#showing-type-judgements)
- [Validating Definitions](#validating-definitions)
- [Tests](#tests)
</details>
## Requirements
- Python 3.11+
@@ -32,10 +51,26 @@ This framework is being developed as part of a Bachelor's Thesis by Louis Herede
## Commands
### Compiling
<!--
check
compile
format
highlight
parse
dump_registry
types
validate
-->
> [!NOTE]
> In the current state of the project, the `compile` command doesn't generate any runnable code, it only runs the parsers and type checker on the provided files
### Type Checking
```shell
midas check -t types.midas source.py
```
This command parses the given files and run the type checkers against the Midas definitions and Python program. Diagnostics are then printed showing warnings and errors.
### Compiling
```shell
midas compile -t types.midas source.py
@@ -43,14 +78,22 @@ midas compile -t types.midas source.py
With the `compile` command, you can process a source Python file, with any number of custom type definition files (`-t FILE` option), and the type checker will verify the coherence of your program and generate the runnable code with valid syntax and runtime assertions.
The optional `-l FILE` option lets you produce a highlighted version of the source code showing diagnostics from the type checker (see [Highlighting](#highlighting))
### Formatting
```shell
midas format types.midas
midas format types.midas -o formatted.midas
```
This command parses the given Midas file and outputs a pretty printed file from the AST.
### Highlighting
```shell
midas utils highlight source.py
# or
midas utils highlight types.midas
midas highlight source.py
midas highlight source.py -o highlighted.html
midas highlight types.midas
midas highlight types.midas -o highlighted.html
```
The `highlight` command takes in a source file (Python or Midas), runs the appropriate parser and outputs an HTML file containing the source code with added highlighting. This highlighting takes the form of hoverable annotations showing some of the parsed structures (e.g. a function definition, an assignment, a generic type, etc.)
@@ -60,14 +103,43 @@ The optional `-o FILE` option can be used to specify an output path. By default,
### Dumping the AST
```shell
midas utils dump-ast source.py
# or
midas utils dump-ast types.midas
midas parse source.py
midas parse types.midas
```
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `-p` flags lets you toggle the custom AST parsing. Without `-p`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `--raw` flags lets you toggle the custom AST parsing. With `--raw`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
### Dumping the Registry
```shell
midas dump-registry -t types.midas
```
This command processes the given Midas definitions and dumps the contents of the types registry.
### Generating Stubs
```shell
midas stubs types.midas -o stubs.pyi
```
This command generate Python stubs from a Midas definition file
### Showing Type Judgements
```shell
midas types -t types.midas source.py
```
This command type checks the given Python source file and logs all typing judgements made by the type checker.
### Validating Definitions
```shell
midas validate types.midas
```
This command lets you validate a Midas definition file by running the parser and type checker, verifying syntax and references.
## Tests
@@ -77,6 +149,7 @@ Several snapshot tests are available to assert the good behaviour of the parsers
uv run -m tests.midas run -a
uv run -m tests.python run -a
uv run -m tests.checker run -a
uv run -m tests.generator run -a
```
**Available subcommands:**

View File

@@ -0,0 +1,15 @@
predicate in_range(min: float, max: float)(v: float) = min <= v & v <= max
predicate is_ratio = in_range(0, 1)
type Currency = float
type Price[T <: Currency] = T where _ >= 0
extend Price[T <: Currency] {
def __add__: fn(Price[T], /) -> Price[T]
}
type EUR = Currency
type USD = Currency
type CHF = Currency
type Discount = float where is_ratio(_)

View File

@@ -0,0 +1,35 @@
from typing import TypeVar
from demo_stubs import CHF, EUR, USD, Currency, Discount, Price
from midas.typing import cast, unsafe_cast
T = TypeVar("T", bound=Currency)
def apply_discount(amount: Price[T], discount: Discount) -> Price[T]:
return cast(Price[T], (1.0 - discount) * amount)
a1 = cast(Price[EUR], 3.2)
a2 = cast(Price[USD], 10.4)
r1 = cast(Discount, 0.2)
print(apply_discount(a1, r1))
print(apply_discount(a2, r1))
a3 = a1 + a1
a4 = a1 + a2 # cannot add euros and dollars
a3 = a2 # cannot change variable type
dyn_price = float(input("Price (CHF): "))
dyn_discount = float(input("Discount (0.0-1.0): "))
discounted = apply_discount(
cast(Price[CHF], dyn_price),
cast(Discount, dyn_discount),
)
print(f"Discounted: CHF {discounted}")
large_data = [i * 10 for i in range(100)]
prices = unsafe_cast(list[Price[EUR]], large_data)

View File

@@ -0,0 +1,14 @@
from __future__ import annotations
from typing import Generic, TypeVar
class Currency(float): ...
_T0 = TypeVar("_T0", bound=Currency, covariant=True)
class Price(Currency, Generic[_T0]):
def __add__(self, _0: Price[_T0], /) -> Price[_T0]: ...
class EUR(Currency): ...
class USD(Currency): ...
class CHF(Currency): ...
class Discount(float): ...

View File

@@ -26,6 +26,14 @@ class MemberKind(Enum):
METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Argument]
mixed: list[FunctionType.Argument]
kw: list[FunctionType.Argument]
###<
@@ -50,9 +58,8 @@ class ExtendStmt:
class PredicateStmt:
name: Token
subject: Token
type: Type
condition: Expr
params: list[ParamSpec]
body: Expr
###<
@@ -78,6 +85,12 @@ class UnaryExpr:
right: Expr
class CallExpr:
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
class GetExpr:
expr: Expr
name: Token
@@ -128,9 +141,7 @@ class ExtensionType:
class FunctionType:
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
params: ParamSpec
returns: Type
@dataclass(frozen=True, kw_only=True)

View File

@@ -145,6 +145,7 @@ class LogicalExpr:
class CastExpr:
type: MidasType
expr: Expr
unsafe: bool
class TernaryExpr:

View File

@@ -27,6 +27,14 @@ class MemberKind(Enum):
METHOD = auto()
@dataclass(frozen=True, kw_only=True)
class ParamSpec:
l_paren: Token
pos: list[FunctionType.Argument]
mixed: list[FunctionType.Argument]
kw: list[FunctionType.Argument]
##############
# Statements #
##############
@@ -86,9 +94,8 @@ class ExtendStmt(Stmt):
@dataclass(frozen=True)
class PredicateStmt(Stmt):
name: Token
subject: Token
type: Type
condition: Expr
params: list[ParamSpec]
body: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_predicate_stmt(self)
@@ -116,6 +123,9 @@ class Expr(ABC):
@abstractmethod
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
@abstractmethod
def visit_call_expr(self, expr: CallExpr) -> T: ...
@abstractmethod
def visit_get_expr(self, expr: GetExpr) -> T: ...
@@ -161,6 +171,16 @@ class UnaryExpr(Expr):
return visitor.visit_unary_expr(self)
@dataclass(frozen=True)
class CallExpr(Expr):
callee: Expr
arguments: list[Expr]
keywords: dict[str, Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_call_expr(self)
@dataclass(frozen=True)
class GetExpr(Expr):
expr: Expr
@@ -279,9 +299,7 @@ class ExtensionType(Type):
@dataclass(frozen=True)
class FunctionType(Type):
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
params: ParamSpec
returns: Type
@dataclass(frozen=True, kw_only=True)

View File

@@ -150,13 +150,17 @@ class MidasAstPrinter(
self._write_line("PredicateStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line(f'subject: "{stmt.subject.lexeme}"')
self._write_line("type")
self._write_line("params")
with self._child_level():
for i, spec in enumerate(stmt.params):
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
self._visit_param_spec(spec)
self._write_line("body", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
self._write_line("condition", last=True)
with self._child_level(single=True):
stmt.condition.accept(self)
stmt.body.accept(self)
# Expressions
@@ -195,6 +199,29 @@ class MidasAstPrinter(
with self._child_level(single=True):
expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self._write_line("CallExpr")
with self._child_level():
self._write_line("callee")
with self._child_level(single=True):
expr.callee.accept(self)
self._write_line("arguments")
with self._child_level():
for i, arg in enumerate(expr.arguments):
self._idx = i
if i == len(expr.arguments) - 1:
self._mark_last()
arg.accept(self)
self._write_line("keywords", last=True)
with self._child_level():
for i, (name, arg) in enumerate(expr.keywords.items()):
self._idx = i
if i == len(expr.keywords) - 1:
self._mark_last()
self._write_line(name)
with self._child_level(single=True):
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr):
self._write_line("GetExpr")
with self._child_level():
@@ -276,34 +303,41 @@ class MidasAstPrinter(
def visit_function_type(self, type: m.FunctionType) -> None:
self._write_line("FunctionType")
with self._child_level():
self._write_line("pos_args")
with self._child_level():
for i, arg in enumerate(type.pos_args):
self._idx = i
if i == len(type.pos_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("args")
with self._child_level():
for i, arg in enumerate(type.args):
self._idx = i
if i == len(type.args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw_args")
with self._child_level():
for i, arg in enumerate(type.kw_args):
self._idx = i
if i == len(type.kw_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("params")
with self._child_level(single=True):
self._visit_param_spec(type.params)
self._write_line("returns", last=True)
with self._child_level(single=True):
type.returns.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
self._write_line("ParamSpec")
with self._child_level():
self._write_line("pos")
with self._child_level():
for i, arg in enumerate(spec.pos):
self._idx = i
if i == len(spec.pos) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("mixed")
with self._child_level():
for i, arg in enumerate(spec.mixed):
self._idx = i
if i == len(spec.mixed) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw", last=True)
with self._child_level():
for i, arg in enumerate(spec.kw):
self._idx = i
if i == len(spec.kw) - 1:
self._mark_last()
self._print_function_arg(arg)
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
self._write_line("Argument")
with self._child_level():
@@ -367,10 +401,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme
subject: str = stmt.subject.lexeme
type: str = stmt.type.accept(self)
condition: str = stmt.condition.accept(self)
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
body: str = stmt.body.accept(self)
return self.indented(f"predicate {name}{sig} = {body}")
def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self)
@@ -389,6 +422,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
right: str = expr.right.accept(self)
return f"{operator}{right}"
def visit_call_expr(self, expr: m.CallExpr) -> str:
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
]
return f"{expr.callee.accept(self)}({', '.join(args)})"
def visit_get_expr(self, expr: m.GetExpr):
expr_: str = expr.expr.accept(self)
name: str = expr.name.lexeme
@@ -436,9 +475,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
def visit_function_type(self, type: m.FunctionType) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
spec: str = self._visit_param_spec(type.params)
return f"fn {spec} -> {type.returns.accept(self)}"
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
args: list[str] = pos_args
if len(pos_args) != 0:
@@ -447,8 +490,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
if len(kw_args) != 0:
args.append("*")
args += kw_args
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
return f"({', '.join(args)})"
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
res: str = ""
@@ -715,9 +757,10 @@ class PythonAstPrinter(
self._write_line("type")
with self._child_level(single=True):
expr.type.accept(self)
self._write_line("expr", last=True)
self._write_line("expr")
with self._child_level(single=True):
expr.expr.accept(self)
self._write_line(f"unsafe: {expr.unsafe}", last=True)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self._write_line("TernaryExpr")

View File

@@ -350,6 +350,7 @@ class LogicalExpr(Expr):
class CastExpr(Expr):
type: MidasType
expr: Expr
unsafe: bool
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_cast_expr(self)

View File

@@ -15,6 +15,7 @@ if TYPE_CHECKING:
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"object": {"float", "list", "dict", "str"},
"float": {"int"},
"int": {"bool"},
}

View File

@@ -9,6 +9,7 @@ class DiagnosticType(StrEnum):
ERROR = "Error"
WARNING = "Warning"
INFO = "Info"
DEBUG = "Debug"
@dataclass(frozen=True)

172
midas/checker/evaluator.py Normal file
View File

@@ -0,0 +1,172 @@
from dataclasses import dataclass
from typing import Any, Callable, Optional
import midas.ast.midas as m
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import Function, Predicate
from midas.lexer.token import TokenType
@dataclass(frozen=True, kw_only=True)
class PartialPredicate(Predicate):
scope: dict[str, Any]
class Evaluator(m.Expr.Visitor[Any]):
def __init__(self, types: TypesRegistry, reporter: Optional[FileReporter] = None):
self.types: TypesRegistry = types
self.reporter: Optional[FileReporter] = reporter
self.preamble: Preamble = Preamble(self.types)
self.scopes: list[dict[str, Any]] = [{}]
def evaluate(self, expr: m.Expr) -> Any:
value: Any = expr.accept(self)
if self.reporter is not None:
self.reporter.debug(expr.location, f"Value: {value}")
return value
def get_value(self, name: str) -> Any:
scope: dict[str, Any] = self.scopes[-1]
return scope[name]
def set_value(self, name: str, value: Any, force_declare: bool = False):
if not force_declare:
for scope in reversed(self.scopes):
if name in scope:
scope[name] = value
return
self.scopes[-1][name] = value
def visit_logical_expr(self, expr: m.LogicalExpr) -> Any:
def left():
return self.evaluate(expr.left)
def right():
return self.evaluate(expr.right)
match expr.operator.type:
case TokenType.AND:
return left() and right()
case _:
raise NotImplementedError
def visit_binary_expr(self, expr: m.BinaryExpr) -> Any:
left: Any = self.evaluate(expr.left)
right: Any = self.evaluate(expr.right)
match expr.operator.type:
case TokenType.MINUS:
return left - right
case TokenType.STAR:
return left * right
case TokenType.SLASH:
return left / right
case TokenType.GREATER:
return left > right
case TokenType.GREATER_EQUAL:
return left >= right
case TokenType.LESS:
return left < right
case TokenType.LESS_EQUAL:
return left <= right
case TokenType.EQUAL_EQUAL:
return left == right
case TokenType.BANG_EQUAL:
return left != right
case _:
raise NotImplementedError
def visit_unary_expr(self, expr: m.UnaryExpr) -> Any:
right: Any = self.evaluate(expr.right)
match expr.operator.type:
case TokenType.MINUS:
return -right
case _:
raise NotImplementedError
def visit_call_expr(self, expr: m.CallExpr) -> Any:
callee: Any = self.evaluate(expr.callee)
args: list[Any] = [self.evaluate(arg) for arg in expr.arguments]
kwargs: dict[str, Any] = {
name: self.evaluate(arg) for name, arg in expr.keywords.items()
}
match callee:
case Predicate():
return self._evaluate_predicate(callee, args, kwargs)
case _ if callable(callee):
return callee(*args, **kwargs)
case _:
return NotImplementedError
def visit_get_expr(self, expr: m.GetExpr) -> Any:
obj: Any = self.evaluate(expr.expr)
return getattr(obj, expr.name.lexeme)
def visit_variable_expr(self, expr: m.VariableExpr) -> Any:
name: str = expr.name.lexeme
for scope in reversed(self.scopes):
if name in scope:
return scope[name]
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is not None:
if predicate.alias:
return self.evaluate(predicate.body)
return predicate
glob: Optional[Callable] = self.preamble.get_py_func(name)
if glob is not None:
return glob
raise NameError(f"Unknown variable '{name}'")
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Any:
return self.evaluate(expr.expr)
def visit_literal_expr(self, expr: m.LiteralExpr) -> Any:
return expr.value
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Any:
return self.get_value("_")
def _evaluate_predicate(
self, predicate: Predicate, args: list[Any], kwargs: dict[str, Any]
) -> Any:
res: Any = None
if isinstance(predicate, PartialPredicate):
self.scopes.append(predicate.scope)
else:
self.scopes.append({})
match predicate.type:
case Function(returns=Function() as inner):
self._map_args(predicate.type, args, kwargs)
res = PartialPredicate(
type=inner,
body=predicate.body,
alias=False,
scope=self.scopes[-1],
)
case Function():
self._map_args(predicate.type, args, kwargs)
res = self.evaluate(predicate.body)
case _:
raise NotImplementedError
self.scopes.pop()
return res
def _map_args(self, function: Function, args: list[Any], kwargs: dict[str, Any]):
positional: list[Function.Argument] = function.pos_args + function.args
keywords: dict[str, Function.Argument] = {
arg.name: arg for arg in function.args + function.kw_args
}
for i, arg in enumerate(args):
param: Function.Argument = positional[i]
self.set_value(param.name, arg)
for name, arg in kwargs.items():
param: Function.Argument = keywords[name]
self.set_value(param.name, arg)

View File

@@ -1,27 +1,65 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import midas.ast.midas as m
from midas.ast.location import Location
from midas.checker.builtins import define_builtins
from midas.checker.environment import Environment
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AliasType,
AppliedType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
Predicate,
Type,
TypeVar,
UnknownType,
unfold_type,
)
from midas.checker.variance import VarianceInferrer
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
@dataclass(frozen=True, kw_only=True)
class TypedParamSpec:
pos: list[Function.Argument]
mixed: list[Function.Argument]
kw: list[Function.Argument]
TypedExpr = tuple[m.Expr, Type]
class ReturnException(Exception):
pass
@dataclass(frozen=True, kw_only=True)
class MappedArgument:
expr: m.Expr
type: Type
argument: Function.Argument
@dataclass(frozen=True, kw_only=True)
class OverloadCandidate:
function: Function
mapped: list[MappedArgument]
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
@@ -31,12 +69,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
self.types: TypesRegistry = types
self._local_variables: dict[str, TypeVar] = {}
self._predicate_params: dict[str, Type] = {}
self._current_name: Optional[str] = None
define_builtins(self.types)
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
self.process(builtins_path.read_text(), str(builtins_path))
self._bool: Type = self.get_type("bool")
self._preamble: Environment = Preamble(self.types)
def process(self, source: str, path: Optional[str]):
self.reporter = self.reporter.for_file(path)
lexer: MidasLexer = MidasLexer(source)
@@ -47,6 +91,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
self.reporter.error(error.token.get_location(), error.message)
self.resolve(stmts)
def type_of(self, expr: m.Expr) -> Type:
type: Type = expr.accept(self)
return type
def get_type(self, name: str) -> Type:
"""Get a type from its name
@@ -63,6 +111,19 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
return self._local_variables[name]
return self.types.get_type(name)
def get_variable(self, name: str) -> Type:
if name in self._predicate_params:
return self._predicate_params[name]
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is not None:
return predicate.type
global_: Optional[Type] = self._preamble.get(name)
if global_ is not None:
return global_
raise NameError(f"Unknown variable '{name}'")
def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements
@@ -72,6 +133,16 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
for stmt in stmts:
stmt.accept(self)
for name, type in self.types._types.items():
if isinstance(type, GenericType):
inferrer = VarianceInferrer(self.types)
self.types._types[name] = inferrer.infer(type)
def assert_bool(self, expr: m.Expr):
type: Type = self.type_of(expr)
if not self.types.is_subtype(type, self._bool):
self.reporter.error(expr.location, f"Must be a boolean but is {type}")
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
name: str = stmt.name.lexeme
self._current_name = name
@@ -106,31 +177,163 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
for spec in stmt.params:
for param in spec.mixed:
assert param.name is not None
self._predicate_params[param.name.lexeme] = param.type.accept(self)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
type: Type = self.type_of(stmt.body)
params: list[TypedParamSpec] = [
self._visit_param_spec(spec) for spec in stmt.params
]
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
if not self._is_valid_predicate(type):
self.reporter.error(
stmt.body.location,
f"Predicate function body must evaluate to a boolean, got {type}",
)
if len(params) != 0:
type = self._bool
for spec in reversed(params):
type = Function(
pos_args=spec.pos,
args=spec.mixed,
kw_args=spec.kw,
returns=type,
)
self._predicate_params = {}
self.types.define_predicate(
stmt.name.lexeme,
Predicate(
type=type,
body=stmt.body,
alias=len(params) == 0,
),
)
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
def _is_valid_predicate(self, body: Type) -> bool:
match body:
case Function(returns=returns):
return self._is_valid_predicate(returns)
case _ if self.types.is_subtype(body, self._bool):
return True
case _:
return False
def visit_get_expr(self, expr: m.GetExpr) -> None:
self.reporter.warning(expr.location, "GetExpr not yet supported")
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
self.assert_bool(expr.left)
self.assert_bool(expr.right)
return self._bool
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
self.reporter.warning(expr.location, "VariableExpr not yet supported")
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
method: Optional[str] = MIDAS_BINARY_METHODS.get(expr.operator.type)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator.lexeme}"
)
return UnknownType()
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
def _visit_binary_expr(
self, location: Location, left_expr: m.Expr, right_expr: m.Expr, method: str
) -> Type:
left: Type = self.type_of(left_expr)
right: Type = self.type_of(right_expr)
operation: Optional[Type] = self.types.lookup_member(left, method)
if operation is None:
self.reporter.error(
location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
result: Optional[Type] = self._get_call_result(
location,
operation,
[(right_expr, right)],
{},
)
return result or UnknownType()
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator.lexeme}"
)
return UnknownType()
operand: Type = self.type_of(expr.right)
operation: Optional[Type] = self.types.lookup_member(operand, method)
if operation is None:
self.reporter.error(
expr.location,
f"Undefined operation {method} for {operand}",
)
return UnknownType()
result: Optional[Type] = self._get_call_result(
expr.location,
operation,
[],
{},
)
return result or UnknownType()
def visit_call_expr(self, expr: m.CallExpr) -> Type:
callee: Type = expr.callee.accept(self)
positional: list[TypedExpr] = [
(arg, self.type_of(arg)) for arg in expr.arguments
]
keywords: dict[str, TypedExpr] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
}
return (
self._get_call_result(
expr.location,
callee,
positional,
keywords,
)
or UnknownType()
)
def visit_get_expr(self, expr: m.GetExpr) -> Type:
object: Type = expr.expr.accept(self)
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
if member is None:
self.reporter.error(
expr.location, f"Unknown member '{expr.name.lexeme}' of {object}"
)
return UnknownType()
return member
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
return self.get_variable(expr.name.lexeme)
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
match expr.value:
case bool(): # Must be before int
return self.types.get_type("bool")
case int():
return self.types.get_type("int")
case float():
return self.types.get_type("float")
case str():
return self.types.get_type("str")
case _:
self.reporter.warning(expr.location, f"Unknown literal {expr}")
return UnknownType()
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
return self.get_variable("_")
def visit_named_type(self, type: m.NamedType) -> Type:
name: str = type.name.lexeme
@@ -153,10 +356,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
return UnknownType()
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
type_: Type = type.type.accept(self)
type.constraint.accept(self)
# TODO
return UnknownType()
return ConstraintType(
type=type.type.accept(self),
constraint=type.constraint,
)
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
return ComplexType(
@@ -172,8 +375,17 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
)
def visit_function_type(self, type: m.FunctionType) -> Type:
n_pos_args: int = len(type.pos_args)
n_args: int = len(type.args)
params: TypedParamSpec = self._visit_param_spec(type.params)
return Function(
pos_args=params.pos,
args=params.mixed,
kw_args=params.kw,
returns=type.returns.accept(self),
)
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
n_pos: int = len(spec.pos)
n_mixed: int = len(spec.mixed)
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
return Function.Argument(
@@ -183,14 +395,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
required=arg.required,
)
return Function(
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
kw_args=[
process_arg(arg, i + n_pos_args + n_args)
for i, arg in enumerate(type.kw_args)
],
returns=type.returns.accept(self),
return TypedParamSpec(
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
)
def _resolve_type_params(self, params: list[m.TypeParam]):
@@ -204,3 +412,343 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
self._local_variables[name] = var
vars.append(var)
return vars
def _get_call_result(
self,
location: Location,
callee: Type,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Type]:
"""Get the result type of a function call
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators.
Args:
location (Location): the call location
callee (Type): the called function
positional (list[TypedExpr]): the list positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Type: the return type of the call, or `None` if either
the call is invalid or no overload matched the arguments uniquely
"""
match callee:
case Function() as function:
valid: bool
mapped: list[MappedArgument]
valid, mapped = self.map_call_arguments(
function, location, positional, keywords
)
valid = valid and self._are_arguments_valid(mapped, report_errors)
if not valid:
return None
return function.returns
case OverloadedFunction(overloads=overloads):
function = self._match_overload(
overloads, location, positional, keywords, report_errors
)
if function is None:
return None
return function.returns
case AppliedType(body=body):
return self._get_call_result(
location, body, positional, keywords, report_errors
)
case UnknownType():
return UnknownType()
case _:
if report_errors:
self.reporter.error(location, f"{callee} is not callable")
return None
def _are_arguments_valid(
self,
arguments: list[MappedArgument],
report_errors: bool = True,
) -> bool:
"""Check whether the passed argument types correspond to their matched parameter definitions
Args:
arguments (list[MappedArgument]): the list of argument/parameter pairs
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
bool: True if all arguments fit the matching parameter definitions, False otherwise
"""
valid: bool = True
for arg in arguments:
if not self.types.is_subtype(arg.type, arg.argument.type):
if report_errors:
self.reporter.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
valid = False
return valid
def _match_overload(
self,
overloads: list[Type],
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> Optional[Function]:
"""Try and resolve the appropriate overload for the given arguments
Args:
overloads (list[Type]): the list of possible overloads
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keywords arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
Optional[Function]: the resolved function signature if it can be
determined unambiguously, or `None`.
"""
candidates: list[OverloadCandidate] = []
for overload in overloads:
function: Type = unfold_type(overload)
if not isinstance(function, Function):
if report_errors:
self.logger.error(
f"Overload is not a function: {overload} is {function}"
)
continue
valid, mapped = self.map_call_arguments(
function=function,
location=location,
positional=positional,
keywords=keywords,
report_errors=False,
)
if valid and self._are_arguments_valid(mapped, report_errors=False):
candidates.append(
OverloadCandidate(
function=function,
mapped=mapped,
)
)
pos_types: str = ", ".join(str(type) for _, type in positional)
kw_types: str = ", ".join(
f"{name}: {type}" for name, (_, type) in keywords.items()
)
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
n_candidates: int = len(candidates)
# Exactly 1 match -> return it
if n_candidates == 1:
return candidates[0].function
# No match -> invalid call
if n_candidates == 0:
overloads_str: str = ", ".join(map(str, overloads))
if report_errors:
self.reporter.error(
location,
f"No matching overload in [{overloads_str}] {for_args}",
)
return None
# Multiple matches -> see if one <: all others (more specific)
for i1, c1 in enumerate(candidates):
mapped1: list[MappedArgument] = c1.mapped
best_match: bool = True
for i2, c2 in enumerate(candidates):
if i1 == i2:
continue
mapped2: list[MappedArgument] = c2.mapped
if not self._are_mapped_subtypes(mapped1, mapped2):
best_match = False
break
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
if best_match:
return c1.function
candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates
)
if report_errors:
self.reporter.error(
location,
f"Multiple matching overloads {for_args}: {candidates_str}",
)
return None
def map_call_arguments(
self,
function: Function,
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> tuple[bool, list[MappedArgument]]:
"""Map call arguments to a function's parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic,
unless `report_errors` is set to `False`
Args:
function (Function): the function definition
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
tuple[bool, list[MappedArgument]]: a boolean reporting whether
the call is valid and the list of mapped arguments
"""
set_args: set[str] = set()
required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required
]
required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required
]
mapped: list[MappedArgument] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
valid_call: bool = True
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
if report_errors:
self.reporter.error(
arg[0].location, "Too many positional arguments"
)
valid_call = False
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if report_errors:
if name in set_args:
self.reporter.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.reporter.error(
arg[0].location, f"Unknown keyword argument '{name}'"
)
valid_call = False
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
def join_args(args: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args))
if len(args) == 0:
return ""
if len(args) == 1:
return args[0]
return ", ".join(args[:-1]) + " and " + args[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
if report_errors:
self.reporter.error(
location,
f"Missing required positional argument{plural}: {args}",
)
valid_call = False
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
if report_errors:
self.reporter.error(
location,
f"Missing required keyword argument{plural}: {args}",
)
valid_call = False
return valid_call, mapped
def _are_mapped_subtypes(
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
) -> bool:
"""Check whether the given argument mappings are subtype/supertype of one another
This function checks whether the argument mappings `mapped1` are subtypes
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
of the corresponding parameter in `mapped2`, `False` is returned.
This is used to check whether a given overload is
a more specific function/ a subtype of another.
Args:
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
Returns:
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
"""
by_expr: dict[m.Expr, Type] = {}
for arg in mapped1:
by_expr[arg.expr] = arg.argument.type
for arg in mapped2:
type2: Type = arg.argument.type
type1: Type = by_expr[arg.expr]
if not self.types.is_subtype(type1, type2):
return False
return True

View File

@@ -1,7 +1,9 @@
import ast
from typing import Type
OPERATOR_METHODS: dict[Type[ast.operator], str] = {
from midas.lexer.token import TokenType
PY_OPERATOR_METHODS: dict[Type[ast.operator], str] = {
ast.Add: "__add__",
ast.Sub: "__sub__",
ast.Mult: "__mul__",
@@ -17,9 +19,9 @@ OPERATOR_METHODS: dict[Type[ast.operator], str] = {
ast.FloorDiv: "__floordiv__",
}
COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
PY_COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
ast.Eq: "__eq__",
# ast.NotEq: "__noteq__",
ast.NotEq: "__eq__",
ast.Lt: "__lt__",
ast.LtE: "__le__",
ast.Gt: "__gt__",
@@ -30,9 +32,40 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
# ast.NotIn: "__notin__",
}
UNARY_METHODS: dict[Type[ast.unaryop], str] = {
PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
ast.Invert: "__invert__",
# ast.Not: "",
ast.UAdd: "__pos__",
ast.USub: "__neg__",
}
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
# TokenType.PLUS: "__add__",
TokenType.MINUS: "__sub__",
TokenType.STAR: "__mul__",
TokenType.SLASH: "__truediv__",
# TokenType.MODULO: "__mod__",
# TokenType.POW: "__pow__",
# ast.BitOr: "__or__",
# ast.BitXor: "__xor__",
# ast.BitAnd: "__and__",
# ast.FloorDiv: "__floordiv__",
TokenType.EQUAL_EQUAL: "__eq__",
TokenType.BANG_EQUAL: "__eq__",
TokenType.LESS: "__lt__",
TokenType.LESS_EQUAL: "__le__",
TokenType.GREATER: "__gt__",
TokenType.GREATER_EQUAL: "__ge__",
# ast.Is: "__is__",
# ast.IsNot: "__isnot__",
# ast.In: "__in__",
# ast.NotIn: "__notin__",
}
MIDAS_UNARY_METHODS: dict[TokenType, str] = {
# ast.Invert: "__invert__",
# ast.Not: "",
# TokenType.PLUS: "__pos__",
TokenType.MINUS: "__neg__",
}

View File

@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Callable, Optional
from midas.checker.environment import Environment
from midas.checker.registry import TypesRegistry
@@ -16,16 +17,18 @@ class Preamble(Environment):
def __init__(self, types: TypesRegistry) -> None:
super().__init__()
self._types: TypesRegistry = types
self._python_funcs: dict[str, Callable] = {}
self._def_type_constructor("object")
self._def_type_constructor("float")
self._def_type_constructor("int")
self._def_type_constructor("bool")
self._def_type_constructor("str")
self._def_type_constructor("object", object)
self._def_type_constructor("float", float)
self._def_type_constructor("int", int)
self._def_type_constructor("bool", bool)
self._def_type_constructor("str", str)
self._def_function(
name="list",
pos=[Param("object", TopType())],
returns=self._list_of(TopType()),
py_function=list,
)
# TODO: use sink
@@ -33,6 +36,7 @@ class Preamble(Environment):
name="print",
pos=[Param("object", TopType())],
returns=UnitType(),
py_function=print,
)
map_in = TypeVar(name="T", bound=None)
@@ -52,17 +56,25 @@ class Preamble(Environment):
),
],
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
type_vars=[map_in, map_out],
py_function=map,
)
self._def_function(
name="input",
pos=[Param("prompt", TopType(), required=False)],
returns=self._types.get_type("str"),
)
def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type])
def _def_type_constructor(self, name: str):
def _def_type_constructor(self, name: str, py_function: Optional[Callable] = None):
# TODO: more specific arg types
self._def_function(
name=name,
pos=[Param("object", TopType(), required=False)],
returns=self._types.get_type(name),
py_function=py_function,
)
def _make_function(
@@ -109,6 +121,7 @@ class Preamble(Environment):
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
py_function: Optional[Callable] = None,
):
function: Type = self._make_function(
name=name,
@@ -119,3 +132,8 @@ class Preamble(Environment):
type_vars=type_vars,
)
self.define(name, function)
if py_function is not None:
self._python_funcs[name] = py_function
def get_py_func(self, name: str) -> Optional[Callable]:
return self._python_funcs.get(name)

View File

@@ -1,12 +1,18 @@
import ast
import logging
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
from midas.checker.evaluator import Evaluator
from midas.checker.operators import (
PY_COMPARATOR_METHODS,
PY_OPERATOR_METHODS,
PY_UNARY_METHODS,
)
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
@@ -14,13 +20,19 @@ from midas.checker.resolver import Resolver
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ConstraintType,
Function,
GenericType,
OverloadedFunction,
Type,
TypeVar,
UnitType,
UnknownType,
Variance,
unfold_type,
)
from midas.checker.unifier import Unifier
from midas.parser.python import PythonParser
from midas.utils import TypedAST
@@ -63,6 +75,7 @@ class PythonTyper(
self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = []
self.evaluated_casts: list[p.CastExpr] = []
def process(self, source: str, path: Optional[str]) -> TypedAST:
self.reporter = self.reporter.for_file(path)
@@ -76,10 +89,15 @@ class PythonTyper(
self.env = self.global_env
self.locals = resolver.locals
self.judgements = []
self.evaluated_casts = []
self.check(stmts)
return TypedAST(stmts=stmts, judgements=self.judgements)
return TypedAST(
stmts=stmts,
judgements=self.judgements,
evaluated_casts=self.evaluated_casts,
)
def judge(self, expr: p.Expr, type: Type):
"""Record a typing judgement
@@ -223,7 +241,8 @@ class PythonTyper(
)
pos += 1
for arg in pos_args + args + kw_args:
all_args: list[Function.Argument] = pos_args + args + kw_args
for arg in all_args:
env.define(arg.name, arg.type)
returns_hint: Optional[Type] = None
@@ -264,12 +283,25 @@ class PythonTyper(
returns = inferred_return
# TODO: handle *args and **kwargs sinks
function: Function = Function(
function: Type = Function(
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
)
generic_params: list[TypeVar] = []
all_types: list[Type] = [arg.type for arg in all_args] + [returns]
for type in all_types:
if isinstance(type, TypeVar):
if type not in generic_params:
generic_params.append(type)
if len(generic_params) != 0:
function = GenericType(
name=stmt.name,
params=generic_params,
body=function,
)
self.env.define(stmt.name, function)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
@@ -377,7 +409,7 @@ class PythonTyper(
pass
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
method: Optional[str] = PY_OPERATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
@@ -388,7 +420,7 @@ class PythonTyper(
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
@@ -421,7 +453,7 @@ class PythonTyper(
return result or UnknownType()
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
@@ -447,6 +479,10 @@ class PythonTyper(
return result or UnknownType()
def visit_call_expr(self, expr: p.CallExpr) -> Type:
match expr.callee:
case p.VariableExpr(name="TypeVar"):
return self.define_typevar(expr) or UnknownType()
callee: Type = self.type_of(expr.callee)
positional: list[TypedExpr] = [
(arg, self.type_of(arg)) for arg in expr.arguments
@@ -512,7 +548,16 @@ class PythonTyper(
return UnknownType()
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
return self.resolve_type_expr(expr.type)
subject_type: Type = self.type_of(expr.expr)
target_type: Type = self.resolve_type_expr(expr.type)
is_lit, lit_value = self._get_literal(expr.expr)
if is_lit:
evaluated: bool = self._evaluate_cast_statically(
expr, subject_type, target_type, lit_value
)
if evaluated:
self.evaluated_casts.append(expr)
return target_type
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
test_type: Type = self.type_of(expr.test)
@@ -653,7 +698,7 @@ class PythonTyper(
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accomodate
The function doesn't take the raw expression as a parameter to accommodate
for desugared calls such as for operators.
Args:
@@ -700,6 +745,28 @@ class PythonTyper(
location, base, positional, keywords, report_errors
)
case GenericType():
unifier: Unifier = Unifier(self.types)
pos: list[Type] = [a[1] for a in positional]
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
if unified is None:
if report_errors:
pos_str: str = ", ".join(str(t) for t in pos)
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
self.reporter.error(
location,
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}",
)
return None
return self._get_call_result(
location,
unified,
positional,
keywords,
report_errors,
)
case _:
if report_errors:
self.reporter.error(
@@ -752,7 +819,7 @@ class PythonTyper(
Returns:
Optional[Function]: the resolved function signature if it can be
determined unambigously, or `None`.
determined unambiguously, or `None`.
"""
candidates: list[OverloadCandidate] = []
for overload in overloads:
@@ -1005,3 +1072,147 @@ class PythonTyper(
report_errors=False,
)
return result
def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]:
def is_kw_true(name: str) -> bool:
match call.keywords.get(name):
case p.LiteralExpr(value=True):
return True
case _:
return False
match call:
case p.CallExpr(
arguments=[p.LiteralExpr(value=str() as name)],
):
bound: Optional[Type] = None
variance: Variance = Variance.INVARIANT
if "bound" in call.keywords:
bound_type: p.MidasType = self._parse_type_from_expr(
call.keywords["bound"]
)
bound = self.resolve_type_expr(bound_type)
if is_kw_true("covariant"):
variance = Variance.COVARIANT
if is_kw_true("contravariant"):
if variance == Variance.COVARIANT:
self.reporter.warning(
call.keywords["contravariant"].location,
"TypeVar cannot be covariant and contravariant at the same time. Marked as invariant",
)
variance = Variance.INVARIANT
else:
variance = Variance.CONTRAVARIANT
var: TypeVar = TypeVar(name=name, bound=bound, variance=variance)
self.types.define_type(name, var)
return var
case _:
self.reporter.warning(
call.location, "Invalid usage of 'TypeVar', skipping"
)
return None
def _parse_type_from_expr(self, expr: p.Expr) -> p.MidasType:
location: Location = expr.location
parser = PythonParser()
match expr:
case p.LiteralExpr(value=str() as value):
node: ast.Expression = ast.parse(value, mode="eval")
return parser._parse_type(node.body)
case p.VariableExpr(name=name):
return p.BaseType(location=location, base=name, param=None)
case _:
raise NotImplementedError
def _get_literal(self, expr: p.Expr) -> tuple[bool, Any]:
match expr:
case p.LiteralExpr(value=value):
return True, value
case p.ListExpr(items=items):
values: list[Any] = []
for item in items:
is_lit, value = self._get_literal(item)
if not is_lit:
return False, None
values.append(value)
return True, values
case p.DictExpr(keys=keys, values=values):
pairs: list[tuple[Any, Any]] = []
for key, value in zip(keys, values):
key_val = None
if key is not None:
is_lit, key_val = self._get_literal(key)
if not is_lit:
return False, None
is_lit, value_val = self._get_literal(value)
if not is_lit:
return False, None
if key is None:
# TODO: check that value is always a dict
assert isinstance(value_val, dict)
pairs.extend(value_val.items())
else:
pairs.append((key_val, value_val))
return True, dict(pairs)
case _:
return False, None
def _evaluate_cast_statically(
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
) -> bool:
match target_type:
case AliasType(type=base):
return self._evaluate_cast_statically(
expr, subject_type, base, lit_value
)
case AppliedType(body=body):
return self._evaluate_cast_statically(
expr, subject_type, body, lit_value
)
case ConstraintType(type=base, constraint=constraint):
evaluated: bool = True
if not self._evaluate_cast_statically(
expr, subject_type, base, lit_value
):
evaluated = False
evaluator = Evaluator(self.types)
evaluator.set_value("_", lit_value)
res = evaluator.evaluate(constraint)
if not res:
printer = MidasPrinter()
constraint_str: str = printer.print(constraint)
self.reporter.error(
expr.location,
f"Value {lit_value!r} does not fit constraint '{constraint_str}'",
)
evaluated = False
return evaluated
case BaseType():
# TODO: do we want to allow cast(float, int)? would require runtime conversion
if not self.types.is_subtype(
subject_type, target_type
) or not self.types.is_subtype(target_type, subject_type):
self.reporter.error(
expr.location,
f"Value {lit_value!r} of type {subject_type} cannot be cast as {target_type}",
)
return False
return True
case _:
self.reporter.info(
expr.location, f"Cannot evaluate cast to {target_type} statically"
)
return False

View File

@@ -9,14 +9,17 @@ from midas.checker.types import (
AppliedType,
BaseType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
Predicate,
TopType,
Type,
TypeVar,
UnknownType,
Variance,
substitute_typevars,
)
@@ -32,6 +35,7 @@ class TypesRegistry:
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Member]] = {}
self._predicates: dict[str, Predicate] = {}
def get_type(self, name: str) -> Type:
"""Get a type from its name
@@ -101,6 +105,11 @@ class TypesRegistry:
else:
members[member_name] = Member(kind=kind, type=member_type)
def define_predicate(self, name: str, predicate: Predicate):
if name in self._predicates:
raise ValueError(f"Predicate {name} already defined")
self._predicates[name] = predicate
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`
@@ -121,6 +130,19 @@ class TypesRegistry:
case (_, TopType()):
return True
case (_, UnknownType()):
return True
case (TypeVar(bound=bound), _):
if bound is None:
return False
return self.is_subtype(bound, type2)
case (_, TypeVar(bound=bound)):
if bound is None:
return True
return self.is_subtype(type1, bound)
case (AliasType(type=base1), _):
return self.is_subtype(base1, type2)
@@ -138,10 +160,30 @@ class TypesRegistry:
case (Function(), Function()):
return self.is_func_subtype(type1, type2)
case (TypeVar(bound=bound), _):
if bound is None:
case (ConstraintType(type=base1), _):
return self.is_subtype(base1, type2)
case (
AppliedType(name=name1, args=args1),
AppliedType(name=name2, args=args2),
) if (
name1 == name2
):
generic: Type = self.get_type(name1)
assert isinstance(generic, GenericType)
for param, arg1, arg2 in zip(generic.params, args1, args2):
variance: Variance = param.variance
if variance in {Variance.INVARIANT, Variance.COVARIANT}:
if not self.is_subtype(arg1, arg2):
return False
return self.is_subtype(bound, type2)
if variance in {Variance.INVARIANT, Variance.CONTRAVARIANT}:
if not self.is_subtype(arg2, arg1):
return False
return True
# TODO: verify legitimacy
case (AppliedType(body=body), _):
return self.is_subtype(body, type2)
return False
@@ -359,9 +401,18 @@ class TypesRegistry:
)
return self.lookup_member(base, member_name)
case ConstraintType(type=base):
return self.lookup_member(base, member_name)
case TypeVar(bound=bound) if bound is not None:
return self.lookup_member(bound, member_name)
case UnknownType():
return UnknownType()
case _:
self.logger.debug(f"Can't get member on {type}")
return None
def lookup_predicate(self, name: str) -> Optional[Predicate]:
return self._predicates.get(name)

View File

@@ -61,3 +61,10 @@ class FileReporter:
location=location,
message=message,
)
def debug(self, location: Location, message: str):
self.report(
type=DiagnosticType.DEBUG,
location=location,
message=message,
)

View File

@@ -1,7 +1,11 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional
from enum import StrEnum
from typing import Optional, assert_never
import midas.ast.midas as m
from midas.ast.printer import MidasPrinter
@dataclass(frozen=True, kw_only=True)
@@ -99,15 +103,27 @@ class ExtensionType:
return f"{self.base} & {self.extension}"
class Variance(StrEnum):
INVARIANT = "INVARIANT"
COVARIANT = "COVARIANT"
CONTRAVARIANT = "CONTRAVARIANT"
@dataclass(frozen=True, kw_only=True)
class TypeVar:
name: str
bound: Optional[Type]
variance: Variance = Variance.INVARIANT
def __str__(self) -> str:
variance: str = {
Variance.COVARIANT: "+",
Variance.CONTRAVARIANT: "-",
}.get(self.variance, "")
res: str = f"{variance}{self.name}"
if self.bound is not None:
return f"{self.name} <: {self.bound}"
return self.name
res = f"{res} <: {self.bound}"
return res
@dataclass(frozen=True, kw_only=True)
@@ -130,6 +146,16 @@ class AppliedType:
return f"{self.name}[{', '.join(map(str, self.args))}]"
@dataclass(frozen=True, kw_only=True)
class ConstraintType:
type: Type
constraint: m.Expr
def __str__(self) -> str:
printer = MidasPrinter()
return f"{self.type} where {printer.print(self.constraint)}"
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument):
return Function.Argument(
@@ -198,6 +224,12 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
body=substitute_typevars(body, substitutions),
)
case ConstraintType():
return ConstraintType(
type=substitute_typevars(type.type, substitutions),
constraint=type.constraint,
)
case TypeVar(name=name):
if name in substitutions:
return substitutions[name]
@@ -221,9 +253,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
case UnknownType() | UnitType():
return type
case _:
case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness
case _:
assert_never(type)
def unfold_type(type: Type) -> Type:
match type:
@@ -233,6 +269,65 @@ def unfold_type(type: Type) -> Type:
return type
def to_annotation(type: Type) -> str:
def _args_annotation(func: Function) -> str:
if len(func.kw_args) != 0:
return "..."
args: str = ", ".join(
to_annotation(arg.type) for arg in func.pos_args + func.args
)
return f"[{args}]"
match type:
case TopType():
return "Any"
case BaseType(name=name):
return name
case AliasType(name=name):
return name
case UnknownType():
return "Any"
case UnitType():
return "None"
case Function(returns=returns):
params_annot: str = _args_annotation(type)
return f"Callable[{params_annot}, {to_annotation(returns)}]"
case OverloadedFunction():
return "Callable"
case ComplexType() | ExtensionType():
raise NotImplementedError
case TypeVar(name=name):
return name
case GenericType(name=name, params=params):
return f"{name}[{', '.join(map(to_annotation, params))}]"
case AppliedType(name=name, args=args):
return f"{name}[{', '.join(map(to_annotation, args))}]"
case ConstraintType():
return str(type)
case _:
assert_never(type)
@dataclass(frozen=True, kw_only=True)
class Predicate:
type: Type
body: m.Expr
alias: bool
Type = (
TopType
| BaseType
@@ -246,4 +341,5 @@ Type = (
| TypeVar
| GenericType
| AppliedType
| ConstraintType
)

169
midas/checker/unifier.py Normal file
View File

@@ -0,0 +1,169 @@
import logging
from typing import Optional
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AppliedType,
Function,
GenericType,
TopType,
Type,
TypeVar,
)
class UnificationError(Exception): ...
class Unifier:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.logger: logging.Logger = logging.getLogger("Unifier")
def unify_call(
self,
type: GenericType,
positional: list[Type],
keywords: dict[str, Type],
) -> Optional[Type]:
concrete_func: Function = Function(
pos_args=[
Function.Argument(
pos=i,
name=str(i),
type=arg,
required=True,
)
for i, arg in enumerate(positional)
],
args=[],
kw_args=[
Function.Argument(
pos=len(positional) + i,
name=name,
type=arg,
required=True,
)
for i, (name, arg) in enumerate(keywords.items())
],
returns=TopType(), # TODO: use expected type
)
return self.unify_generic(type, concrete_func, match_return=False)
def unify_generic(
self,
template: GenericType,
concrete: Type,
match_return: bool = True,
) -> Optional[Type]:
substitutions: dict[str, Type]
try:
substitutions = self.match(template.body, concrete, match_return)
except UnificationError:
return None
args: list[Type] = []
for param in template.params:
if param.name not in substitutions:
return None
args.append(substitutions[param.name])
applied: Type = self.types.apply_generic(template, args)
return applied
def match(
self,
template: Type,
concrete: Type,
match_return: bool = True,
) -> dict[str, Type]:
# TODO: if concrete is Generic, record bound TypeVar. Then when merging
# substitutions, check that the constraint is respected
match (template, concrete):
case (TypeVar(name=name), _):
return {name: concrete}
case (
AppliedType(name=template_name, args=template_args),
AppliedType(name=concrete_name, args=concrete_args),
) if template_name == concrete_name and len(template_args) == len(
concrete_args
):
substitutions: dict[str, Type] = {}
for template_arg, concrete_arg in zip(template_args, concrete_args):
new_substistutions: dict[str, Type] = self.match(
template_arg, concrete_arg
)
substitutions = self.merge(substitutions, new_substistutions)
return substitutions
case (Function(), Function()):
mapped: list[tuple[Function.Argument, Function.Argument]] = (
self.map_params(template, concrete)
)
substitutions: dict[str, Type] = {}
for template_arg, concrete_arg in mapped:
arg_subs: dict[str, Type] = self.match(
template_arg.type, concrete_arg.type
)
substitutions = self.merge(substitutions, arg_subs)
if match_return:
return_subs: dict[str, Type] = self.match(
template.returns, concrete.returns
)
substitutions = self.merge(substitutions, return_subs)
return substitutions
case _:
self.logger.debug(f"Can't match {concrete!r} with {template!r}")
return {}
def merge(self, subs1: dict[str, Type], subs2: dict[str, Type]) -> dict[str, Type]:
merged: dict[str, Type] = subs1.copy()
for k, v in subs2.items():
if k in merged and merged[k] != v:
self.logger.debug(
f"Substitution already defined for {k} with type {merged[k]}, got {v}"
)
raise UnificationError
merged[k] = v
return merged
def map_params(
self, func1: Function, func2: Function
) -> list[tuple[Function.Argument, Function.Argument]]:
pos1: list[Function.Argument] = func1.pos_args
mixed1: list[Function.Argument] = func1.args
kw1: list[Function.Argument] = func1.kw_args
pos2: list[Function.Argument] = func2.pos_args
mixed2: list[Function.Argument] = func2.args
kw2: list[Function.Argument] = func2.kw_args
mapped: list[tuple[Function.Argument, Function.Argument]] = []
by_pos2: dict[int, Function.Argument] = {arg.pos: arg for arg in pos2 + mixed2}
by_name2: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2 + kw2}
for arg1 in pos1:
if (arg2 := by_pos2.get(arg1.pos)) is not None:
mapped.append((arg1, arg2))
for arg1 in mixed1:
# Match both positionally and by name, conflicts are caught
# when merging substitutions
if (arg2 := by_pos2.get(arg1.pos)) is not None:
mapped.append((arg1, arg2))
if (arg2 := by_name2.get(arg1.name)) is not None:
mapped.append((arg1, arg2))
for arg1 in kw1:
if (arg2 := by_name2.get(arg1.name)) is not None:
mapped.append((arg1, arg2))
return mapped

129
midas/checker/variance.py Normal file
View File

@@ -0,0 +1,129 @@
from typing import Literal, Optional, cast
from midas.checker.registry import Member, TypesRegistry
from midas.checker.types import (
AppliedType,
ConstraintType,
Function,
GenericType,
OverloadedFunction,
Type,
TypeVar,
Variance,
)
Polarity = Literal[-1, 0, 1]
class Tracker:
def __init__(self, vars: list[TypeVar]) -> None:
self.vars: list[TypeVar] = vars
self.refs: dict[str, set[Polarity]] = {var.name: set() for var in self.vars}
def record(self, var: TypeVar, polarity: Polarity):
self.refs[var.name].add(polarity)
def get_updated_vars(self) -> list[TypeVar]:
return [
TypeVar(
name=var.name, bound=var.bound, variance=self.get_variance(var.name)
)
for var in self.vars
]
def get_variance(self, name: str) -> Variance:
refs: set[Polarity] = self.refs[name]
if refs == {-1}:
return Variance.CONTRAVARIANT
if refs == {1}:
return Variance.COVARIANT
return Variance.INVARIANT
def __contains__(self, item: TypeVar | str):
if isinstance(item, TypeVar):
return item.name in self
return item in self.refs
class VarianceInferrer:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.tracker: Tracker = Tracker([])
def infer(self, type: GenericType) -> GenericType:
self.tracker = Tracker(type.params)
self.walk(type.body, 1, type.name)
members: dict[str, Member] = self.types._members.get(type.name, {})
for name, member in members.items():
self.walk(member.type, 1, type.name, [f"member:'{name}'"])
return GenericType(
name=type.name,
params=self.tracker.get_updated_vars(),
body=type.body,
)
def walk(
self,
type: Type,
polarity: Polarity,
base_name: str,
path: Optional[list[str]] = None,
):
if path is None:
path = []
match type:
# Arguments are negative positions -> flip polarity
# Return is positive position -> keep polarity
case Function(pos_args=pos_args, args=mixed_args, kw_args=kw_args):
all_args: list[Function.Argument] = pos_args + mixed_args + kw_args
for arg in all_args:
self.walk(
arg.type,
-polarity,
base_name,
path + [f"arg:'{arg.name}'"],
)
self.walk(type.returns, polarity, base_name, path + ["return"])
# Walk all overloads
case OverloadedFunction(overloads=overloads):
for overload in overloads:
self.walk(overload, polarity, base_name, path)
# If same name as root generic -> skip
# Get inferred variance of parameters and multiply with current
# polarity to recurse through arguments
case AppliedType(name=name, args=args):
# TODO: handle mutually recursive types
if name == base_name:
return
generic: Type = self.types.get_type(name)
assert isinstance(generic, GenericType)
params: list[TypeVar] = generic.params
polarities: dict[Variance, Polarity] = {
Variance.INVARIANT: 0,
Variance.COVARIANT: 1,
Variance.CONTRAVARIANT: -1,
}
for arg, param in zip(args, params):
param_polarity: Polarity = polarities[param.variance]
self.walk(
arg,
cast(Polarity, polarity * param_polarity),
base_name,
path + [f"applied:'{name}'"],
)
# Walk base type
case ConstraintType(type=base):
self.walk(base, polarity, base_name, path + ["constraint"])
# Reached end
# If tracked, record polarity
case TypeVar():
if type in self.tracker:
self.tracker.record(type, polarity)

View File

@@ -19,9 +19,11 @@ from midas.utils import TypedAST
@click.command(help="Compile source")
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("--ignore-errors", is_flag=True)
def compile(
file: TextIO,
types: tuple[TextIO],
ignore_errors: bool,
):
source: str = file.read()
source_path: Path = Path(file.name).resolve()
@@ -35,8 +37,10 @@ def compile(
printer = DiagnosticPrinter()
printer.print_all(diagnostics)
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
if not ignore_errors and any(
map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)
):
sys.exit(1)
generator = Generator(workdir=source_path.parent)
generator = Generator(workdir=source_path.parent, types=checker.types)
generator.generate(typed_ast, source_path)

View File

@@ -8,7 +8,9 @@ from typing import TextIO
import click
from midas.ast.printer import MidasPrinter
from midas.checker.checker import TypeChecker
from midas.checker.registry import Member
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
@@ -35,10 +37,30 @@ def dump_registry(
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
print("##### Types #####")
for name, type in checker.types._types.items():
members: dict[str, Type] = checker.types._members.get(name, {})
print(f"{name} = {base_type(type)}")
members: dict[str, Member] = checker.types._members.get(name, {})
params: str = ""
if isinstance(type, GenericType):
params = ", ".join(map(str, type.params))
params = f"[{params}]"
print(f"{name}{params} = {base_type(type)}")
if len(members) != 0:
print(" " * 4 + "Members:")
for member_name, member_type in members.items():
print(" " * 8 + f"{member_name}: {member_type}")
for member_name, member in members.items():
kind: str = member.kind.name
print(" " * 8 + f"({kind:8}) {member_name}: {member.type}")
print("##### Predicates #####")
printer = MidasPrinter()
for name, predicate in checker.types._predicates.items():
body: str = printer.print(predicate.body)
if predicate.alias:
print(f"{name}: {predicate.type} = {body}")
else:
print(f"{name}{predicate.type}:")
body = "\n".join(
" " + ("return " if i == 0 else "") + line
for i, line in enumerate(body.split("\n"))
)
print(body)

View File

@@ -1,27 +1,64 @@
import ast
import time
from pathlib import Path
from typing import TextIO
import black
import click
from watchdog.events import DirModifiedEvent, FileModifiedEvent, FileSystemEventHandler
from watchdog.observers import Observer
from midas.checker.checker import TypeChecker
from midas.generator.stubs import StubsGenerator
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def stubs(
file: TextIO,
output: TextIO,
):
source_path: Path = Path(file.name).resolve()
def generate_stubs(in_path: Path, out_path: Path):
checker = TypeChecker()
checker.import_midas(source_path)
checker.import_midas(in_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output.write(ast.unparse(module))
output: str = ast.unparse(module)
output = black.format_str(output, mode=black.Mode(is_pyi=True))
out_path.write_text(output)
class Handler(FileSystemEventHandler):
def __init__(self, in_path: Path, out_path: Path) -> None:
super().__init__()
self.in_path: Path = in_path
self.out_path: Path = out_path
def on_modified(self, event: DirModifiedEvent | FileModifiedEvent) -> None:
generate_stubs(self.in_path, self.out_path)
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.option("-w", "--watch", is_flag=True)
def stubs(
file: TextIO,
output: TextIO,
watch: bool,
):
source_path: Path = Path(file.name).resolve()
out_path: Path = Path(output.name).resolve()
generate_stubs(source_path, out_path)
if watch:
print(f"Watching {source_path}...")
print("Press CTRL+C to stop")
handler = Handler(source_path, out_path)
observer = Observer()
observer.schedule(handler, str(source_path))
observer.start()
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
observer.stop()
observer.join()

View File

@@ -41,6 +41,7 @@ def types(
message=f"Type: {type}",
)
)
diagnostics.extend(checker.diagnostics)
printer = DiagnosticPrinter()
printer.print_all(diagnostics)

View File

@@ -228,6 +228,13 @@ class PythonHighlighter(
for item in expr.items:
item.accept(self)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
for key in expr.keys:
if key is not None:
key.accept(self)
for value in expr.values:
value.accept(self)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
expr.object.accept(self)
expr.index.accept(self)
@@ -240,6 +247,10 @@ class PythonHighlighter(
if expr.step is not None:
expr.step.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
@@ -266,8 +277,9 @@ class MidasHighlighter(
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate")
self.wrap(LocatableToken(stmt.name), "predicate-name")
stmt.type.accept(self)
stmt.condition.accept(self)
for spec in stmt.params:
self._visit_param_spec(spec)
stmt.body.accept(self)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.wrap(expr, "logical-expr")
@@ -283,6 +295,14 @@ class MidasHighlighter(
self.wrap(expr, "unary-expr")
expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self.wrap(expr, "call-expr")
expr.callee.accept(self)
for arg in expr.arguments:
arg.accept(self)
for arg in expr.keywords.values():
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr) -> None:
self.wrap(expr, "get-expr")
expr.expr.accept(self)
@@ -318,8 +338,7 @@ class MidasHighlighter(
def visit_function_type(self, type: m.FunctionType) -> None:
self.wrap(type, "function")
for arg in type.pos_args + type.args + type.kw_args:
arg.type.accept(self)
self._visit_param_spec(type.params)
type.returns.accept(self)
def visit_extension_type(self, type: m.ExtensionType) -> None:
@@ -327,6 +346,10 @@ class MidasHighlighter(
type.base.accept(self)
type.extension.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
for param in spec.pos + spec.mixed + spec.kw:
param.type.accept(self)
class DiagnosticsHighlighter(Highlighter):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"

View File

@@ -1,3 +1,4 @@
from collections import defaultdict
from pathlib import Path
from typing import Optional
@@ -7,6 +8,13 @@ from midas.cli.ansi import Ansi
class DiagnosticPrinter:
COLORS: dict[DiagnosticType, int] = {
DiagnosticType.ERROR: Ansi.RED,
DiagnosticType.WARNING: Ansi.YELLOW,
DiagnosticType.INFO: Ansi.CYAN,
DiagnosticType.DEBUG: Ansi.MAGENTA,
}
def __init__(self) -> None:
self.files: dict[Optional[str], list[str]] = {}
@@ -22,10 +30,25 @@ class DiagnosticPrinter:
return self.files[filename]
def print_all(self, diagnostics: list[Diagnostic], indent: int = 4):
by_type: dict[DiagnosticType, int] = defaultdict(int)
for diagnostic in diagnostics:
filename: Optional[str] = diagnostic.file_path
lines = self.get_lines(filename)
self.print(lines, diagnostic, indent=indent)
by_type[diagnostic.type] += 1
if len(diagnostics) == 0:
return
counts: list[str] = []
for type in DiagnosticType:
if type not in by_type:
continue
count: int = by_type[type]
color: int = self.COLORS.get(type, Ansi.WHITE)
counts.append(f"{Ansi.FG(color)}{type.value}s{Ansi.RESET}: {count}")
print(" ".join(counts))
def print(self, lines: list[str], diagnostic: Diagnostic, indent: int = 4):
"""Pretty-print a diagnostic, showing some context if possible
@@ -55,11 +78,7 @@ class DiagnosticPrinter:
before: str = line[:start_offset]
after: str = line[end_offset:]
color: int = {
DiagnosticType.ERROR: Ansi.RED,
DiagnosticType.WARNING: Ansi.YELLOW,
DiagnosticType.INFO: Ansi.CYAN,
}.get(diagnostic.type, Ansi.WHITE)
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
cursor: str = (

View File

@@ -0,0 +1,224 @@
import ast
from typing import Optional
import midas.ast.midas as m
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
Function,
Predicate,
Type,
to_annotation,
)
from midas.lexer.token import TokenType
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
TokenType.AND: ast.And,
# TokenType.OR: ast.Or,
}
BINARY_OPERATORS: dict[TokenType, type[ast.operator]] = {
# TokenType.PLUS: ast.Add,
TokenType.MINUS: ast.Sub,
TokenType.STAR: ast.Mult,
TokenType.SLASH: ast.Div,
}
UNARY_OPERATORS: dict[TokenType, type[ast.unaryop]] = {
# TokenType.PLUS: ast.UAdd,
TokenType.MINUS: ast.USub,
}
COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = {
TokenType.GREATER: ast.Gt,
TokenType.GREATER_EQUAL: ast.GtE,
TokenType.LESS: ast.Lt,
TokenType.LESS_EQUAL: ast.LtE,
TokenType.EQUAL_EQUAL: ast.Eq,
TokenType.BANG_EQUAL: ast.NotEq,
}
class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
def __init__(self, types: TypesRegistry):
self.types: TypesRegistry = types
self._id: int = 0
self._definitions: list[ast.stmt] = []
self._aliases: dict[str, str] = {}
def get_definitions(self) -> list[ast.stmt]:
return self._definitions
def generate(self, expr: m.Expr) -> ast.expr:
match expr:
case m.VariableExpr():
return expr.accept(self)
case _:
func = Function(
pos_args=[],
args=[
Function.Argument(
pos=0,
name="_",
type=self.types.get_type("Any"),
required=True,
)
],
kw_args=[],
returns=self.types.get_type("bool"),
)
alias: str = self.make_alias(None)
definition: ast.stmt = self.make_definition(
alias, Predicate(type=func, body=expr, alias=False)
)
self._definitions.append(definition)
return ast.Name(id=alias)
def make_alias(self, name: Optional[str]) -> str:
suffix: str
if name is None:
suffix = f"p{self._id}"
self._id += 1
else:
suffix = name
alias: str = f"__midas_{suffix}__"
return alias
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
body: ast.expr = predicate.body.accept(self)
if predicate.alias:
return ast.Assign(
targets=[
ast.Name(id=name),
],
value=body,
)
return self.make_func(name, [ast.Return(value=body)], predicate.type)
def make_args(self, func: Function) -> ast.arguments:
return ast.arguments(
posonlyargs=[
ast.arg(
arg=arg.name,
annotation=ast.Constant(value=to_annotation(arg.type)),
)
for arg in func.pos_args
],
args=[
ast.arg(
arg=arg.name,
annotation=ast.Constant(value=to_annotation(arg.type)),
)
for arg in func.args
],
kwonlyargs=[
ast.arg(
arg=arg.name,
annotation=ast.Constant(value=to_annotation(arg.type)),
)
for arg in func.kw_args
],
defaults=[],
kw_defaults=[],
)
def make_func(
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
) -> ast.stmt:
match type:
case Function(returns=Function()):
inner_name: str = f"inner{level}"
return ast.FunctionDef(
name=name,
args=self.make_args(type),
body=[
self.make_func(inner_name, inner_body, type.returns, level + 1),
ast.Return(value=ast.Name(id=inner_name)),
],
returns=ast.Constant(value=to_annotation(type.returns)),
decorator_list=[],
)
case Function():
return ast.FunctionDef(
name=name,
args=self.make_args(type),
body=inner_body,
returns=ast.Constant(value=to_annotation(type.returns)),
decorator_list=[],
)
case _:
raise ValueError(f"Expected function, got {type!r}")
def get_predicate(self, name: str) -> Optional[ast.expr]:
if name not in self._aliases:
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is None:
return None
alias: str = self.make_alias(name)
self._aliases[name] = alias
self._definitions.append(self.make_definition(alias, predicate))
return ast.Name(id=self._aliases[name])
def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr:
return ast.BoolOp(
op=LOGICAL_OPERATORS[expr.operator.type](),
values=[
expr.left.accept(self),
expr.right.accept(self),
],
)
def visit_binary_expr(self, expr: m.BinaryExpr) -> ast.expr:
op: TokenType = expr.operator.type
if op in BINARY_OPERATORS:
return ast.BinOp(
left=expr.left.accept(self),
op=BINARY_OPERATORS[op](),
right=expr.right.accept(self),
)
if op in COMPARISON_OPERATORS:
return ast.Compare(
left=expr.left.accept(self),
ops=[COMPARISON_OPERATORS[op]()],
comparators=[expr.right.accept(self)],
)
raise ValueError(f"Unexpected binary operator {op}")
def visit_unary_expr(self, expr: m.UnaryExpr) -> ast.expr:
return ast.UnaryOp(
op=UNARY_OPERATORS[expr.operator.type](),
operand=expr.right.accept(self),
)
def visit_call_expr(self, expr: m.CallExpr) -> ast.expr:
return ast.Call(
func=expr.callee.accept(self),
args=[arg.accept(self) for arg in expr.arguments],
keywords=[
ast.keyword(arg=name, value=arg.accept(self))
for name, arg in expr.keywords.items()
],
)
def visit_get_expr(self, expr: m.GetExpr) -> ast.expr:
return ast.Attribute(
value=expr.expr.accept(self),
attr=expr.name.lexeme,
)
def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr:
name: str = expr.name.lexeme
if (p := self.get_predicate(name)) is not None:
return p
return ast.Name(id=name)
def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr:
return expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> ast.expr:
return ast.Constant(value=expr.value)
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> ast.expr:
return ast.Name(id="_")

View File

@@ -2,15 +2,19 @@ import ast
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from typing import Optional, assert_never
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
@@ -19,7 +23,9 @@ from midas.checker.types import (
Type,
TypeVar,
UnitType,
UnknownType,
)
from midas.generator.constraints import ConstraintGenerator
from midas.utils import TypedAST
@@ -30,26 +36,29 @@ class Scope:
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def __init__(self, workdir: Path) -> None:
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas"
if self.build_dir.exists():
shutil.rmtree(self.build_dir)
self.build_dir.mkdir(parents=True, exist_ok=True)
self.rel_src_path: Path = Path()
self._typed_ast: TypedAST = TypedAST(
stmts=[],
judgements=[],
evaluated_casts=[],
)
self._alias_count: int = 0
self._predicate_count: int = 0
self._scopes: list[Scope] = []
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
self._constraints: list[tuple[m.Expr, ast.expr]] = []
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
self.rel_src_path = src_path.relative_to(self.workdir)
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
module = ast.Module(body=body, type_ignores=[])
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
module = ast.Module(body=predicates + body, type_ignores=[])
module = ast.fix_missing_locations(module)
return module
@@ -59,6 +68,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
module: ast.AST = self.generate_ast(typed_ast, src_path)
compiled: str = ast.unparse(module)
if out_path is None:
if self.build_dir.exists():
shutil.rmtree(self.build_dir)
self.build_dir.mkdir(parents=True, exist_ok=True)
out_path = (self.build_dir / self.rel_src_path).resolve()
try:
_ = out_path.relative_to(self.build_dir)
@@ -120,6 +132,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
expr2: ast.expr = expr.expr.accept(self)
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
return expr2
alias: ast.expr = self._make_alias(expr2)
type: Type = self._get_expr_type(expr)
@@ -246,7 +262,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
return generated
def _make_alias(self, expr: ast.expr) -> ast.expr:
name: str = f"__midas_alias_{self._alias_count}__"
name: str = f"__midas_a{self._alias_count}__"
alias = ast.Name(id=name)
self._alias_count += 1
self._scopes[-1].aliases.append(name)
@@ -276,6 +292,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
match type:
case UnknownType():
pass
case BaseType(name=name):
self._add_assert(
ast.Call(
@@ -301,8 +320,17 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._make_cast_assert_message(src_location, expr, type),
)
case AppliedType():
self._make_cast_asserts(src_location, expr, type.body)
case AppliedType(body=body):
self._make_cast_asserts(src_location, expr, body)
case ConstraintType(type=base, constraint=constraint):
self._make_cast_asserts(src_location, expr, base)
self._make_constraint_assert(src_location, expr, constraint)
case TypeVar(bound=bound):
# TODO: check with type from arguments / use call-site context
if bound is not None:
self._make_cast_asserts(src_location, expr, bound)
case (
TopType()
@@ -314,8 +342,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
):
raise NotImplementedError(f"Can't make assertion for type {type}")
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
# Ensure exhaustiveness
case _:
assert_never(type)
def _make_cast_assert_message(
self, location: Location, expr: ast.expr, type: Type
@@ -339,3 +368,36 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
ast.Constant(f" to {type}"),
]
)
def _make_constraint_assert(
self, src_location: Location, expr: ast.expr, constraint: m.Expr
):
test_func: ast.expr = self._get_constraint(constraint)
self._add_assert(
ast.Call(
func=test_func,
args=[expr],
keywords=[],
),
self._make_constraint_assert_message(src_location, expr, constraint),
)
def _make_constraint_assert_message(
self, location: Location, expr: ast.expr, constraint: m.Expr
) -> ast.expr:
printer = MidasPrinter()
constraint_str: str = printer.print(constraint)
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: ConstraintError: Value does not fit constraint 'v > 0'"
return ast.Constant(
f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'"
)
def _get_constraint(self, expr: m.Expr) -> ast.expr:
for expr2, constraint in self._constraints:
if expr2 == expr:
return constraint
constraint: ast.expr = self._constraint_generator.generate(expr)
self._constraints.append((expr, constraint))
return constraint

View File

@@ -1,5 +1,5 @@
import ast
from typing import Optional
from typing import Optional, assert_never
import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry
@@ -8,6 +8,7 @@ from midas.checker.types import (
AppliedType,
BaseType,
ComplexType,
ConstraintType,
ExtensionType,
Function,
GenericType,
@@ -17,6 +18,7 @@ from midas.checker.types import (
TypeVar,
UnitType,
UnknownType,
Variance,
substitute_typevars,
)
@@ -37,6 +39,18 @@ class StubsGenerator:
self.stubs = []
self.typing_imports = set()
for name, type in self.types._types.items():
# Skip builtin types, not just based on name so the user can override
# TODO: check if added members on builtin type
match type:
case BaseType(name=name_) if name == name_:
continue
case GenericType(
name=name1,
body=BaseType(name=name2),
) if (
name == name1 == name2
):
continue
self.generate_stub(name, type)
imports = [
@@ -84,6 +98,7 @@ class StubsGenerator:
match type:
case AliasType(type=base):
return [self.dump_type(base)], {}
case GenericType(params=params, body=body):
self.add_typing_import("Generic")
type_vars: ast.expr
@@ -111,6 +126,13 @@ class StubsGenerator:
],
body_subsitutions | substitutions,
)
case ConstraintType(type=base):
return self.get_bases(base)
case TypeVar(bound=bound) if bound is not None:
return [self.dump_type(bound)], {}
case _:
return [], {}
@@ -148,15 +170,20 @@ class StubsGenerator:
case TopType() | UnknownType():
self.add_typing_import("Any")
return ast.Name(id="Any")
case BaseType(name=name):
return ast.Name(id=name)
case AliasType(name=name):
return ast.Name(id=name)
case UnitType():
return ast.Constant(value=None)
case Function():
name: str = self.define_protocol(type)
return ast.Name(id=name)
case OverloadedFunction(overloads=overloads):
if len(overloads) == 1:
return self.dump_type(overloads[0])
@@ -176,6 +203,7 @@ class StubsGenerator:
case TypeVar():
return ast.Name(id=type.name)
case GenericType(name=name):
params: ast.expr
if len(type.params) == 1:
@@ -188,6 +216,7 @@ class StubsGenerator:
value=ast.Name(id=type.name),
slice=params,
)
case AppliedType():
args: ast.expr
if len(type.args) == 1:
@@ -199,6 +228,12 @@ class StubsGenerator:
slice=args,
)
case ConstraintType():
return self.dump_type(type.type)
case _:
assert_never(type)
def dump_method(
self, name: str, method: Type, overloaded: bool = False
) -> list[ast.stmt]:
@@ -313,6 +348,29 @@ class StubsGenerator:
def define_type_var(self, var: TypeVar) -> TypeVar:
name: str = self.new_type_var_name()
self.add_typing_import("TypeVar")
kwargs: list[ast.keyword] = []
if var.bound is not None:
kwargs.append(
ast.keyword(
arg="bound",
value=self.dump_type(var.bound),
)
)
if var.variance == Variance.COVARIANT:
kwargs.append(
ast.keyword(
arg="covariant",
value=ast.Constant(value=True),
)
)
elif var.variance == Variance.CONTRAVARIANT:
kwargs.append(
ast.keyword(
arg="contravariant",
value=ast.Constant(value=True),
)
)
self.add_stub(
ast.Assign(
targets=[ast.Name(id=name)],
@@ -321,16 +379,7 @@ class StubsGenerator:
args=[
ast.Constant(value=name),
],
keywords=(
[]
if var.bound is None
else [
ast.keyword(
arg="bound",
value=self.dump_type(var.bound),
)
]
),
keywords=kwargs,
),
)
)

View File

@@ -69,6 +69,8 @@ class MidasLexer(Lexer):
):
self.advance()
self.add_token(TokenType.WHITESPACE)
case '"' | "'":
self.scan_string(char)
case _:
if char.isdigit():
self.scan_number()
@@ -78,6 +80,17 @@ class MidasLexer(Lexer):
self.error("Unexpected character")
return None
def scan_string(self, opening: str):
while self.peek() != opening and not self.is_at_end():
self.advance()
if self.is_at_end():
self.error("Unterminated string")
self.advance()
value: str = self.source[self.start + 1 : self.idx - 1]
self.add_token(TokenType.STRING, value)
def scan_number(self):
"""Scan the rest of number and add it as a token

View File

@@ -43,6 +43,7 @@ class TokenType(Enum):
TRUE = auto()
FALSE = auto()
NONE = auto()
STRING = auto()
# Keywords
TYPE = auto()

View File

@@ -3,6 +3,7 @@ from typing import Optional
from midas.ast.location import Location
from midas.ast.midas import (
BinaryExpr,
CallExpr,
ComplexType,
ConstraintType,
Expr,
@@ -17,6 +18,7 @@ from midas.ast.midas import (
MemberKind,
MemberStmt,
NamedType,
ParamSpec,
PredicateStmt,
Stmt,
Type,
@@ -265,6 +267,9 @@ class MidasParser(Parser):
Returns:
Expr: the parsed constraint expression
"""
return self.expression()
def expression(self) -> Expr:
return self.and_()
def and_(self) -> Expr:
@@ -331,7 +336,55 @@ class MidasParser(Parser):
right: Expr = self.unary()
location: Location = Location.span(operator.get_location(), right.location)
return UnaryExpr(location=location, operator=operator, right=right)
return self.reference()
return self.call()
def call(self) -> Expr:
expr: Expr = self.reference()
while self.match(TokenType.LEFT_PAREN):
expr = self.finish_call(expr)
return expr
def finish_call(self, callee: Expr) -> Expr:
pos_args: list[Expr] = []
kw_args: dict[str, Expr] = {}
keywords: bool = False
while not self.match(TokenType.RIGHT_PAREN):
if self.check_identifier() and self.check_next(TokenType.EQUAL):
keywords = True
keyword: Token = self.advance()
self.advance()
value: Expr = self.expression()
name: str = keyword.lexeme
if name in kw_args:
self.error(
self.peek(),
f"Multiple values passed for '{name}', only the last occurrence will be used",
)
kw_args[name] = value
else:
value = self.expression()
if self.check(TokenType.EQUAL):
if keywords:
raise self.error(self.peek(), "Invalid keyword argument name")
else:
raise self.error(
self.peek(),
"Cannot pass positional arguments after a keyword argument",
)
pos_args.append(value)
if not self.match(TokenType.COMMA):
break
r_paren: Token = self.consume(
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
)
return CallExpr(
location=Location.span(callee.location, r_paren.get_location()),
callee=callee,
arguments=pos_args,
keywords=kw_args,
)
def reference(self) -> Expr:
"""Parse an attribute access expression or a simpler expression
@@ -365,6 +418,9 @@ class MidasParser(Parser):
if self.match(TokenType.NUMBER):
return LiteralExpr(location=token.get_location(), value=token.value)
if self.match(TokenType.STRING):
return LiteralExpr(location=token.get_location(), value=token.value)
if self.match_identifier():
return VariableExpr(location=token.get_location(), name=token)
@@ -453,23 +509,35 @@ class MidasParser(Parser):
PredicateStmt: the parsed predicate declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume_identifier("Expected predicate name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume_identifier("Expected subject name")
self.consume(TokenType.COLON, "Expected ':' after subject name")
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
params: list[ParamSpec] = []
while self.check(TokenType.LEFT_PAREN):
params.append(self.function_args())
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint()
body: Expr = self.constraint()
return PredicateStmt(
location=keyword.location_to(self.previous()),
name=name,
subject=subject,
type=type,
condition=condition,
params=params,
body=body,
)
def function(self) -> FunctionType:
params: ParamSpec = self.function_args()
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=params.l_paren.location_to(self.previous()),
params=params,
returns=result,
)
def function_args(self) -> ParamSpec:
l_paren: Token = self.consume(
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
)
@@ -526,14 +594,4 @@ class MidasParser(Parser):
self.error(token, "Unnamed mixed argument")
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=l_paren.location_to(self.previous()),
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=result,
)
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)

View File

@@ -49,6 +49,7 @@ class UnsupportedSyntaxError(Exception):
class PythonParser:
CAST_FUNCTION = "cast"
UNSAFE_CAST_FUNCTION = "unsafe_cast"
def parse_module(self, node: ast.Module) -> list[Stmt]:
statements: list[Stmt] = []
@@ -423,6 +424,9 @@ class PythonParser:
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
return self.parse_cast(node)
case ast.Call(func=ast.Name(id=self.UNSAFE_CAST_FUNCTION)):
return self.parse_cast(node)
case ast.Call():
return self.parse_call(node)
@@ -527,16 +531,19 @@ class PythonParser:
return expr
def parse_cast(self, node: ast.Call) -> CastExpr:
assert isinstance(node.func, ast.Name)
func: str = node.func.id
match node:
case ast.Call(args=[type, expr], keywords=[]):
return CastExpr(
location=Location.from_ast(node),
type=self._parse_type(type),
expr=self.parse_expr(expr),
unsafe=func == self.UNSAFE_CAST_FUNCTION,
)
case _:
raise InvalidSyntaxError(
f"Invalid call to {self.CAST_FUNCTION}, expected type and expression"
f"Invalid call to {func}, expected type and expression"
)
def parse_call(self, node: ast.Call) -> CallExpr:

34
midas/typing.py Normal file
View File

@@ -0,0 +1,34 @@
from typing import cast as typing_cast
cast = typing_cast
"""### Midas documentation
Cast a value to a type.
- **Compile-time**: tells the type checker that the return value has the designated type.
- **Run-time**: generates assertions to ensure the value can be interpreted as the given type.
---
<br>
<br>
<br>
_**Internal Python documentation**_
"""
unsafe_cast = typing_cast
"""### Midas documentation
Cast a value to a type.
- **Compile-time**: tells the type checker that the return value has the designated type.
- **Run-time**: -
This operation is unsound, use at your own risk!
---
<br>
<br>
<br>
_**Internal Python documentation**_
"""

View File

@@ -62,3 +62,4 @@ class UniversalJSONDumper:
class TypedAST:
stmts: list[p.Stmt]
judgements: list[tuple[p.Expr, Type]]
evaluated_casts: list[p.CastExpr]

View File

@@ -8,7 +8,11 @@ authors = [
{ name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" },
]
classifiers = ["Programming Language :: Python :: 3"]
dependencies = ["click>=8.4.1"]
dependencies = [
"black>=26.5.1",
"click>=8.4.1",
"watchdog>=6.0.0",
]
[project.urls]
Homepage = "https://git.kbk28.ch/HEL/midas"

View File

@@ -1,6 +1,19 @@
{
"diagnostics": [],
"judgments": [
{
"location": {
"from": "L4:30",
"to": "L4:36"
},
"expr": {
"_type": "LiteralExpr",
"value": 123.45
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L4:18",
@@ -16,7 +29,8 @@
"expr": {
"_type": "LiteralExpr",
"value": 123.45
}
},
"unsafe": false
},
"type": {
"name": "Meter",
@@ -25,6 +39,19 @@
}
}
},
{
"location": {
"from": "L5:28",
"to": "L5:31"
},
"expr": {
"_type": "LiteralExpr",
"value": 6.7
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L5:15",
@@ -40,7 +67,8 @@
"expr": {
"_type": "LiteralExpr",
"value": 6.7
}
},
"unsafe": false
},
"type": {
"name": "Second",

View File

@@ -0,0 +1,59 @@
// T is invariant (unused)
type Unused[T] = object
// T is covariant
type Covariant[T] = object
// T is contravariant
type Contravariant[T] = object
// T is invariant
type Invariant[T] = object
extend Covariant[T] {
def foo: fn() -> T
}
extend Contravariant[T] {
def foo: fn(T, /) -> None
}
extend Invariant[T] {
def foo: fn(T, /) -> T
}
// T is covariant
type Coco[T] = object
extend Coco[T] {
def foo: fn() -> Covariant[T]
}
// T is contravariant
type Cocontra[T] = object
extend Cocontra[T] {
def foo: fn() -> Contravariant[T]
}
// T is contravariant
type Contraco[T] = object
extend Contraco[T] {
def foo: fn(Covariant[T], /) -> None
}
// T is covariant
type Contracontra[T] = object
extend Contracontra[T] {
def foo: fn(Contravariant[T], /) -> None
}
type T1[T] = object
type T2[T] = object
extend T1[T] {
def foo: fn() -> T2[T]
}
extend T2[T] {
def foo: fn() -> T1[T]
}

View File

@@ -0,0 +1,52 @@
from _ import (
T1,
T2,
Coco,
Cocontra,
Contraco,
Contracontra,
Contravariant,
Covariant,
Invariant,
Unused,
)
unused: Unused
covariant: Covariant
contravariant: Contravariant
invariant: Invariant
coco: Coco
cocontra: Cocontra
contraco: Contraco
contracontra: Contracontra
t1: T1
t2: T2
# Dummy print to prudce judgements for the expressions
print(
unused,
covariant,
contravariant,
invariant,
coco,
cocontra,
contraco,
contracontra,
t1,
t2,
)
cov1: Covariant[float]
cov2: Covariant[int]
cov1 = cov2 # Ok because int <: float => Covariant[int] <: Covariant[float]
cov2 = cov1 # Invalid
contra1: Contravariant[float]
contra2: Contravariant[int]
contra1 = contra2 # Invalid
contra2 = contra1 # Ok because int <: float => Covariant[float] <: Covariant[int]
inv1: Invariant[float]
inv2: Invariant[int]
inv1 = inv2 # Invalid
inv2 = inv1 # Invalid

View File

@@ -0,0 +1,512 @@
{
"diagnostics": [
{
"type": "Error",
"location": {
"start": [
28,
4
],
"end": [
28,
13
]
},
"message": "Too many positional arguments"
},
{
"type": "Error",
"location": {
"start": [
42,
0
],
"end": [
42,
11
]
},
"message": "Cannot assign Covariant[float] to variable 'cov2' of type Covariant[int]"
},
{
"type": "Error",
"location": {
"start": [
46,
0
],
"end": [
46,
17
]
},
"message": "Cannot assign Contravariant[int] to variable 'contra1' of type Contravariant[float]"
},
{
"type": "Error",
"location": {
"start": [
51,
0
],
"end": [
51,
11
]
},
"message": "Cannot assign Invariant[int] to variable 'inv1' of type Invariant[float]"
},
{
"type": "Error",
"location": {
"start": [
52,
0
],
"end": [
52,
11
]
},
"message": "Cannot assign Invariant[float] to variable 'inv2' of type Invariant[int]"
}
],
"judgments": [
{
"location": {
"from": "L26:0",
"to": "L26:5"
},
"expr": {
"_type": "VariableExpr",
"name": "print"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {}
}
},
{
"location": {
"from": "L27:4",
"to": "L27:10"
},
"expr": {
"_type": "VariableExpr",
"name": "unused"
},
"type": {
"name": "Unused",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L28:4",
"to": "L28:13"
},
"expr": {
"_type": "VariableExpr",
"name": "covariant"
},
"type": {
"name": "Covariant",
"params": [
{
"name": "T",
"bound": null,
"variance": "COVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L29:4",
"to": "L29:17"
},
"expr": {
"_type": "VariableExpr",
"name": "contravariant"
},
"type": {
"name": "Contravariant",
"params": [
{
"name": "T",
"bound": null,
"variance": "CONTRAVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L30:4",
"to": "L30:13"
},
"expr": {
"_type": "VariableExpr",
"name": "invariant"
},
"type": {
"name": "Invariant",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L31:4",
"to": "L31:8"
},
"expr": {
"_type": "VariableExpr",
"name": "coco"
},
"type": {
"name": "Coco",
"params": [
{
"name": "T",
"bound": null,
"variance": "COVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L32:4",
"to": "L32:12"
},
"expr": {
"_type": "VariableExpr",
"name": "cocontra"
},
"type": {
"name": "Cocontra",
"params": [
{
"name": "T",
"bound": null,
"variance": "CONTRAVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L33:4",
"to": "L33:12"
},
"expr": {
"_type": "VariableExpr",
"name": "contraco"
},
"type": {
"name": "Contraco",
"params": [
{
"name": "T",
"bound": null,
"variance": "CONTRAVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L34:4",
"to": "L34:16"
},
"expr": {
"_type": "VariableExpr",
"name": "contracontra"
},
"type": {
"name": "Contracontra",
"params": [
{
"name": "T",
"bound": null,
"variance": "COVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L35:4",
"to": "L35:6"
},
"expr": {
"_type": "VariableExpr",
"name": "t1"
},
"type": {
"name": "T1",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L36:4",
"to": "L36:6"
},
"expr": {
"_type": "VariableExpr",
"name": "t2"
},
"type": {
"name": "T2",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L26:0",
"to": "L37:1"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "print"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "unused"
},
{
"_type": "VariableExpr",
"name": "covariant"
},
{
"_type": "VariableExpr",
"name": "contravariant"
},
{
"_type": "VariableExpr",
"name": "invariant"
},
{
"_type": "VariableExpr",
"name": "coco"
},
{
"_type": "VariableExpr",
"name": "cocontra"
},
{
"_type": "VariableExpr",
"name": "contraco"
},
{
"_type": "VariableExpr",
"name": "contracontra"
},
{
"_type": "VariableExpr",
"name": "t1"
},
{
"_type": "VariableExpr",
"name": "t2"
}
],
"keywords": {}
},
"type": {}
},
{
"location": {
"from": "L41:7",
"to": "L41:11"
},
"expr": {
"_type": "VariableExpr",
"name": "cov2"
},
"type": {
"name": "Covariant",
"args": [
{
"name": "int"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L42:7",
"to": "L42:11"
},
"expr": {
"_type": "VariableExpr",
"name": "cov1"
},
"type": {
"name": "Covariant",
"args": [
{
"name": "float"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L46:10",
"to": "L46:17"
},
"expr": {
"_type": "VariableExpr",
"name": "contra2"
},
"type": {
"name": "Contravariant",
"args": [
{
"name": "int"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L47:10",
"to": "L47:17"
},
"expr": {
"_type": "VariableExpr",
"name": "contra1"
},
"type": {
"name": "Contravariant",
"args": [
{
"name": "float"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L51:7",
"to": "L51:11"
},
"expr": {
"_type": "VariableExpr",
"name": "inv2"
},
"type": {
"name": "Invariant",
"args": [
{
"name": "int"
}
],
"body": {
"name": "object"
}
}
},
{
"location": {
"from": "L52:7",
"to": "L52:11"
},
"expr": {
"_type": "VariableExpr",
"name": "inv1"
},
"type": {
"name": "Invariant",
"args": [
{
"name": "float"
}
],
"body": {
"name": "object"
}
}
}
]
}

View File

@@ -0,0 +1,14 @@
def double(value: float) -> float:
return value * 2
def is_odd(value: int) -> bool:
return bool(value % 2)
floats: list[float] = [0.2, 0.5, 0.1, 1.2]
ints: list[int] = [1, 2, 6, -3]
doubled_floats = map(double, floats)
doubled_ints = map(double, ints)
odd_ints = map(is_odd, ints)

View File

@@ -0,0 +1,874 @@
{
"diagnostics": [
{
"type": "Error",
"location": {
"start": [
13,
15
],
"end": [
13,
32
]
},
"message": "Could not unify map[T, U]=(transform: (v: T, /) -> U, iterable: list[T], /) -> list[U] with pos=[(value: float) -> float, list[int]] and kw={}"
}
],
"judgments": [
{
"location": {
"from": "L2:11",
"to": "L2:16"
},
"expr": {
"_type": "VariableExpr",
"name": "value"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L2:19",
"to": "L2:20"
},
"expr": {
"_type": "LiteralExpr",
"value": 2
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L2:11",
"to": "L2:20"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "value"
},
"operator": "*",
"right": {
"_type": "LiteralExpr",
"value": 2
}
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L6:11",
"to": "L6:15"
},
"expr": {
"_type": "VariableExpr",
"name": "bool"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": false
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{
"location": {
"from": "L6:16",
"to": "L6:21"
},
"expr": {
"_type": "VariableExpr",
"name": "value"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:24",
"to": "L6:25"
},
"expr": {
"_type": "LiteralExpr",
"value": 2
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:16",
"to": "L6:25"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "value"
},
"operator": "%",
"right": {
"_type": "LiteralExpr",
"value": 2
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:11",
"to": "L6:26"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "bool"
},
"arguments": [
{
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "value"
},
"operator": "%",
"right": {
"_type": "LiteralExpr",
"value": 2
}
}
],
"keywords": {}
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L9:23",
"to": "L9:26"
},
"expr": {
"_type": "LiteralExpr",
"value": 0.2
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:28",
"to": "L9:31"
},
"expr": {
"_type": "LiteralExpr",
"value": 0.5
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:33",
"to": "L9:36"
},
"expr": {
"_type": "LiteralExpr",
"value": 0.1
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:38",
"to": "L9:41"
},
"expr": {
"_type": "LiteralExpr",
"value": 1.2
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:22",
"to": "L9:42"
},
"expr": {
"_type": "ListExpr",
"items": [
{
"_type": "LiteralExpr",
"value": 0.2
},
{
"_type": "LiteralExpr",
"value": 0.5
},
{
"_type": "LiteralExpr",
"value": 0.1
},
{
"_type": "LiteralExpr",
"value": 1.2
}
]
},
"type": {
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L10:19",
"to": "L10:20"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:22",
"to": "L10:23"
},
"expr": {
"_type": "LiteralExpr",
"value": 2
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:25",
"to": "L10:26"
},
"expr": {
"_type": "LiteralExpr",
"value": 6
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:29",
"to": "L10:30"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:28",
"to": "L10:30"
},
"expr": {
"_type": "UnaryExpr",
"operator": "-",
"right": {
"_type": "LiteralExpr",
"value": 3
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:18",
"to": "L10:31"
},
"expr": {
"_type": "ListExpr",
"items": [
{
"_type": "LiteralExpr",
"value": 1
},
{
"_type": "LiteralExpr",
"value": 2
},
{
"_type": "LiteralExpr",
"value": 6
},
{
"_type": "UnaryExpr",
"operator": "-",
"right": {
"_type": "LiteralExpr",
"value": 3
}
}
]
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L12:17",
"to": "L12:20"
},
"expr": {
"_type": "VariableExpr",
"name": "map"
},
"type": {
"name": "map",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"pos_args": [
{
"pos": 0,
"name": "transform",
"type": {
"pos_args": [
{
"pos": 0,
"name": "v",
"type": {
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
},
"required": true
},
{
"pos": 1,
"name": "iterable",
"type": {
"name": "list",
"args": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "list",
"args": [
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
}
}
}
},
{
"location": {
"from": "L12:21",
"to": "L12:27"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L12:29",
"to": "L12:35"
},
"expr": {
"_type": "VariableExpr",
"name": "floats"
},
"type": {
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L12:17",
"to": "L12:36"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "map"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "double"
},
{
"_type": "VariableExpr",
"name": "floats"
}
],
"keywords": {}
},
"type": {
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L13:15",
"to": "L13:18"
},
"expr": {
"_type": "VariableExpr",
"name": "map"
},
"type": {
"name": "map",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"pos_args": [
{
"pos": 0,
"name": "transform",
"type": {
"pos_args": [
{
"pos": 0,
"name": "v",
"type": {
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
},
"required": true
},
{
"pos": 1,
"name": "iterable",
"type": {
"name": "list",
"args": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "list",
"args": [
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
}
}
}
},
{
"location": {
"from": "L13:19",
"to": "L13:25"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L13:27",
"to": "L13:31"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L13:15",
"to": "L13:32"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "map"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "double"
},
{
"_type": "VariableExpr",
"name": "ints"
}
],
"keywords": {}
},
"type": {}
},
{
"location": {
"from": "L14:11",
"to": "L14:14"
},
"expr": {
"_type": "VariableExpr",
"name": "map"
},
"type": {
"name": "map",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"pos_args": [
{
"pos": 0,
"name": "transform",
"type": {
"pos_args": [
{
"pos": 0,
"name": "v",
"type": {
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
},
"required": true
},
{
"pos": 1,
"name": "iterable",
"type": {
"name": "list",
"args": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "list",
"args": [
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
}
}
}
},
{
"location": {
"from": "L14:15",
"to": "L14:21"
},
"expr": {
"_type": "VariableExpr",
"name": "is_odd"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "int"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{
"location": {
"from": "L14:23",
"to": "L14:27"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L14:11",
"to": "L14:28"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "map"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "is_odd"
},
{
"_type": "VariableExpr",
"name": "ints"
}
],
"keywords": {}
},
"type": {
"name": "list",
"args": [
{
"name": "bool"
}
],
"body": {
"name": "list"
}
}
}
]
}

View File

@@ -7,68 +7,14 @@ Module(
alias(name='Meter'),
alias(name='Second')],
level=0),
Assign(
targets=[
Name(id='__midas_alias_0__')],
value=Constant(value=123.45)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_alias_0__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L3:19: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_alias_0__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assign(
targets=[
Name(id='distance')],
value=Name(id='__midas_alias_0__')),
Delete(
targets=[
Name(id='__midas_alias_0__')]),
Assign(
targets=[
Name(id='__midas_alias_1__')],
value=Constant(value=6.7)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_alias_1__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L4:16: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_alias_1__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
value=Constant(value=123.45)),
Assign(
targets=[
Name(id='time')],
value=Name(id='__midas_alias_1__')),
Delete(
targets=[
Name(id='__midas_alias_1__')]),
value=Constant(value=6.7)),
Assign(
targets=[
Name(id='speed')],

View File

@@ -0,0 +1,14 @@
// Inline
type T1 = float where _ > 0
// Named
predicate is_positive(v: float) = v > 0
type T2 = float where is_positive(_)
// Curried
predicate in_range(mn: float, mx: float)(v: float) = v >= mn & v < mx
type T3 = float where in_range(100, 200)(_)
// Alias
predicate minor = in_range(0, 18)
type T4 = float where minor(_)

View File

@@ -0,0 +1,8 @@
from midas import T1, T2, T3, T4, cast
t: float = 12.5
t1: T1 = cast(T1, t)
t2: T2 = cast(T2, t)
t3: T3 = cast(T3, t)
t4: T4 = cast(T4, t)

View File

@@ -0,0 +1,333 @@
Module(
body=[
FunctionDef(
name='__midas_p0__',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='_',
annotation=Constant(value='Any'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
Return(
value=Compare(
left=Name(id='_'),
ops=[
Gt()],
comparators=[
Constant(value=0.0)]))],
decorator_list=[],
returns=Constant(value='bool')),
FunctionDef(
name='__midas_is_positive__',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='v',
annotation=Constant(value='float'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
Return(
value=Compare(
left=Name(id='v'),
ops=[
Gt()],
comparators=[
Constant(value=0.0)]))],
decorator_list=[],
returns=Constant(value='bool')),
FunctionDef(
name='__midas_p1__',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='_',
annotation=Constant(value='Any'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
Return(
value=Call(
func=Name(id='__midas_is_positive__'),
args=[
Name(id='_')],
keywords=[]))],
decorator_list=[],
returns=Constant(value='bool')),
FunctionDef(
name='__midas_in_range__',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='mn',
annotation=Constant(value='float')),
arg(
arg='mx',
annotation=Constant(value='float'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
FunctionDef(
name='inner0',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='v',
annotation=Constant(value='float'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
Return(
value=BoolOp(
op=And(),
values=[
Compare(
left=Name(id='v'),
ops=[
GtE()],
comparators=[
Name(id='mn')]),
Compare(
left=Name(id='v'),
ops=[
Lt()],
comparators=[
Name(id='mx')])]))],
decorator_list=[],
returns=Constant(value='bool')),
Return(
value=Name(id='inner0'))],
decorator_list=[],
returns=Constant(value='Callable[[float], bool]')),
FunctionDef(
name='__midas_p2__',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='_',
annotation=Constant(value='Any'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
Return(
value=Call(
func=Call(
func=Name(id='__midas_in_range__'),
args=[
Constant(value=100.0),
Constant(value=200.0)],
keywords=[]),
args=[
Name(id='_')],
keywords=[]))],
decorator_list=[],
returns=Constant(value='bool')),
Assign(
targets=[
Name(id='__midas_minor__')],
value=Call(
func=Name(id='__midas_in_range__'),
args=[
Constant(value=0.0),
Constant(value=18.0)],
keywords=[])),
FunctionDef(
name='__midas_p3__',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='_',
annotation=Constant(value='Any'))],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
Return(
value=Call(
func=Name(id='__midas_minor__'),
args=[
Name(id='_')],
keywords=[]))],
decorator_list=[],
returns=Constant(value='bool')),
ImportFrom(
module='midas',
names=[
alias(name='T1'),
alias(name='T2'),
alias(name='T3'),
alias(name='T4'),
alias(name='cast')],
level=0),
Assign(
targets=[
Name(id='t')],
value=Constant(value=12.5)),
Assign(
targets=[
Name(id='__midas_a0__')],
value=Name(id='t')),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a0__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='02_constraints.py:L5:10: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a0__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assert(
test=Call(
func=Name(id='__midas_p0__'),
args=[
Name(id='__midas_a0__')],
keywords=[]),
msg=Constant(value="02_constraints.py:L5:10: ConstraintError: Value does not fit constraint '_ > 0.0'")),
Assign(
targets=[
Name(id='t1')],
value=Name(id='__midas_a0__')),
Delete(
targets=[
Name(id='__midas_a0__')]),
Assign(
targets=[
Name(id='__midas_a1__')],
value=Name(id='t')),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a1__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='02_constraints.py:L6:10: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a1__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assert(
test=Call(
func=Name(id='__midas_p1__'),
args=[
Name(id='__midas_a1__')],
keywords=[]),
msg=Constant(value="02_constraints.py:L6:10: ConstraintError: Value does not fit constraint 'is_positive(_)'")),
Assign(
targets=[
Name(id='t2')],
value=Name(id='__midas_a1__')),
Delete(
targets=[
Name(id='__midas_a1__')]),
Assign(
targets=[
Name(id='__midas_a2__')],
value=Name(id='t')),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a2__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='02_constraints.py:L7:10: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a2__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assert(
test=Call(
func=Name(id='__midas_p2__'),
args=[
Name(id='__midas_a2__')],
keywords=[]),
msg=Constant(value="02_constraints.py:L7:10: ConstraintError: Value does not fit constraint 'in_range(100.0, 200.0)(_)'")),
Assign(
targets=[
Name(id='t3')],
value=Name(id='__midas_a2__')),
Delete(
targets=[
Name(id='__midas_a2__')]),
Assign(
targets=[
Name(id='__midas_a3__')],
value=Name(id='t')),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a3__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='02_constraints.py:L8:10: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a3__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assert(
test=Call(
func=Name(id='__midas_p3__'),
args=[
Name(id='__midas_a3__')],
keywords=[]),
msg=Constant(value="02_constraints.py:L8:10: ConstraintError: Value does not fit constraint 'minor(_)'")),
Assign(
targets=[
Name(id='t4')],
value=Name(id='__midas_a3__')),
Delete(
targets=[
Name(id='__midas_a3__')])],
type_ignores=[])

View File

@@ -2582,7 +2582,9 @@
"name": "__sub__",
"type": {
"_type": "FunctionType",
"pos_args": [
"params": {
"_type": "ParamSpec",
"pos": [
{
"name": null,
"type": {
@@ -2592,8 +2594,9 @@
"required": true
}
],
"args": [],
"kw_args": [],
"mixed": [],
"kw": []
},
"returns": {
"_type": "GenericType",
"type": {
@@ -2673,7 +2676,9 @@
"name": "__sub__",
"type": {
"_type": "FunctionType",
"pos_args": [
"params": {
"_type": "ParamSpec",
"pos": [
{
"name": null,
"type": {
@@ -2683,8 +2688,9 @@
"required": true
}
],
"args": [],
"kw_args": [],
"mixed": [],
"kw": []
},
"returns": {
"_type": "GenericType",
"type": {
@@ -2713,7 +2719,9 @@
"name": "__sub__",
"type": {
"_type": "FunctionType",
"pos_args": [
"params": {
"_type": "ParamSpec",
"pos": [
{
"name": null,
"type": {
@@ -2723,8 +2731,9 @@
"required": true
}
],
"args": [],
"kw_args": [],
"mixed": [],
"kw": []
},
"returns": {
"_type": "GenericType",
"type": {
@@ -2745,12 +2754,24 @@
{
"_type": "PredicateStmt",
"name": "Positive",
"subject": "v",
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "v",
"type": {
"_type": "NamedType",
"name": "float"
},
"condition": {
"required": true
}
],
"kw": []
}
],
"body": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
@@ -2766,12 +2787,24 @@
{
"_type": "PredicateStmt",
"name": "StrictlyPositive",
"subject": "v",
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "v",
"type": {
"_type": "NamedType",
"name": "float"
},
"condition": {
"required": true
}
],
"kw": []
}
],
"body": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
@@ -2787,12 +2820,24 @@
{
"_type": "PredicateStmt",
"name": "Equatorial",
"subject": "loc",
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "loc",
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"condition": {
"required": true
}
],
"kw": []
}
],
"body": {
"_type": "GroupingExpr",
"expr": {
"_type": "BinaryExpr",
@@ -2827,12 +2872,24 @@
{
"_type": "PredicateStmt",
"name": "Arctic",
"subject": "loc",
"params": [
{
"_type": "ParamSpec",
"pos": [],
"mixed": [
{
"name": "loc",
"type": {
"_type": "NamedType",
"name": "GeoLocation"
},
"condition": {
"required": true
}
],
"kw": []
}
],
"body": {
"_type": "GroupingExpr",
"expr": {
"_type": "BinaryExpr",

View File

@@ -45,7 +45,7 @@ class GeneratorTester(Tester):
typed_ast: TypedAST = checker.type_check(path)
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
generator = Generator(workdir=path.parent)
generator = Generator(workdir=path.parent, types=checker.types)
result.compiled_ast = generator.generate_ast(typed_ast, path)
return result

View File

@@ -2,6 +2,7 @@ from typing import Optional, Sequence
from midas.ast.midas import (
BinaryExpr,
CallExpr,
ComplexType,
ConstraintType,
Expr,
@@ -15,6 +16,7 @@ from midas.ast.midas import (
LogicalExpr,
MemberStmt,
NamedType,
ParamSpec,
PredicateStmt,
Stmt,
Type,
@@ -78,9 +80,8 @@ class MidasAstJsonSerializer(
return {
"_type": "PredicateStmt",
"name": stmt.name.lexeme,
"subject": stmt.subject.lexeme,
"type": stmt.type.accept(self),
"condition": stmt.condition.accept(self),
"params": [self._serialize_param_spec(spec) for spec in stmt.params],
"body": stmt.body.accept(self),
}
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
@@ -106,6 +107,14 @@ class MidasAstJsonSerializer(
"right": expr.right.accept(self),
}
def visit_call_expr(self, expr: CallExpr) -> dict:
return {
"_type": "CallExpr",
"callee": expr.callee.accept(self),
"arguments": self._serialize_list(expr.arguments),
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
}
def visit_get_expr(self, expr: GetExpr) -> dict:
return {
"_type": "GetExpr",
@@ -163,15 +172,21 @@ class MidasAstJsonSerializer(
def visit_function_type(self, type: FunctionType) -> dict:
return {
"_type": "FunctionType",
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
"args": [self._serialize_func_arg(arg) for arg in type.args],
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
"params": self._serialize_param_spec(type.params),
"returns": type.returns.accept(self),
}
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
return {
"_type": "ParamSpec",
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
}
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
return {
"name": arg.name,
"name": arg.name.lexeme if arg.name is not None else None,
"type": arg.type.accept(self),
"required": arg.required,
}

View File

@@ -263,6 +263,7 @@ class PythonAstJsonSerializer(
"_type": "CastExpr",
"type": expr.type.accept(self),
"expr": expr.expr.accept(self),
"unsafe": expr.unsafe,
}
def visit_ternary_expr(self, expr: TernaryExpr) -> dict: