118 Commits

Author SHA1 Message Date
2886ffe00b feat(checker): add slice overloads on lists 2026-06-14 17:04:29 +02:00
def38e720b fix(checker): handle generic overloads 2026-06-14 17:04:10 +02:00
49274be2f4 feat(checker): type check slice expressions 2026-06-14 16:58:20 +02:00
aec6b7aa7b feat(parser): add slice expression 2026-06-14 16:53:38 +02:00
530d93723e tests: update with new subscript and call checks
invalid function calls now return UnknownType even if the function has a return type
2026-06-14 16:45:53 +02:00
905132a18e feat(checker): resolve overloads with subtypes
try to find the most specific overload if multiple matches are found
2026-06-14 16:36:10 +02:00
9594c74952 doc(checker): add docstrings to new call checks 2026-06-14 16:08:34 +02:00
8df5607461 refactor(checker): unify call check for subscript 2026-06-14 16:08:13 +02:00
757054d7af chore: add examples for functions and overloads 2026-06-14 15:50:20 +02:00
25a96d20e1 feat(checker): handle overloaded function calls 2026-06-14 15:48:31 +02:00
04c0d683de fix(types): remove unused operation structures 2026-06-13 18:57:02 +02:00
ac620f318b feat(checker): type check subscripts 2026-06-13 18:48:53 +02:00
947e9f0149 feat(parser): add subscript expressions 2026-06-13 18:44:19 +02:00
c1ca254b51 feat(checker): handle unary operations 2026-06-13 18:19:50 +02:00
0bb862a1db fix(checker): report unsupported features 2026-06-13 18:11:14 +02:00
54919a3565 tests: update with newly reported judgements 2026-06-13 18:01:32 +02:00
a5f0140013 refactor(checker): replace all accept calls
make visitor accept calls more explicit with type_of(), resolve_type_expr() and process_stmt()
2026-06-13 18:01:02 +02:00
b0af01d906 tests: update tests 2026-06-13 17:46:15 +02:00
6048ee020f fix(checker): adapt comparison to lookup method 2026-06-13 17:44:40 +02:00
f815faa2f8 fix(checker): remove in.to_bytes 2026-06-13 14:00:50 +02:00
f7d5d36d44 fix(checker): handle members on base type 2026-06-13 14:00:23 +02:00
503f2b6a0a fix: remove unused op statement 2026-06-13 13:49:57 +02:00
778117664f fix(checker): forward parsing errors as diagnostics 2026-06-13 13:44:05 +02:00
afe3eefbbf fix(checker): gravefully handle unknown type 2026-06-13 13:43:33 +02:00
96495e9f79 fix(parser): make name required for mixed and keyword args 2026-06-13 13:43:16 +02:00
77263139f6 feat(parser): add mixed arguments in midas functions 2026-06-13 13:16:57 +02:00
4f5967a151 feat(checker): add top type (Any) 2026-06-13 12:45:40 +02:00
2a714a1021 fix: extend example of complex types 2026-06-13 12:40:26 +02:00
dafe0b471a feat(checker): define members on builtin types 2026-06-13 12:39:46 +02:00
a1f2937e16 feat(tests): update serializer 2026-06-12 17:01:19 +02:00
2063d94dce fix(checker): give warning on unknown variable 2026-06-12 17:01:02 +02:00
22fc8010d8 fix(checker): minor fix when using base type in generic 2026-06-12 16:56:03 +02:00
aff1097d91 fix(checker): update binary operation lookup 2026-06-12 16:55:01 +02:00
12d034fd1e fix(checker): handle nested generic members 2026-06-12 16:53:34 +02:00
200709cca6 feat(checker): implement lookup_member method 2026-06-12 16:52:09 +02:00
700284296c feat(checker): add members registry 2026-06-12 16:45:42 +02:00
0b53259b90 fix(cli): update highlighter 2026-06-12 16:42:25 +02:00
0461a4184c feat(parser): accept props and methods in extend 2026-06-12 16:41:33 +02:00
01d6e41893 feat(cli): add option to show type judgements 2026-06-12 14:44:02 +02:00
80e611e49c fix(cli): show diagnostics from different files 2026-06-12 14:43:27 +02:00
c00915966f fix(checker): improve error for recursive type ref 2026-06-11 17:15:28 +02:00
beaa4d95d8 feat(checker): adapt typers to members and extension type 2026-06-11 17:13:13 +02:00
bfa0bb3ee0 feat(parser): add new ast nodes to parser 2026-06-11 13:49:47 +02:00
31158df2a9 feat(parser): add extension type and rename properties 2026-06-11 13:42:19 +02:00
c6ead886ec feat: add function type to midas syntax 2026-06-09 23:48:06 +02:00
9de03bf2b5 feat(types): add type params to extend statement 2026-06-09 23:40:57 +02:00
a26b9293be refactor(types): extract TypeParams
also rename generic type params to type args (when calling a generic)
2026-06-09 15:30:45 +02:00
efa5454776 feat(types): add human-friendly string rep
add `__str__` methods on type structures to improve readability of diagnostics
2026-06-09 12:59:36 +02:00
b8bb8190c4 fix(resolver): define variable on assignment
if a variable is not already defined when an assignment is visited, it is then defined in the current scope
2026-06-09 08:06:46 +02:00
a4f5db7ece fix(checker): use reduce_types to infer return type 2026-06-09 08:05:31 +02:00
fc67f01f34 refactor(checker): extract reduce_types function 2026-06-09 08:04:45 +02:00
0a748a36a3 feat(types): WIP add AppliedType 2026-06-08 18:26:11 +02:00
89fdd1b47e feat(checker): WIP add lists 2026-06-08 18:25:37 +02:00
0cde53ac6e feat(types): add name to generic type 2026-06-08 18:21:40 +02:00
f3ec3606c2 fix: avoid circular import in builtins.py 2026-06-08 13:48:46 +02:00
67ec029529 refactor(resolver): move resolver to checker module 2026-06-08 13:45:48 +02:00
e2aef7a811 refactor(checker): unify builtins definitions 2026-06-08 13:44:26 +02:00
86ba4e658a refactor(checker): restructure around shared registry
restructure the type checker with a shared TypesRegistry used by MidasTyper and PythonTyper

this commit also relocates some methods in more appropriate places, such as is_subtype and apply_generic (now in TypesRegistry)
2026-06-08 13:41:42 +02:00
7eccf59558 feat(checker): add reporter class 2026-06-08 13:38:35 +02:00
9dd7801d2d feat(resolver): handle generic application 2026-06-08 10:59:01 +02:00
154cb8b314 refactor(checker): move is_subtype to resolver 2026-06-08 10:57:50 +02:00
c64ab434b5 refactor(checker): move unfold_type to types.py 2026-06-08 10:56:27 +02:00
25e6410546 feat(resolver): handle generics definition 2026-06-08 10:55:15 +02:00
8a22acc17c feat(checker): add generic type structure 2026-06-08 10:52:34 +02:00
e0179bc442 feat(checker): handle assignments to attributes 2026-06-07 17:50:56 +02:00
e665d03533 fix: remove unused SetExpr 2026-06-07 17:48:31 +02:00
b8cb2b4273 feat(checker): handle attribute getter 2026-06-07 15:07:24 +02:00
d278dc5f5b tests: update tests with operation overloads 2026-06-07 14:28:36 +02:00
59e73f0fd9 fix(checker): invert property subtype check 2026-06-07 14:00:02 +02:00
3e0dc60283 fix(checker): only unfold alias on subtype 2026-06-07 13:59:27 +02:00
c24eb5125e feat(checker): resolve operation overloads with subtypes 2026-06-07 13:43:43 +02:00
25bd895dde feat(cli): improve diagnostic printing 2026-06-07 13:42:15 +02:00
bccd75317e tests: add subtyping test 2026-06-06 16:59:49 +02:00
f0e3f7574f feat(tests): add judgements to test results
add type judgements to checker test results and update all tests (including the new subtyping rules)
2026-06-06 16:58:13 +02:00
5d44081847 feat(checker): implement function subtyping
the logic for checking function subtypes is a WIP and has not been fully tested, there may be some errors and unhandled edge cases
Claude helped lay out and verify the overall steps

Co-authored-by: Claude <noreply@anthropic.com>
2026-06-06 16:53:52 +02:00
2a2bb0aec7 feat(checker): store function param position 2026-06-06 16:50:42 +02:00
67c40a3909 feat(checker): add is_subtype method 2026-06-06 16:30:04 +02:00
1c30188122 feat(checker): record type judgements 2026-06-06 16:25:33 +02:00
82a0f13242 feat(cli): add verbose flag to compile 2026-06-05 14:17:24 +02:00
288d15a9bc Merge pull request 'Usage documentation' (#7) from feat/usage-documentation into main
Reviewed-on: #7
2026-06-05 10:29:42 +00:00
504703d0f7 fix(cli): remove print in main command 2026-06-05 12:26:09 +02:00
e48895d0af docs: add usage documentation in README 2026-06-05 12:25:02 +02:00
13d32d0d27 Merge pull request 'Basic type checker' (#6) from feat/basic-type-checker into main
Reviewed-on: #6
2026-06-05 09:31:53 +00:00
19b9fdd623 Merge pull request 'Improve syntax and types' (#5) from feat/improve-syntax-and-types into feat/basic-type-checker
Reviewed-on: #5
2026-06-05 09:20:56 +00:00
ddcaebb51a fix: remove outdated syntax definition 2026-06-05 11:19:29 +02:00
f182312cd2 fix: update midas syntax definitions 2026-06-05 11:14:53 +02:00
73b21789d5 fix(tests): remove custom imports 2026-06-05 10:48:46 +02:00
5d7c724bc8 fix(cli): add types files argument 2026-06-05 10:44:20 +02:00
74b297c89c feat(checker): remove custom midas import
remove custom import statement (`midas.using`) in favor of passing type definition files as arguments to the checker
2026-06-05 10:43:52 +02:00
822a74acce refactor(checker): rename methods
improve a couple methods names, namely evaluate → type_of and evaluate_block → process_block
2026-06-03 13:03:41 +02:00
9a934fabfd tests: remove union type 2026-06-02 17:22:19 +02:00
828ec9a3fa fix!: remove union type 2026-06-02 17:19:17 +02:00
63a43d79dd chore: update examples 2026-06-02 13:07:53 +02:00
029caf4526 fix(tests): update tests with new syntax 2026-06-02 13:05:38 +02:00
1c5c418f1c fix(tests): serialize ternary expressions 2026-06-02 13:05:06 +02:00
a4139d4652 feat(checker): handle logical expressions 2026-06-02 13:03:07 +02:00
2fd2071d40 feat(parser): parse pass statement and None 2026-06-02 13:02:45 +02:00
97b1ee8ab8 feat(cli): add format command 2026-06-02 13:00:43 +02:00
dee479def5 fix(checker): wrap type definitions in AliasType 2026-06-02 13:00:03 +02:00
c8536e20d2 feat(tests): update Midas serializer 2026-06-02 12:38:58 +02:00
d70137775f feat(cli): update highlighter with new nodes 2026-06-02 12:29:39 +02:00
35ceda99aa chore: tidy 2026-06-02 11:45:49 +02:00
7f3d74ee49 feat(checker)!: resolve new types 2026-06-02 11:44:31 +02:00
b9f378de6f feat(parser)!: update Midas parser with new nodes 2026-06-02 11:42:35 +02:00
ccb17c7290 feat(parser)!: add new Midas AST nodes 2026-06-02 11:41:53 +02:00
505779310a feat: add new midas syntax example 2026-06-02 11:40:42 +02:00
bea3f399ad feat(checker): handle ternary expression 2026-06-01 15:02:12 +02:00
55060bfecd feat(parser): add ternary statement 2026-06-01 15:00:21 +02:00
dd126f2559 fix(cli): improve diagnostic message popup 2026-06-01 14:48:24 +02:00
4151f5373d fix(checker): early define fully-typed function
to handle simple recursion cases where the function has an explicit return type hint, the function must be defined before evaluating its body
2026-06-01 14:40:42 +02:00
bd31713ab4 tests(checker): add control flow test 2026-06-01 14:22:03 +02:00
f4dc57cb96 chore: add control flow example 2026-06-01 14:15:10 +02:00
261fd47494 feat(cli): update highlighter 2026-06-01 14:14:10 +02:00
1b66a8553d fix(checker): handle paths with no returns in functions 2026-06-01 14:13:48 +02:00
65164abadb feat(checker): type check if statements 2026-06-01 14:13:17 +02:00
9d45163d9c feat(checker): handle comparisons 2026-06-01 14:12:22 +02:00
ab0fa1de1a feat(parser): add if statement 2026-06-01 14:11:12 +02:00
5d4df7978b fix(cli): ignore repeated visit of types 2026-06-01 14:10:07 +02:00
62 changed files with 6826 additions and 1851 deletions

View File

@@ -5,3 +5,82 @@
*Midas* aims at providing Python developers with a simple annotation system to enable compile-time integrity and data type checks, as well as generating runtime assertions.
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
## Requirements
- Python 3.11+
- [uv](https://docs.astral.sh/uv/getting-started/installation/)
## Installation
1. Clone the repository
```shell
git clone https://git.kb28.ch/HEL/midas.git
```
2. Go in the project directory
```shell
cd midas
```
3. Install the CLI as a user-wide tool
```shell
uv tool install .
```
4. You can now run the `midas` command from anywhere
```shell
midas --help
```
## Commands
### Compiling
> [!NOTE]
> In the current state of the project, the `compile` command doesn't generate any runnable code, it only runs the parsers and type checker on the provided files
```shell
midas compile -t types.midas source.py
```
With the `compile` command, you can process a source Python file, with any number of custom type definition files (`-t FILE` option), and the type checker will verify the coherence of your program and generate the runnable code with valid syntax and runtime assertions.
The optional `-l FILE` option lets you produce a highlighted version of the source code showing diagnostics from the type checker (see [Highlighting](#highlighting))
### Highlighting
```shell
midas utils highlight source.py
# or
midas utils highlight types.midas
```
The `highlight` command takes in a source file (Python or Midas), runs the appropriate parser and outputs an HTML file containing the source code with added highlighting. This highlighting takes the form of hoverable annotations showing some of the parsed structures (e.g. a function definition, an assignment, a generic type, etc.)
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
### Dumping the AST
```shell
midas utils dump-ast source.py
# or
midas utils dump-ast types.midas
```
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `-p` flags lets you toggle the custom AST parsing. Without `-p`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
## Tests
Several snapshot tests are available to assert the good behaviour of the parsers and type checker. They can be run as follows:
```shell
uv run -m tests.midas run -a
uv run -m tests.python run -a
uv run -m tests.checker run -a
```
**Available subcommands:**
- Run all tests: `run -a`
- Run specific tests: `run tests/cases/test1.py tests/cases/test2.py ...`
- Update all tests: `update -a`
- Update specific tests: `update tests/cases/test1.py tests/cases/test2.py ...`

View File

@@ -2,10 +2,6 @@
# ruff: disable[F821]
from __future__ import annotations
# Prototype of custom type import to use valid Python syntax
import midas
midas.using("02_custom_types.midas")
# A data-frame using a custom type
df: Frame[
location: GeoLocation

View File

@@ -0,0 +1,33 @@
type Foo1 = float
type Foo2 = float where (_ > 3)
type Foo3 = int | float
type Foo4 = int where (_ > 3) | float where (_ > 3)
type Foo5 = (int | float) where (_ > 3)
type Foo6 = {
foo: float
bar: float where (_ > 3)
}
type Foo7[T] = T where (_ > 3)
type Foo8[A, B<:int] = {
a: A
b: B
}
type Complex = {
a: int
b: int
}
type Complex2 = Complex where (_.a > 3 & _.b < 5)
predicate Positive(n: int) = n >= 0
extend Foo1 {
op __add__(Foo1) -> Foo1
}
extend Foo7[T] {
op __add__(Foo7[T]) -> Foo7[T]
}
type Optional[T] = None | T

View File

@@ -9,3 +9,5 @@ d = True
e = d + d
f: float = a
f = -f

View File

@@ -1,14 +1,14 @@
type Meter(float)
type Second(float)
type MeterPerSecond(float)
type Meter = float
type Second = float
type MeterPerSecond = float
extend Meter {
op __add__(Meter) -> Meter
op __sub__(Meter) -> Meter
op __truediv__(Second) -> MeterPerSecond
def __add__: fn(Meter, /) -> Meter
def __sub__: fn(Meter, /) -> Meter
def __truediv__: fn(Second, /) -> MeterPerSecond
}
extend Second {
op __add__(Second) -> Second
op __sub__(Second) -> Second
def __add__: fn(Second, /) -> Second
def __sub__: fn(Second, /) -> Second
}

View File

@@ -1,8 +1,6 @@
# type: ignore
# ruff: disable [F821]
midas.using("02_simple_types.midas")
distance: Meter = cast(Meter, 123.45)
time: Second = cast(Second, 6.7)
speed = distance / time

View File

@@ -0,0 +1,23 @@
def minimum(x: int, y: int):
if x < y:
return x
else:
return y
a = 15
b = 72
c = minimum(a, b)
def factorial(n: int) -> int:
if n <= 1:
return 1
return n * factorial(n - 1)
category = "Category 1" if a < 10 else "Category 2"
def foo() -> None:
pass

View File

@@ -0,0 +1,21 @@
type Meter = float
extend Meter {
def __add__: fn(Meter, /) -> Meter
def __sub__: fn(Meter, /) -> Meter
}
type Coordinate = object
extend Coordinate {
prop x: Meter
prop y: Meter
}
type Difference[T <: float] = T
type MeterDifference = Difference[Meter]
type CompDiff[T <: float] = {
prop d1: Difference[T]
prop d2: Difference[T]
}

View File

@@ -0,0 +1,37 @@
# type: ignore
# ruff: disable [F821]
p1: Coordinate
p2: Coordinate
diff_x = p2.x - p1.x
diff_y = p2.y - p1.y
dist = diff_x + diff_y
p2.x += cast(Meter, 1)
p2.y = True # invalid, wrong type
p2.z = 3 # invalid, no property 'z' on Coordinate
p2.x.a = 3 # invalid, no properties on Meter
foo: list[float] = []
append = foo.append
foo.append("") # invalid, must be float
foo.append(2)
append(True) # invalid, must be float
append(2)
bar: list[list[Meter]]
bar.append([p2.x])
foo2 = foo + foo
a = foo[0]
b = bar[0][1]
c = bar[0][1][2] # invalid, not method __getitem__ on Meter
c = bar[""] # invalid, wrong index type
d = foo[1:2]

View File

@@ -0,0 +1,28 @@
def incr(value: int):
return value + 1
def decr(value: int):
return value - 1
def foo(a: int, /, b: float, *, c: str):
return True
r1 = foo() # foo() missing 2 required positional arguments: 'a' and 'b'
r2 = foo(1) # foo() missing 1 required positional argument: 'b'
r3 = foo(1, 2.0) # foo() missing 1 required keyword-only argument: 'c'
r4 = foo(1, b=2.0) # foo() missing 1 required keyword-only argument: 'c'
r5 = foo(1, 2.0, "test") # foo() takes 2 positional arguments but 3 were given
r6 = foo(1, 2.0, b=3.0) # foo() got multiple values for argument 'b'
r7 = foo(
a=1
) # foo() got some positional-only arguments passed as keyword arguments: 'a'
r8 = foo(g="test") # foo() got an unexpected keyword argument 'g'
r9a = foo(1, 2.0, c="test")
r9b = foo(1, b=2.0, c="test")
r9c = foo(1, c="test", b=2.0)
r10 = foo("a", 3, c=False) # wrong argument types

View File

@@ -0,0 +1,10 @@
type T1 = object
type T2 = object
type Foo = object
type T2b = T2
extend Foo {
def bar: fn(T1, /) -> int
def bar: fn(T2, /) -> float
def bar: fn(T2b, /) -> int
}

View File

@@ -0,0 +1,18 @@
# type: ignore
# ruff: disable [F821]
foo: Foo
t1: T1
t2: T2
a = foo.bar(t1)
b = foo.bar(t2)
func = foo.bar
c = func(t1)
d = func(t2)
t2b: T2b
e = foo.bar(t2b)

View File

@@ -30,6 +30,7 @@ from __future__ import annotations
T = TypeVar("T")
{preamble}
{sections}
"""
@@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile(
re.MULTILINE | re.DOTALL,
)
PREAMBLE_REGEX = re.compile(
r"^###>\s*Preamble\s*?\n(?P<body>.*?)\n###<$",
re.MULTILINE | re.DOTALL,
)
def snake_case(text: str) -> str:
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
@@ -88,13 +94,14 @@ def make_banner(text: str) -> str:
def make_section(full_name: str, base: str, param: str, body: str) -> str:
print(f" Generating {full_name}")
visitor_methods: list[str] = []
classes: list[str] = []
definitions: list[str] = body.strip("\n").split("\n\n\n")
for cls in definitions:
cls = cls.strip("\n")
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
print(f"Processing {name}")
print(f" Processing {name}")
visitor_methods.append(make_visitor_method(name, param))
classes.append(make_class(name, cls, base))
@@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str:
def generate(definitions_path: Path, out_path: Path):
print(f"Processing generating {out_path} from {definitions_path}")
root_dir: Path = Path(__file__).parent.parent
rel_path: Path = definitions_path.relative_to(root_dir)
src: str = definitions_path.read_text()
@@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path):
if m := IMPORTS_REGEX.search(src):
imports = m.group("body").strip("\n")
preamble: str = ""
if m := PREAMBLE_REGEX.search(src):
preamble = m.group("body")
for section_m in SECTION_REGEX.finditer(src):
full_name: str = section_m.group("name")
base: str = section_m.group("base")
@@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path):
gen_path=Path(__file__).relative_to(root_dir),
),
imports=imports,
preamble=preamble,
sections="\n\n\n".join(sections),
)
out_path.write_text(result)

View File

@@ -4,6 +4,7 @@
###> Imports
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
@@ -12,41 +13,45 @@ from midas.lexer.token import Token
###<
###> Preamble
@dataclass(frozen=True, kw_only=True)
class TypeParam:
location: Location
name: Token
bound: Optional[Type]
class MemberKind(Enum):
PROPERTY = auto()
METHOD = auto()
###<
###> Stmt | Statements
class SimpleTypeStmt:
class TypeStmt:
name: Token
template: Optional[TemplateExpr]
base: TypeExpr
constraint: Optional[Expr]
params: list[TypeParam]
type: Type
class ComplexTypeStmt:
class MemberStmt:
name: Token
template: Optional[TemplateExpr]
properties: list[PropertyStmt]
class PropertyStmt:
name: Token
type: TypeExpr
constraint: Optional[Expr]
type: Type
kind: MemberKind
class ExtendStmt:
type: TypeExpr
operations: list[OpStmt]
class OpStmt:
name: Token
operand: TypeExpr
result: TypeExpr
params: list[TypeParam]
members: list[MemberStmt]
class PredicateStmt:
name: Token
subject: Token
type: TypeExpr
type: Type
condition: Expr
@@ -54,9 +59,6 @@ class PredicateStmt:
###> Expr | Expressions
class SimpleTypeExpr:
name: Token
optional: bool
class LogicalExpr:
@@ -97,14 +99,46 @@ class WildcardExpr:
token: Token
class TemplateExpr:
type: TypeExpr
###<
###> Type | Types
class TypeExpr:
class NamedType:
name: Token
template: Optional[TemplateExpr]
optional: bool
class GenericType:
type: Type
args: list[Type]
class ConstraintType:
type: Type
constraint: Expr
class ComplexType:
members: list[MemberStmt]
class ExtensionType:
base: Type
extension: ComplexType
class FunctionType:
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
returns: Type
@dataclass(frozen=True, kw_only=True)
class Argument:
location: Optional[Location] = None
name: Optional[Token]
type: Type
required: bool
###<

View File

@@ -76,6 +76,12 @@ class ReturnStmt:
value: Optional[Expr]
class IfStmt:
test: Expr
body: list[Stmt]
orelse: list[Stmt]
###<
@@ -122,15 +128,30 @@ class LogicalExpr:
right: Expr
class SetExpr:
object: Expr
name: str
value: Expr
class CastExpr:
type: MidasType
expr: Expr
class TernaryExpr:
test: Expr
if_true: Expr
if_false: Expr
class ListExpr:
items: list[Expr]
class SubscriptExpr:
object: Expr
index: Expr
class SliceExpr:
lower: Optional[Expr]
upper: Optional[Expr]
step: Optional[Expr]
###<

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Generic, Optional, TypeVar
from midas.ast.location import Location
@@ -14,6 +15,18 @@ from midas.lexer.token import Token
T = TypeVar("T")
@dataclass(frozen=True, kw_only=True)
class TypeParam:
location: Location
name: Token
bound: Optional[Type]
class MemberKind(Enum):
PROPERTY = auto()
METHOD = auto()
##############
# Statements #
##############
@@ -28,79 +41,53 @@ class Stmt(ABC):
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> T: ...
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
@abstractmethod
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> T: ...
@abstractmethod
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
@abstractmethod
def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ...
@abstractmethod
def visit_op_stmt(self, stmt: OpStmt) -> T: ...
@abstractmethod
def visit_predicate_stmt(self, stmt: PredicateStmt) -> T: ...
@dataclass(frozen=True)
class SimpleTypeStmt(Stmt):
class TypeStmt(Stmt):
name: Token
template: Optional[TemplateExpr]
base: TypeExpr
constraint: Optional[Expr]
params: list[TypeParam]
type: Type
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_simple_type_stmt(self)
return visitor.visit_type_stmt(self)
@dataclass(frozen=True)
class ComplexTypeStmt(Stmt):
class MemberStmt(Stmt):
name: Token
template: Optional[TemplateExpr]
properties: list[PropertyStmt]
type: Type
kind: MemberKind
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_complex_type_stmt(self)
@dataclass(frozen=True)
class PropertyStmt(Stmt):
name: Token
type: TypeExpr
constraint: Optional[Expr]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_property_stmt(self)
return visitor.visit_member_stmt(self)
@dataclass(frozen=True)
class ExtendStmt(Stmt):
type: TypeExpr
operations: list[OpStmt]
name: Token
params: list[TypeParam]
members: list[MemberStmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_extend_stmt(self)
@dataclass(frozen=True)
class OpStmt(Stmt):
name: Token
operand: TypeExpr
result: TypeExpr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_op_stmt(self)
@dataclass(frozen=True)
class PredicateStmt(Stmt):
name: Token
subject: Token
type: TypeExpr
type: Type
condition: Expr
def accept(self, visitor: Stmt.Visitor[T]) -> T:
@@ -120,9 +107,6 @@ class Expr(ABC):
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> T: ...
@abstractmethod
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
@@ -147,21 +131,6 @@ class Expr(ABC):
@abstractmethod
def visit_wildcard_expr(self, expr: WildcardExpr) -> T: ...
@abstractmethod
def visit_template_expr(self, expr: TemplateExpr) -> T: ...
@abstractmethod
def visit_type_expr(self, expr: TypeExpr) -> T: ...
@dataclass(frozen=True)
class SimpleTypeExpr(Expr):
name: Token
optional: bool
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_simple_type_expr(self)
@dataclass(frozen=True)
class LogicalExpr(Expr):
@@ -233,19 +202,94 @@ class WildcardExpr(Expr):
return visitor.visit_wildcard_expr(self)
@dataclass(frozen=True)
class TemplateExpr(Expr):
type: TypeExpr
#########
# Types #
#########
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_template_expr(self)
@dataclass(frozen=True, kw_only=True)
class Type(ABC):
location: Location
@abstractmethod
def accept(self, visitor: Visitor[T]) -> T: ...
class Visitor(ABC, Generic[T]):
@abstractmethod
def visit_named_type(self, type: NamedType) -> T: ...
@abstractmethod
def visit_generic_type(self, type: GenericType) -> T: ...
@abstractmethod
def visit_constraint_type(self, type: ConstraintType) -> T: ...
@abstractmethod
def visit_complex_type(self, type: ComplexType) -> T: ...
@abstractmethod
def visit_extension_type(self, type: ExtensionType) -> T: ...
@abstractmethod
def visit_function_type(self, type: FunctionType) -> T: ...
@dataclass(frozen=True)
class TypeExpr(Expr):
class NamedType(Type):
name: Token
template: Optional[TemplateExpr]
optional: bool
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_type_expr(self)
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_named_type(self)
@dataclass(frozen=True)
class GenericType(Type):
type: Type
args: list[Type]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_generic_type(self)
@dataclass(frozen=True)
class ConstraintType(Type):
type: Type
constraint: Expr
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_constraint_type(self)
@dataclass(frozen=True)
class ComplexType(Type):
members: list[MemberStmt]
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_complex_type(self)
@dataclass(frozen=True)
class ExtensionType(Type):
base: Type
extension: ComplexType
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_extension_type(self)
@dataclass(frozen=True)
class FunctionType(Type):
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
returns: Type
@dataclass(frozen=True, kw_only=True)
class Argument:
location: Optional[Location] = None
name: Optional[Token]
type: Type
required: bool
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_function_type(self)

View File

@@ -85,67 +85,66 @@ class AstPrinter(Generic[T]):
child.accept(self)
class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
class MidasAstPrinter(
AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None], m.Type.Visitor[None]
):
# Statements
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
self._write_line("SimpleTypeStmt")
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self._write_line("TypeStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_optional_child("template", stmt.template)
self._write_line("base")
with self._child_level(single=True):
stmt.base.accept(self)
self._write_optional_child("constraint", stmt.constraint, last=True)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
self._write_line("ComplexTypeStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_optional_child("template", stmt.template)
self._write_line("properties", last=True)
self._write_line("params")
with self._child_level():
for i, prop in enumerate(stmt.properties):
for i, param in enumerate(stmt.params):
self._idx = i
if i == len(stmt.properties) - 1:
if i == len(stmt.params) - 1:
self._mark_last()
prop.accept(self)
def visit_property_stmt(self, stmt: m.PropertyStmt):
self._write_line("PropertyStmt")
with self._child_level():
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type")
self._print_type_param(param)
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
def _print_type_param(self, param: m.TypeParam) -> None:
self._write_line("Param")
with self._child_level():
self._write_line(f'name: "{param.name.lexeme}"')
self._write_optional_child("bound", param.bound, last=True)
def visit_member_stmt(self, stmt: m.MemberStmt):
self._write_line("MemberStmt")
with self._child_level():
self._write_line(f"kind: {stmt.kind.name}")
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("type", last=True)
with self._child_level(single=True):
stmt.type.accept(self)
self._write_optional_child("constraint", stmt.constraint, last=True)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._write_line("ExtendStmt")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
stmt.type.accept(self)
self._write_line("operations", last=True)
self._write_line("params")
with self._child_level():
for i, op in enumerate(stmt.operations):
for i, param in enumerate(stmt.params):
self._idx = i
if i == len(stmt.operations) - 1:
if i == len(stmt.params) - 1:
self._mark_last()
op.accept(self)
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
self._write_line("OpStmt")
with self._child_level():
self._print_type_param(param)
self._write_line(f'name: "{stmt.name.lexeme}"')
self._write_line("operand")
with self._child_level(single=True):
stmt.operand.accept(self)
self._write_line("result", last=True)
with self._child_level(single=True):
stmt.result.accept(self)
self._write_line("params")
with self._child_level():
for i, param in enumerate(stmt.params):
self._idx = i
if i == len(stmt.params) - 1:
self._mark_last()
self._print_type_param(param)
self._write_line("members", last=True)
with self._child_level():
for i, member in enumerate(stmt.members):
self._idx = i
if i == len(stmt.members) - 1:
self._mark_last()
member.accept(self)
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
self._write_line("PredicateStmt")
@@ -161,12 +160,6 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
# Expressions
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
self._write_line("SimpleTypeExpr")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_line(f"optional: {expr.optional}", last=True)
def visit_logical_expr(self, expr: m.LogicalExpr):
self._write_line("LogicalExpr")
with self._child_level():
@@ -230,22 +223,101 @@ class MidasAstPrinter(AstPrinter, m.Expr.Visitor[None], m.Stmt.Visitor[None]):
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self._write_line("WildcardExpr")
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
self._write_line("TemplateExpr")
with self._child_level(single=True):
def visit_named_type(self, type: m.NamedType) -> None:
self._write_line("NamedType")
with self._child_level():
self._write_line(f'name: "{type.name.lexeme}"', last=True)
def visit_generic_type(self, type: m.GenericType) -> None:
self._write_line("GenericType")
with self._child_level():
self._write_line("type")
with self._child_level():
type.type.accept(self)
self._write_line("args", last=True)
with self._child_level():
for i, param in enumerate(type.args):
self._idx = i
if i == len(type.args) - 1:
self._mark_last()
param.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self._write_line("ConstraintType")
with self._child_level():
self._write_line("type")
with self._child_level(single=True):
expr.type.accept(self)
type.type.accept(self)
self._write_line("constraint", last=True)
with self._child_level(single=True):
type.constraint.accept(self)
def visit_type_expr(self, expr: m.TypeExpr):
self._write_line("TypeExpr")
def visit_complex_type(self, type: m.ComplexType) -> None:
self._write_line("ComplexType")
with self._child_level():
self._write_line(f'name: "{expr.name.lexeme}"')
self._write_optional_child("template", expr.template)
self._write_line(f"optional: {expr.optional}", last=True)
self._write_line("members", last=True)
with self._child_level():
for i, member in enumerate(type.members):
self._idx = i
if i == len(type.members) - 1:
self._mark_last()
member.accept(self)
def visit_extension_type(self, type: m.ExtensionType) -> None:
self._write_line("ExtensionType")
with self._child_level():
self._write_line("base")
with self._child_level(single=True):
type.base.accept(self)
self._write_line("extension", last=True)
with self._child_level(single=True):
type.extension.accept(self)
def visit_function_type(self, type: m.FunctionType) -> None:
self._write_line("FunctionType")
with self._child_level():
self._write_line("pos_args")
with self._child_level():
for i, arg in enumerate(type.pos_args):
self._idx = i
if i == len(type.pos_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("args")
with self._child_level():
for i, arg in enumerate(type.args):
self._idx = i
if i == len(type.args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("kw_args")
with self._child_level():
for i, arg in enumerate(type.kw_args):
self._idx = i
if i == len(type.kw_args) - 1:
self._mark_last()
self._print_function_arg(arg)
self._write_line("returns", last=True)
with self._child_level(single=True):
type.returns.accept(self)
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
self._write_line("Argument")
with self._child_level():
name: str = "None"
if arg.name is not None:
name = f'"{arg.name.lexeme}"'
self._write_line(f"name: {name}")
self._write_line("type")
with self._child_level(single=True):
arg.type.accept(self)
self._write_line(f"required: {arg.required}", last=True)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
def __init__(self, indent: int = 4):
self.indent: int = indent
self.level: int = 0
@@ -253,50 +325,46 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def indented(self, text: str) -> str:
return " " * (self.level * self.indent) + text
def print(self, expr: m.Expr | m.Stmt):
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
self.level = 0
return expr.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt):
template: str = stmt.template.accept(self) if stmt.template is not None else ""
res: str = f"type {stmt.name.lexeme}{template}({stmt.base.accept(self)})"
if stmt.constraint is not None:
res += " where " + stmt.constraint.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
template: str = ""
if len(stmt.params) != 0:
params: list[str] = [self._print_type_param(param) for param in stmt.params]
template = f"[{', '.join(params)}]"
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
return self.indented(res)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt):
template: str = stmt.template.accept(self) if stmt.template is not None else ""
res: str = self.indented(f"type {stmt.name.lexeme}{template}")
res += " {\n"
self.level += 1
for prop in stmt.properties:
res += prop.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
def _print_type_param(self, param: m.TypeParam) -> str:
res: str = param.name.lexeme
if param.bound is not None:
res += "<:" + param.bound.accept(self)
return res
def visit_property_stmt(self, stmt: m.PropertyStmt):
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
if stmt.constraint is not None:
res += " where " + stmt.constraint.accept(self)
def visit_member_stmt(self, stmt: m.MemberStmt):
keyword: str = {
m.MemberKind.PROPERTY: "prop",
m.MemberKind.METHOD: "def",
}.get(stmt.kind, "")
res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}"
return self.indented(res)
def visit_extend_stmt(self, stmt: m.ExtendStmt):
res: str = self.indented(f"extend {stmt.type.accept(self)}")
template: str = ""
if len(stmt.params) != 0:
params: list[str] = [self._print_type_param(param) for param in stmt.params]
template = f"[{', '.join(params)}]"
res: str = self.indented(f"extend {stmt.name.lexeme}{template}")
res += " {\n"
self.level += 1
for op in stmt.operations:
res += op.accept(self)
for member in stmt.members:
res += member.accept(self) + "\n"
self.level -= 1
res += "\n" + self.indented("}")
res += self.indented("}")
return res
def visit_op_stmt(self, stmt: m.OpStmt):
operand: str = stmt.operand.accept(self)
result: str = stmt.result.accept(self)
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}")
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
name: str = stmt.name.lexeme
subject: str = stmt.subject.lexeme
@@ -304,9 +372,6 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
condition: str = stmt.condition.accept(self)
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr):
return f"{expr.name.lexeme}{'?' if expr.optional else ''}"
def visit_logical_expr(self, expr: m.LogicalExpr):
left: str = expr.left.accept(self)
operator: str = expr.operator.lexeme
@@ -342,12 +407,58 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str]):
def visit_wildcard_expr(self, expr: m.WildcardExpr):
return "_"
def visit_template_expr(self, expr: m.TemplateExpr):
return f"[{expr.type.accept(self)}]"
def visit_named_type(self, type: m.NamedType) -> str:
return type.name.lexeme
def visit_type_expr(self, expr: m.TypeExpr):
template: str = expr.template.accept(self) if expr.template is not None else ""
return f"{expr.name.lexeme}{template}{'?' if expr.optional else ''}"
def visit_generic_type(self, type: m.GenericType) -> str:
res: str = type.type.accept(self)
if len(type.args) != 0:
args: list[str] = [param.accept(self) for param in type.args]
res += f"[{', '.join(args)}]"
return res
def visit_constraint_type(self, type: m.ConstraintType) -> str:
res: str = type.type.accept(self)
res += " where " + type.constraint.accept(self)
return res
def visit_complex_type(self, type: m.ComplexType) -> str:
res: str = "{\n"
self.level += 1
for member in type.members:
res += member.accept(self)
res += "\n"
self.level -= 1
res += self.indented("}")
return res
def visit_extension_type(self, type: m.ExtensionType) -> str:
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
def visit_function_type(self, type: m.FunctionType) -> str:
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
args: list[str] = pos_args
if len(pos_args) != 0:
args.append("/")
args += mixed_args
if len(kw_args) != 0:
args.append("*")
args += kw_args
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
res: str = ""
if arg.name is not None:
res += arg.name.lexeme
res += ": "
res += arg.type.accept(self)
if not arg.required:
res += "?"
return res
class PythonAstPrinter(
@@ -419,7 +530,14 @@ class PythonAstPrinter(
self._mark_last()
self._print_argument(arg)
self._write_optional_child("returns", stmt.returns, last=True)
self._write_optional_child("returns", stmt.returns)
self._write_line("body", last=True)
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
def _print_argument(self, arg: p.Function.Argument) -> None:
self._write_line("FunctionArgument")
@@ -454,6 +572,27 @@ class PythonAstPrinter(
with self._child_level():
self._write_optional_child("value", stmt.value, last=True)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
self._write_line("IfStmt")
with self._child_level():
self._write_line("test")
with self._child_level(single=True):
stmt.test.accept(self)
self._write_line("body")
with self._child_level():
for i, body_stmt in enumerate(stmt.body):
self._idx = i
if i == len(stmt.body) - 1:
self._mark_last()
body_stmt.accept(self)
self._write_line("orelse", last=True)
with self._child_level():
for i, else_stmt in enumerate(stmt.orelse):
self._idx = i
if i == len(stmt.orelse) - 1:
self._mark_last()
else_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self._write_line("BinaryExpr")
with self._child_level():
@@ -525,7 +664,7 @@ class PythonAstPrinter(
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
self._write_line("LiteralExpr")
with self._child_level(single=True):
self._write_line(f"value: {expr.value}")
self._write_line(f"value: {expr.value!r}")
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
self._write_line("VariableExpr")
@@ -545,17 +684,6 @@ class PythonAstPrinter(
with self._child_level(single=True):
expr.right.accept(self)
def visit_set_expr(self, expr: p.SetExpr) -> None:
self._write_line("SetExpr")
with self._child_level():
self._write_line("object")
with self._child_level(single=True):
expr.object.accept(self)
self._write_line(f"name: {expr.name}")
self._write_line("value", last=True)
with self._child_level(single=True):
expr.value.accept(self)
def visit_cast_expr(self, expr: p.CastExpr) -> None:
self._write_line("CastExpr")
with self._child_level():
@@ -565,3 +693,46 @@ class PythonAstPrinter(
self._write_line("expr", last=True)
with self._child_level(single=True):
expr.expr.accept(self)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self._write_line("TernaryExpr")
with self._child_level():
self._write_line("test")
with self._child_level(single=True):
expr.test.accept(self)
self._write_line("if_true")
with self._child_level(single=True):
expr.if_true.accept(self)
self._write_line("if_false", last=True)
with self._child_level(single=True):
expr.if_false.accept(self)
def visit_list_expr(self, expr: p.ListExpr) -> None:
self._write_line("ListExpr")
with self._child_level():
self._write_line("items", last=True)
with self._child_level():
for i, item in enumerate(expr.items):
self._idx = i
if i == len(expr.items) - 1:
self._mark_last()
item.accept(self)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
self._write_line("SubscriptExpr")
with self._child_level():
self._write_line("object")
with self._child_level(single=True):
expr.object.accept(self)
self._write_line("index", last=True)
with self._child_level(single=True):
expr.index.accept(self)
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
self._write_line("SliceExpr")
with self._child_level():
self._write_optional_child("lower", expr.lower)
self._write_optional_child("upper", expr.upper)
self._write_optional_child("step", expr.step, last=True)

View File

@@ -14,6 +14,7 @@ from midas.ast.location import Location
T = TypeVar("T")
####################
# Type annotations #
####################
@@ -103,6 +104,9 @@ class Stmt(ABC):
@abstractmethod
def visit_return_stmt(self, stmt: ReturnStmt) -> T: ...
@abstractmethod
def visit_if_stmt(self, stmt: IfStmt) -> T: ...
@dataclass(frozen=True)
class ExpressionStmt(Stmt):
@@ -164,6 +168,16 @@ class ReturnStmt(Stmt):
return visitor.visit_return_stmt(self)
@dataclass(frozen=True)
class IfStmt(Stmt):
test: Expr
body: list[Stmt]
orelse: list[Stmt]
def accept(self, visitor: Stmt.Visitor[T]) -> T:
return visitor.visit_if_stmt(self)
###############
# Expressions #
###############
@@ -202,10 +216,19 @@ class Expr(ABC):
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
@abstractmethod
def visit_set_expr(self, expr: SetExpr) -> T: ...
def visit_cast_expr(self, expr: CastExpr) -> T: ...
@abstractmethod
def visit_cast_expr(self, expr: CastExpr) -> T: ...
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
@abstractmethod
def visit_list_expr(self, expr: ListExpr) -> T: ...
@abstractmethod
def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ...
@abstractmethod
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
@dataclass(frozen=True)
@@ -282,16 +305,6 @@ class LogicalExpr(Expr):
return visitor.visit_logical_expr(self)
@dataclass(frozen=True)
class SetExpr(Expr):
object: Expr
name: str
value: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_set_expr(self)
@dataclass(frozen=True)
class CastExpr(Expr):
type: MidasType
@@ -299,3 +312,40 @@ class CastExpr(Expr):
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_cast_expr(self)
@dataclass(frozen=True)
class TernaryExpr(Expr):
test: Expr
if_true: Expr
if_false: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_ternary_expr(self)
@dataclass(frozen=True)
class ListExpr(Expr):
items: list[Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_list_expr(self)
@dataclass(frozen=True)
class SubscriptExpr(Expr):
object: Expr
index: Expr
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_subscript_expr(self)
@dataclass(frozen=True)
class SliceExpr(Expr):
lower: Optional[Expr]
upper: Optional[Expr]
step: Optional[Expr]
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_slice_expr(self)

View File

@@ -0,0 +1,152 @@
extend float {
def hex: fn() -> str
def is_integer: fn() -> bool
prop real: float
prop imag: float
def conjugate: fn() -> float
def __add__: fn(value: float, /) -> float
def __sub__: fn(value: float, /) -> float
def __mul__: fn(value: float, /) -> float
def __floordiv__: fn(value: float, /) -> float
def __truediv__: fn(value: float, /) -> float
def __mod__: fn(value: float, /) -> float
// def __divmod__: fn(value: float, /) -> tuple[float, float]
def __pow__: fn(value: int, /) -> float
// positive __value -> float; negative __value -> complex
// return type must be Any as `float | complex` causes too many false-positive errors
def __pow__: fn(value: float, /) -> Any
def __radd__: fn(value: float, /) -> float
def __rsub__: fn(value: float, /) -> float
def __rmul__: fn(value: float, /) -> float
def __rfloordiv__: fn(value: float, /) -> float
def __rtruediv__: fn(value: float, /) -> float
def __rmod__: fn(value: float, /) -> float
// def __rdivmod__: fn(value: float, /) -> tuple[float, float]
// def __rpow__: fn(value: _PositiveInteger, mod: None = None, /) -> float
// def __rpow__: fn(value: _NegativeInteger, mod: None = None, /) -> complex
// Returning `complex` for the general case gives too many false-positive errors.
// def __rpow__: fn(value: float, mod: None = None, /) -> Any
// def __getnewargs__: fn() -> tuple[float]
def __trunc__: fn() -> int
def __ceil__: fn() -> int
def __floor__: fn() -> int
def __round__: fn(ndigits: None?, /) -> int
def __round__: fn(ndigits: int, /) -> float
def __eq__: fn(value: object, /) -> bool
def __ne__: fn(value: object, /) -> bool
def __lt__: fn(value: float, /) -> bool
def __le__: fn(value: float, /) -> bool
def __gt__: fn(value: float, /) -> bool
def __ge__: fn(value: float, /) -> bool
def __neg__: fn() -> float
def __pos__: fn() -> float
def __int__: fn() -> int
def __float__: fn() -> float
def __abs__: fn() -> float
def __hash__: fn() -> int
def __bool__: fn() -> bool
def __format__: fn(format_spec: str, /) -> str
}
extend int {
prop real: int
prop imag: int
prop numerator: int
prop denominator: int
def conjugate: fn() -> int
def bit_length: fn() -> int
def bit_count: fn() -> int
// def to_bytes: fn(length: int?, byteorder: str?, *, signed: bool?) -> bytes
def __add__: fn(value: int, /) -> int
def __sub__: fn(value: int, /) -> int
def __mul__: fn(value: int, /) -> int
def __floordiv__: fn(value: int, /) -> int
def __truediv__: fn(value: int, /) -> float
def __mod__: fn(value: int, /) -> int
// def __divmod__: fn(value: int, /) -> tuple[int, int]
def __radd__: fn(value: int, /) -> int
def __rsub__: fn(value: int, /) -> int
def __rmul__: fn(value: int, /) -> int
def __rfloordiv__: fn(value: int, /) -> int
def __rtruediv__: fn(value: int, /) -> float
def __rmod__: fn(value: int, /) -> int
// def __rdivmod__: fn(value: int, /) -> tuple[int, int]
def __pow__: fn(value: int, /) -> int
// def __pow__: fn(value: _PositiveInteger, mod: None = None, /) -> int
// def __pow__: fn(value: _NegativeInteger, mod: None = None, /) -> float
// positive __value -> int; negative __value -> float
// return type must be Any as `int | float` causes too many false-positive errors
// def __pow__: fn(value: int, mod: None = None, /) -> Any
// def __pow__: fn(value: int, mod: int, /) -> int
def __rpow__: fn(value: int, /) -> Any
def __and__: fn(value: int, /) -> int
def __or__: fn(value: int, /) -> int
def __xor__: fn(value: int, /) -> int
def __lshift__: fn(value: int, /) -> int
def __rshift__: fn(value: int, /) -> int
def __rand__: fn(value: int, /) -> int
def __ror__: fn(value: int, /) -> int
def __rxor__: fn(value: int, /) -> int
def __rlshift__: fn(value: int, /) -> int
def __rrshift__: fn(value: int, /) -> int
def __neg__: fn() -> int
def __pos__: fn() -> int
def __invert__: fn() -> int
def __trunc__: fn() -> int
def __ceil__: fn() -> int
def __floor__: fn() -> int
def __round__: fn(ndigits: None?, /) -> int
def __round__: fn(ndigits: int, /) -> int
// def __getnewargs__: fn() -> tuple[int]
def __eq__: fn(value: object, /) -> bool
def __ne__: fn(value: object, /) -> bool
def __lt__: fn(value: int, /) -> bool
def __le__: fn(value: int, /) -> bool
def __gt__: fn(value: int, /) -> bool
def __ge__: fn(value: int, /) -> bool
def __float__: fn() -> float
def __int__: fn() -> int
def __abs__: fn() -> int
def __hash__: fn() -> int
def __bool__: fn() -> bool
def __index__: fn() -> int
def __format__: fn(format_spec: str, /) -> str
}
extend list[T] {
def copy: fn () -> list[T]
def append: fn (object: T, /) -> None
def extend: fn (iterable: list[T], /) -> None
def pop: fn (index: int?, /) -> T
def index: fn (value: T, start: int?, stop: int?, /) -> int
def count: fn (value: T, /) -> int
def insert: fn (index: int, object: T, /) -> None
def remove: fn (value: T, /) -> None
def sort: fn (*, reverse: bool?) -> None
def __len__: fn () -> int
// def __iter__: fn () -> Iterator[T]
def __getitem__: fn (i: int, /) -> T
def __getitem__: fn (s: slice, /) -> list[T]
def __setitem__: fn (key: int, value: T, /) -> None
def __setitem__: fn (key: slice, value: list[T], /) -> None
def __delitem__: fn (key: int, /) -> None
def __delitem__: fn (key: slice, /) -> None
// def __add__: fn[S <: T] (value: list[S], /) -> list[T]
def __add__: fn (value: list[T], /) -> list[T]
def __iadd__: fn (value: list[T], /) -> list[T]
def __mul__: fn (value: int, /) -> list[T]
def __rmul__: fn (value: int, /) -> list[T]
def __imul__: fn (value: int, /) -> list[T]
def __contains__: fn (key: object, /) -> bool
// def __reversed__: fn (self) -> Iterator[_T]
def __gt__: fn (value: list[T], /) -> bool
def __ge__: fn (value: list[T], /) -> bool
def __lt__: fn (value: list[T], /) -> bool
def __le__: fn (value: list[T], /) -> bool
def __eq__: fn (value: object, /) -> bool
prop __doc__: str
}

41
midas/checker/builtins.py Normal file
View File

@@ -0,0 +1,41 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from midas.checker.types import (
BaseType,
GenericType,
TopType,
TypeVar,
UnitType,
)
if TYPE_CHECKING:
from midas.checker.registry import TypesRegistry
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"float": {"int"},
"int": {"bool"},
}
def define_builtins(reg: TypesRegistry):
"""Define builtin types and operations"""
any = reg.define_type("Any", TopType())
unit = reg.define_type("None", UnitType())
object = reg.define_type("object", BaseType(name="object"))
bool = reg.define_type("bool", BaseType(name="bool"))
int = reg.define_type("int", BaseType(name="int"))
float = reg.define_type("float", BaseType(name="float"))
str = reg.define_type("str", BaseType(name="str"))
slice = reg.define_type("slice", BaseType(name="slice"))
list = reg.define_type(
"list",
GenericType(
name="list",
params=[TypeVar(name="T", bound=None)],
body=BaseType(name="list"),
),
)

View File

@@ -1,467 +1,35 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.environment import Environment
from midas.checker.operators import OPERATOR_METHODS
from midas.checker.types import Function, Type, UnitType, UnknownType
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
from midas.resolver.midas import MidasResolver
from midas.checker.diagnostic import Diagnostic
from midas.checker.midas import MidasTyper
from midas.checker.python import PythonTyper
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import Reporter
class ReturnException(Exception):
pass
class TypeChecker:
def __init__(self):
self.types: TypesRegistry = TypesRegistry()
self.reporter: Reporter = Reporter()
self.midas_typer = MidasTyper(self.types, self.reporter)
self.python_typer = PythonTyper(self.types, self.reporter)
@dataclass(frozen=True, kw_only=True)
class MappedArgument:
expr: p.Expr
type: Type
argument: Function.Argument
def import_midas(self, path: Path):
source: str = path.read_text()
return self.import_midas_source(source, path=str(path))
def import_midas_source(self, source: str, path: Optional[str] = None):
self.midas_typer.process(source, path)
class Checker(
p.Stmt.Visitor[None],
p.Expr.Visitor[Type],
p.MidasType.Visitor[Type],
):
"""A type checker which can use custom type definitions"""
def type_check(self, path: Path):
source: str = path.read_text()
return self.type_check_source(source, path=str(path))
def __init__(self, locals: dict[p.Expr, int], file_path: Path):
self.logger: logging.Logger = logging.getLogger("Checker")
self.file_path: Path = file_path
self.ctx: MidasResolver = MidasResolver()
self.global_env: Environment = Environment()
self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = locals
self.diagnostics: list[Diagnostic] = []
def type_check_source(self, source: str, path: Optional[str] = None):
self.python_typer.process(source, path)
def diagnostic(self, type: DiagnosticType, location: Location, message: str):
self.diagnostics.append(
Diagnostic(
file_path=self.file_path,
location=location,
type=type,
message=message,
)
)
def error(self, location: Location, message: str):
self.diagnostic(
type=DiagnosticType.ERROR,
location=location,
message=message,
)
def warning(self, location: Location, message: str):
self.diagnostic(
type=DiagnosticType.WARNING,
location=location,
message=message,
)
def info(self, location: Location, message: str):
self.diagnostic(
type=DiagnosticType.INFO,
location=location,
message=message,
)
def evaluate(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression
Args:
expr (p.Expr): the expression to evaluate
Returns:
Type: the type of the given expression
"""
return expr.accept(self)
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> None:
"""Evaluate a sequence of statements
Args:
block (list[p.Stmt]): the statements to evaluate
env (Environment): the environment in which to evaluate
"""
previous_env: Environment = self.env
self.env = env
for stmt in block:
try:
stmt.accept(self)
except ReturnException:
break
self.env = previous_env
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
"""Type check a sequence of statements and returns diagnostics
Args:
statements (list[p.Stmt]): the statements to evaluate and check
Returns:
list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
"""
self.diagnostics = []
for stmt in statements:
stmt.accept(self)
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
return self.diagnostics
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
"""Look up a variable in the environment it was declared
Args:
name (str): the name of the variable
expr (p.Expr): the variable expression, used to lookup the scope distance
Returns:
Optional[Type]: the type of the variable, or None if it was not found
"""
distance: Optional[int] = self.locals.get(expr)
if distance is not None:
return self.env.get_at(distance, name)
return self.global_env.get(name)
def parse_midas_import(self, expr: p.CallExpr) -> Optional[Path]:
"""Parse a Midas import statement
The statement should be written as `midas.using("path/to/types.midas")`
Args:
expr (p.CallExpr): the import call expression
Returns:
Optional[Path]: the path to the imported file, or None if the expression is malformed
"""
match expr:
case p.CallExpr(
callee=p.GetExpr(
object=p.VariableExpr(name="midas"),
name="using",
),
arguments=[
p.LiteralExpr(value=path),
],
):
return Path(path)
return None
def import_midas(self, path: Path) -> None:
"""Import Midas definitions from a path
Args:
path (Path): the import path
"""
self.logger.debug(f"Importing type definitions from {path}")
path = (self.file_path.parent / path).resolve()
lexer: MidasLexer = MidasLexer(path.read_text())
tokens: list[Token] = lexer.process()
parser: MidasParser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
self.ctx.resolve(stmts)
self.logger.debug(f"Midas types: {self.ctx._types}")
self.logger.debug(f"Midas operations: {self.ctx._operations}")
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
self.evaluate(stmt.expr)
def visit_function(self, stmt: p.Function) -> None:
env: Environment = Environment(self.env)
pos_args: list[Function.Argument] = []
args: list[Function.Argument] = []
kw_args: list[Function.Argument] = []
def eval_arg_type(arg: p.Function.Argument) -> Type:
if arg.type is not None:
return arg.type.accept(self)
if arg.default is not None:
return arg.default.accept(self)
return UnknownType()
for arg in stmt.posonlyargs:
pos_args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
for arg in stmt.args:
args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
for arg in stmt.kwonlyargs:
kw_args.append(
Function.Argument(
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
for arg in pos_args + args + kw_args:
env.define(arg.name, arg.type)
self.evaluate_block(stmt.body, env)
inferred_return: Type = UnknownType()
if len(env.return_types) == 1:
inferred_return = list(env.return_types)[0]
elif len(env.return_types) > 1:
self.error(
stmt.location,
f"Mixed return types: {env.return_types}",
)
returns: Type = UnknownType()
if stmt.returns is not None:
returns = stmt.returns.accept(self)
if returns != inferred_return:
self.error(
stmt.returns.location,
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
)
else:
returns = inferred_return
# TODO: handle *args and **kwargs sinks
function: Function = Function(
name=stmt.name,
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
)
self.env.define(stmt.name, function)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
# TODO check not yet defined locally
type: Type = stmt.type.accept(self)
self.env.define(stmt.name, type)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
value: Type = self.evaluate(stmt.value)
for target in stmt.targets:
if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}")
self.warning(target.location, f"Unsupported assignment to {target}")
continue
name: str = target.name
var_type: Optional[Type] = self.look_up_variable(name, target)
if var_type is None:
self.env.define(name, value)
else:
# TODO: implement real comparison method
if var_type != value:
self.error(
stmt.location,
f"Cannot assign {value} to {name} of type {var_type}",
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
self.env.return_types.append(type)
raise ReturnException()
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.warning(expr.location, f"Unsupported operator {expr.operator}")
return UnknownType()
left: Type = self.evaluate(expr.left)
right: Type = self.evaluate(expr.right)
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
if result is None:
self.error(
expr.location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return result
def visit_compare_expr(self, expr: p.CompareExpr) -> Type: ...
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
def visit_call_expr(self, expr: p.CallExpr) -> Type:
if path := self.parse_midas_import(expr):
self.import_midas(path)
return UnknownType()
callee: Type = self.evaluate(expr.callee)
if not isinstance(callee, Function):
self.error(expr.callee.location, "Callee is not a function")
return UnknownType()
function: Function = callee
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
for arg in mapped:
if arg.type != arg.argument.type:
self.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
return function.returns
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
match expr.value:
case bool(): # Must be before int
return self.ctx.get_type("bool")
case int():
return self.ctx.get_type("int")
case float():
return self.ctx.get_type("float")
case str():
return self.ctx.get_type("str")
case _:
self.warning(expr.location, f"Unknown literal {expr}")
return UnknownType()
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
return self.look_up_variable(expr.name, expr) or UnknownType()
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: ...
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
return expr.type.accept(self)
def visit_base_type(self, node: p.BaseType) -> Type:
return self.ctx.get_type(node.base)
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
def visit_frame_type(self, node: p.FrameType) -> Type: ...
def map_call_arguments(
self, function: Function, call: p.CallExpr
) -> list[MappedArgument]:
"""Map call arguments to function parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic
Args:
function (Function): the function definition
call (p.CallExpr): the call expression
Returns:
list[MappedArgument]: the list of mapped arguments
"""
positional: list[tuple[p.Expr, Type]] = [
(arg, self.evaluate(arg)) for arg in call.arguments
]
keywords: dict[str, tuple[p.Expr, Type]] = {
name: (arg, self.evaluate(arg)) for name, arg in call.keywords.items()
}
set_args: set[str] = set()
required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required
]
required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required
]
mapped: list[MappedArgument] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
self.error(arg[0].location, "Too many positional arguments")
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if name in set_args:
self.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
def join_args(args: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args))
if len(args) == 0:
return ""
if len(args) == 1:
return args[0]
return ", ".join(args[:-1]) + " and " + args[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
self.error(
call.location,
f"Missing required positional argument{plural}: {args}",
)
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
self.error(
call.location,
f"Missing required keyword argument{plural}: {args}",
)
return mapped
@property
def diagnostics(self) -> list[Diagnostic]:
return self.reporter.diagnostics

View File

@@ -1,6 +1,5 @@
from dataclasses import dataclass
from enum import StrEnum
from pathlib import Path
from typing import Optional
from midas.ast.location import Location
@@ -14,12 +13,13 @@ class DiagnosticType(StrEnum):
@dataclass(frozen=True)
class Diagnostic:
file_path: Path
file_path: Optional[str]
location: Location
type: DiagnosticType
message: str
def __str__(self) -> str:
@property
def location_str(self) -> str:
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
end_loc: Optional[str] = ""
if (
@@ -27,7 +27,16 @@ class Diagnostic:
and self.location.end_col_offset is not None
):
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
loc: str = (
f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}"
)
return f"{self.type} in {self.file_path} {loc}: {self.message}"
loc: str = ""
if self.file_path is not None:
loc += f" in {self.file_path}"
if end_loc is None:
loc += f" at {start_loc}"
else:
loc += f" from {start_loc} to {end_loc}"
return f"{self.type}{loc}"
def __str__(self) -> str:
return f"{self.location_str}: {self.message}"

206
midas/checker/midas.py Normal file
View File

@@ -0,0 +1,206 @@
import logging
from pathlib import Path
from typing import Optional
import midas.ast.midas as m
from midas.checker.builtins import define_builtins
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AliasType,
ComplexType,
ExtensionType,
Function,
GenericType,
Type,
TypeVar,
UnknownType,
)
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token
from midas.parser.midas import MidasParser
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
self.logger: logging.Logger = logging.getLogger("MidasTyper")
self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types
self._local_variables: dict[str, TypeVar] = {}
self._current_name: Optional[str] = None
define_builtins(self.types)
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
self.process(builtins_path.read_text(), str(builtins_path))
def process(self, source: str, path: Optional[str]):
self.reporter = self.reporter.for_file(path)
lexer: MidasLexer = MidasLexer(source)
tokens: list[Token] = lexer.process()
parser: MidasParser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
for error in parser.errors:
self.reporter.error(error.token.get_location(), error.message)
self.resolve(stmts)
def get_type(self, name: str) -> Type:
"""Get a type from its name
Args:
name (str): the name of the type
Raises:
NameError: if the type is not defined
Returns:
Type: the type
"""
if name in self._local_variables:
return self._local_variables[name]
return self.types.get_type(name)
def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements
Args:
stmts (list[m.Stmt]): the statements
"""
for stmt in stmts:
stmt.accept(self)
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
name: str = stmt.name.lexeme
self._current_name = name
params: list[TypeVar] = self._resolve_type_params(stmt.params)
type: Type = stmt.type.accept(self)
if len(params) != 0:
type = GenericType(name=name, params=params, body=type)
else:
type = AliasType(name=name, type=type)
self.types.define_type(name, type)
self._local_variables.clear()
self._current_name = None
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self._resolve_type_params(stmt.params)
base_name: str = stmt.name.lexeme
try:
_ = self.get_type(base_name)
except NameError:
self.reporter.error(stmt.name.get_location(), f"Unknown type '{base_name}'")
for member in stmt.members:
member_type: Type = member.type.accept(self)
self.types.define_member(
base_name,
member.name.lexeme,
member_type,
member.kind == m.MemberKind.METHOD,
)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
def visit_get_expr(self, expr: m.GetExpr) -> None:
self.reporter.warning(expr.location, "GetExpr not yet supported")
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
self.reporter.warning(expr.location, "VariableExpr not yet supported")
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
def visit_named_type(self, type: m.NamedType) -> Type:
name: str = type.name.lexeme
try:
return self.get_type(name)
except NameError:
msg: str = f"Undefined type {name}"
if self._current_name == name:
msg += ". Recursive types are not supported, use an extend block"
self.reporter.error(type.name.get_location(), msg)
return UnknownType()
def visit_generic_type(self, type: m.GenericType) -> Type:
type_: Type = type.type.accept(self)
args: list[Type] = [arg.accept(self) for arg in type.args]
try:
return self.types.apply_generic(type_, args)
except Exception as e:
self.reporter.error(type.location, f"Cannot apply generic type: {e}")
return UnknownType()
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
type_: Type = type.type.accept(self)
type.constraint.accept(self)
# TODO
return UnknownType()
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
return ComplexType(
members={
member.name.lexeme: member.type.accept(self) for member in type.members
}
)
def visit_extension_type(self, type: m.ExtensionType) -> Type:
return ExtensionType(
base=type.base.accept(self),
extension=self.visit_complex_type(type.extension),
)
def visit_function_type(self, type: m.FunctionType) -> Type:
n_pos_args: int = len(type.pos_args)
n_args: int = len(type.args)
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
return Function.Argument(
pos=i,
name=arg.name.lexeme if arg.name is not None else str(i),
type=arg.type.accept(self),
required=arg.required,
)
return Function(
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
kw_args=[
process_arg(arg, i + n_pos_args + n_args)
for i, arg in enumerate(type.kw_args)
],
returns=type.returns.accept(self),
)
def _resolve_type_params(self, params: list[m.TypeParam]):
vars: list[TypeVar] = []
for param in params:
name: str = param.name.lexeme
bound: Optional[Type] = None
if param.bound is not None:
bound = param.bound.accept(self)
var = TypeVar(name=name, bound=bound)
self._local_variables[name] = var
vars.append(var)
return vars

View File

@@ -29,3 +29,10 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
# ast.In: "__in__",
# ast.NotIn: "__notin__",
}
UNARY_METHODS: dict[Type[ast.unaryop], str] = {
ast.Invert: "__invert__",
# ast.Not: "",
ast.UAdd: "__pos__",
ast.USub: "__neg__",
}

859
midas/checker/python.py Normal file
View File

@@ -0,0 +1,859 @@
import ast
import logging
from dataclasses import dataclass
from typing import Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.environment import Environment
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter, Reporter
from midas.checker.resolver import Resolver
from midas.checker.types import (
Function,
OverloadedFunction,
Type,
UnitType,
UnknownType,
unfold_type,
)
from midas.parser.python import PythonParser
TypedExpr = tuple[p.Expr, Type]
class ReturnException(Exception):
pass
@dataclass(frozen=True, kw_only=True)
class MappedArgument:
expr: p.Expr
type: Type
argument: Function.Argument
@dataclass(frozen=True, kw_only=True)
class OverloadCandidate:
function: Function
mapped: list[MappedArgument]
class PythonTyper(
p.Stmt.Visitor[None],
p.Expr.Visitor[Type],
p.MidasType.Visitor[Type],
):
"""A type checker which can use custom type definitions"""
def __init__(
self,
types: TypesRegistry,
reporter: Reporter,
):
self.logger: logging.Logger = logging.getLogger("PythonTyper")
self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types
self.global_env: Environment = Environment()
self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = []
def process(self, source: str, path: Optional[str]):
self.reporter = self.reporter.for_file(path)
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
resolver = Resolver()
resolver.resolve(*stmts)
self.env = self.global_env
self.locals = resolver.locals
self.judgements = []
self.check(stmts)
def type_of(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression
Args:
expr (p.Expr): the expression to evaluate
Returns:
Type: the type of the given expression
"""
type: Type = expr.accept(self)
self.judgements.append((expr, type))
return type
def resolve_type_expr(self, expr: p.MidasType) -> Type:
return expr.accept(self)
def process_stmt(self, stmt: p.Stmt) -> None:
stmt.accept(self)
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
"""Evaluate a sequence of statements
Args:
block (list[p.Stmt]): the statements to evaluate
env (Environment): the environment in which to evaluate
Returns:
bool: whether a return statement is present in the block
"""
previous_env: Environment = self.env
self.env = env
returned: bool = False
for i, stmt in enumerate(block):
try:
self.process_stmt(stmt)
except ReturnException:
returned = True
if i < len(block) - 1:
self.reporter.warning(
block[i + 1].location, "Unreachable statement"
)
break
self.env = previous_env
return returned
def check(self, statements: list[p.Stmt]) -> None:
"""Type check a sequence of statements and returns diagnostics
Args:
statements (list[p.Stmt]): the statements to evaluate and check
"""
for stmt in statements:
self.process_stmt(stmt)
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
"""Look up a variable in the environment it was declared
Args:
name (str): the name of the variable
expr (p.Expr): the variable expression, used to lookup the scope distance
Returns:
Optional[Type]: the type of the variable, or None if it was not found
"""
distance: Optional[int] = self.locals.get(expr)
if distance is not None:
return self.env.get_at(distance, name)
return self.global_env.get(name)
def is_subtype(self, type1: Type, type2: Type) -> bool:
return self.types.is_subtype(type1, type2)
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
self.type_of(stmt.expr)
def visit_function(self, stmt: p.Function) -> None:
env: Environment = Environment(self.env)
pos_args: list[Function.Argument] = []
args: list[Function.Argument] = []
kw_args: list[Function.Argument] = []
def eval_arg_type(arg: p.Function.Argument) -> Type:
if arg.type is not None:
return self.resolve_type_expr(arg.type)
if arg.default is not None:
return self.type_of(arg.default)
return UnknownType()
pos: int = 0
for arg in stmt.posonlyargs:
pos_args.append(
Function.Argument(
pos=pos,
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
pos += 1
for arg in stmt.args:
args.append(
Function.Argument(
pos=pos,
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
pos += 1
for arg in stmt.kwonlyargs:
kw_args.append(
Function.Argument(
pos=pos, # not relevant
name=arg.name,
type=eval_arg_type(arg),
required=arg.default is None,
)
)
pos += 1
for arg in pos_args + args + kw_args:
env.define(arg.name, arg.type)
returns_hint: Optional[Type] = None
if stmt.returns is not None:
returns_hint = self.resolve_type_expr(stmt.returns)
# Early define to handle simple fully-typed recursion
inside_function: Function = Function(
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns_hint,
)
self.env.define(stmt.name, inside_function)
returned: bool = self.process_block(stmt.body, env)
inferred_return: Type = UnknownType()
if not returned:
env.return_types.append(UnitType())
return_types: list[Type] = self.types.reduce_types(env.return_types)
if len(return_types) == 1:
inferred_return = return_types[0]
elif len(return_types) > 1:
self.reporter.error(
stmt.location,
f"Mixed return types: {return_types}",
)
returns: Type = UnknownType()
if returns_hint is not None:
assert stmt.returns is not None
returns = returns_hint
if returns != inferred_return:
self.reporter.error(
stmt.returns.location,
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
)
else:
returns = inferred_return
# TODO: handle *args and **kwargs sinks
function: Function = Function(
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
)
self.env.define(stmt.name, function)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
# TODO check not yet defined locally
type: Type = self.resolve_type_expr(stmt.type)
self.env.define(stmt.name, type)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
value_type: Type = self.type_of(stmt.value)
for target in stmt.targets:
self._assign(stmt.location, target, value_type)
def _assign(self, location: Location, target: p.Expr, value_type: Type):
match target:
case p.VariableExpr():
self._assign_var(location, target, value_type)
case p.GetExpr(object=object, name=name):
self._assign_attr(location, object, name, value_type)
case _:
if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}")
self.reporter.warning(
target.location, f"Unsupported assignment to {target}"
)
def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type):
name: str = target.name
var_type: Optional[Type] = self.look_up_variable(name, target)
if var_type is None:
self.env.define(name, value_type)
else:
# S <: T
# Γ, x: T v: S
# x = v
if not self.is_subtype(value_type, var_type):
self.reporter.error(
location,
f"Cannot assign {value_type} to variable '{name}' of type {var_type}",
)
def _assign_attr(
self, location: Location, object: p.Expr, name: str, value_type: Type
):
object_type: Type = self.type_of(object)
member: Optional[Type] = self.types.lookup_member(object_type, name)
if member is None:
self.reporter.error(location, f"Unknown member '{name}' of {object_type}")
return
self.logger.debug(f"Member '{name}' of {object_type} has type {member}")
if not self.is_subtype(value_type, member):
self.reporter.error(
location,
f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
self.env.return_types.append(type)
raise ReturnException()
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
# Not evaluated in sub-environment because assignments in the test leak out of the if
# For example:
# if (m := 1 + 1) < 2:
# ...
# print(m) # <- m is still defined
test_type: Type = self.type_of(stmt.test)
# TODO Allow subtypes or any type
if test_type != self.types.get_type("bool"):
self.reporter.error(
stmt.test.location, f"If test must be a boolean, got {test_type}"
)
env: Environment = Environment(self.env)
body_returned: bool = self.process_block(stmt.body, env)
else_returned: bool = self.process_block(stmt.orelse, env)
self.env.return_types.extend(env.return_types)
if body_returned and else_returned:
raise ReturnException()
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator}"
)
return UnknownType()
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator}"
)
return UnknownType()
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
def _visit_binary_expr(
self, location: Location, left_expr: p.Expr, right_expr: p.Expr, method: str
) -> Type:
left: Type = self.type_of(left_expr)
right: Type = self.type_of(right_expr)
operation: Optional[Type] = self.types.lookup_member(left, method)
if operation is None:
self.reporter.error(
location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
return self._get_call_result(location, operation, [(right_expr, right)], {})
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
if method is None:
self.logger.warning(f"Unsupported operator {expr.operator}")
self.reporter.warning(
expr.location, f"Unsupported operator {expr.operator}"
)
return UnknownType()
operand: Type = self.type_of(expr.right)
operation: Optional[Type] = self.types.lookup_member(operand, method)
if operation is None:
self.reporter.error(
expr.location,
f"Undefined operation {method} for {operand}",
)
return UnknownType()
return self._get_call_result(
expr.location, operation, [(expr.right, operand)], {}
)
def visit_call_expr(self, expr: p.CallExpr) -> Type:
callee: Type = self.type_of(expr.callee)
positional: list[TypedExpr] = [
(arg, self.type_of(arg)) for arg in expr.arguments
]
keywords: dict[str, TypedExpr] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
}
return self._get_call_result(
location=expr.location,
callee=callee,
positional=positional,
keywords=keywords,
)
def visit_get_expr(self, expr: p.GetExpr) -> Type:
object: Type = self.type_of(expr.object)
member: Optional[Type] = self.types.lookup_member(object, expr.name)
if member is None:
self.reporter.error(
expr.location, f"Unknown member '{expr.name}' of {object}"
)
return UnknownType()
self.logger.debug(f"Member '{expr.name}' of {object} has type {member}")
return member
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
match expr.value:
case bool(): # Must be before int
return self.types.get_type("bool")
case int():
return self.types.get_type("int")
case float():
return self.types.get_type("float")
case str():
return self.types.get_type("str")
case _:
self.reporter.warning(expr.location, f"Unknown literal {expr}")
return UnknownType()
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
type: Optional[Type] = self.look_up_variable(expr.name, expr)
if type is None:
self.logger.debug(f"Unknown variable {expr.name} in {self.env.flat_dict()}")
self.reporter.warning(expr.location, "Unknown variable")
return type or UnknownType()
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
left: Type = self.type_of(expr.left)
right: Type = self.type_of(expr.right)
if self.is_subtype(left, right):
return right
if self.is_subtype(right, left):
return left
self.reporter.error(
expr.location,
f"Incompatible operand types, {left=} and {right=}",
)
return UnknownType()
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
return self.resolve_type_expr(expr.type)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
test_type: Type = self.type_of(expr.test)
# TODO Allow subtypes or any type
if test_type != self.types.get_type("bool"):
self.reporter.error(
expr.test.location, f"If test must be a boolean, got {test_type}"
)
true_type: Type = self.type_of(expr.if_true)
false_type: Type = self.type_of(expr.if_false)
if self.is_subtype(true_type, false_type):
return false_type
if self.is_subtype(false_type, true_type):
return true_type
self.reporter.error(
expr.location,
f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
)
return UnknownType()
def visit_list_expr(self, expr: p.ListExpr) -> Type:
list_type: Type = self.types.get_type("list")
item_types: list[Type] = [self.type_of(item) for item in expr.items]
item_types = self.types.reduce_types(item_types)
if len(item_types) == 0:
return list_type
if len(item_types) == 1:
item_type: Type = item_types[0]
return self.types.apply_generic(list_type, [item_type])
self.reporter.error(
expr.location,
f"Heterogeneous list items: {item_types}",
)
return self.types.apply_generic(list_type, [UnknownType()])
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
object: Type = self.type_of(expr.object)
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
if operation is None:
self.reporter.error(
expr.location,
f"Undefined method __getitem__ on {object}",
)
return UnknownType()
index: Type = self.type_of(expr.index)
return self._get_call_result(
expr.location, operation, [(expr.index, index)], {}
)
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
return self.types.get_type("slice")
def visit_base_type(self, node: p.BaseType) -> Type:
base: Type
try:
base = self.types.get_type(node.base)
except NameError:
self.reporter.warning(node.location, f"Unknown type '{node.base}'")
return UnknownType()
if node.param is not None:
param: Type = self.resolve_type_expr(node.param)
return self.types.apply_generic(base, [param])
return base
def visit_constraint_type(self, node: p.ConstraintType) -> Type:
self.reporter.warning(node.location, "ConstraintType not yet supported")
return UnknownType()
def visit_frame_column(self, node: p.FrameColumn) -> Type:
self.reporter.warning(node.location, "FrameColumn not yet supported")
return UnknownType()
def visit_frame_type(self, node: p.FrameType) -> Type:
self.reporter.warning(node.location, "FrameType not yet supported")
return UnknownType()
def _get_call_result(
self,
location: Location,
callee: Type,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
"""Get the result type of a function call
If the function has overloads, the function will try to resolve the
appropriate signature.
Argument types are matched to the defined parameters.
The function doesn't take the raw expression as a parameter to accomodate
for desugared calls such as for operators.
Args:
location (Location): the call location
callee (Type): the called function
positional (list[TypedExpr]): the list positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
Returns:
Type: the return type of the call, or `UnknownType` if either
the call is invalid or no overload matched the arguments uniquely
"""
match callee:
case Function() as function:
valid: bool
mapped: list[MappedArgument]
valid, mapped = self.map_call_arguments(
function, location, positional, keywords
)
valid = valid and self._are_arguments_valid(mapped)
if not valid:
return UnknownType()
return function.returns
case OverloadedFunction(overloads=overloads):
function = self._match_overload(
overloads, location, positional, keywords
)
if function is None:
return UnknownType()
return function.returns
case _:
self.reporter.error(location, f"{callee} is not callable")
return UnknownType()
def _are_arguments_valid(
self,
arguments: list[MappedArgument],
report_errors: bool = True,
) -> bool:
"""Check whether the passed argument types correspond to their matched parameter definitions
Args:
arguments (list[MappedArgument]): the list of argument/parameter pairs
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
bool: True if all arguments fit the matching parameter definitions, False otherwise
"""
valid: bool = True
for arg in arguments:
if not self.is_subtype(arg.type, arg.argument.type):
if report_errors:
self.reporter.error(
arg.expr.location,
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
)
valid = False
return valid
def _match_overload(
self,
overloads: list[Type],
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Optional[Function]:
"""Try and resolve the appropriate overload for the given arguments
Args:
overloads (list[Type]): the list of possible overloads
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keywords arguments
Returns:
Optional[Function]: the resolved function signature if it can be
determined unambigously, or `None`.
"""
candidates: list[OverloadCandidate] = []
for overload in overloads:
function: Type = unfold_type(overload)
if not isinstance(function, Function):
self.logger.error(
f"Overload is not a function: {overload} is {function}"
)
continue
valid, mapped = self.map_call_arguments(
function=function,
location=location,
positional=positional,
keywords=keywords,
report_errors=False,
)
if valid and self._are_arguments_valid(mapped, report_errors=False):
candidates.append(
OverloadCandidate(
function=function,
mapped=mapped,
)
)
pos_types: str = ", ".join(str(type) for _, type in positional)
kw_types: str = ", ".join(
f"{name}: {type}" for name, (_, type) in keywords.items()
)
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
n_candidates: int = len(candidates)
# Exactly 1 match -> return it
if n_candidates == 1:
return candidates[0].function
# No match -> invalid call
if n_candidates == 0:
overloads_str: str = ", ".join(map(str, overloads))
self.reporter.error(
location,
f"No matching overload in [{overloads_str}] {for_args}",
)
return None
# Multiple matches -> see if one <: all others (more specific)
for i1, c1 in enumerate(candidates):
mapped1: list[MappedArgument] = c1.mapped
best_match: bool = True
for i2, c2 in enumerate(candidates):
if i1 == i2:
continue
mapped2: list[MappedArgument] = c2.mapped
if not self._are_mapped_subtypes(mapped1, mapped2):
best_match = False
break
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
if best_match:
return c1.function
candidates_str: str = ", ".join(
str(candidate.function) for candidate in candidates
)
self.reporter.error(
location,
f"Multiple matching overloads {for_args}: {candidates_str}",
)
return None
def map_call_arguments(
self,
function: Function,
location: Location,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
report_errors: bool = True,
) -> tuple[bool, list[MappedArgument]]:
"""Map call arguments to a function's parameters as defined in its signature
This method maps positional-only, keyword-only and mixed parameter definitions
with the arguments passed at the call site
Any mismatched, missing or unexpected argument is reported as a diagnostic,
unless `report_errors` is set to `False`
Args:
function (Function): the function definition
location (Location): the call location
positional (list[TypedExpr]): the list of positional arguments
keywords (dict[str, TypedExpr]): the map of keyword arguments
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
Returns:
tuple[bool, list[MappedArgument]]: a boolean reporting whether
the call is valid and the list of mapped arguments
"""
set_args: set[str] = set()
required_positional: list[str] = [
arg.name for arg in function.pos_args + function.args if arg.required
]
required_keyword: list[str] = [
arg.name for arg in function.kw_args if arg.required
]
mapped: list[MappedArgument] = []
pos_params: list[Function.Argument] = list(function.pos_args)
mixed_params: list[Function.Argument] = list(function.args)
kw_params: dict[str, Function.Argument] = {
arg.name: arg for arg in function.kw_args
}
valid_call: bool = True
# TODO: handle *args and **kwargs sinks
for arg in positional:
param: Function.Argument
if len(pos_params) != 0:
param = pos_params.pop(0)
elif len(mixed_params) != 0:
param = mixed_params.pop(0)
else:
if report_errors:
self.reporter.error(
arg[0].location, "Too many positional arguments"
)
valid_call = False
break
name: str = param.name
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
kw_params.update({arg.name: arg for arg in mixed_params})
for name, arg in keywords.items():
param: Function.Argument
if name not in kw_params:
if report_errors:
if name in set_args:
self.reporter.error(
arg[0].location, f"Multiple values for argument '{name}'"
)
else:
self.reporter.error(
arg[0].location, f"Unknown keyword argument '{name}'"
)
valid_call = False
continue
param = kw_params.pop(name)
if name in required_positional:
required_positional.remove(name)
if name in required_keyword:
required_keyword.remove(name)
set_args.add(name)
mapped.append(
MappedArgument(
expr=arg[0],
type=arg[1],
argument=param,
)
)
def join_args(args: list[str]) -> str:
args = list(map(lambda a: f"'{a}'", args))
if len(args) == 0:
return ""
if len(args) == 1:
return args[0]
return ", ".join(args[:-1]) + " and " + args[-1]
if len(required_positional) != 0:
plural: str = "" if len(required_positional) == 1 else "s"
args: str = join_args(required_positional)
if report_errors:
self.reporter.error(
location,
f"Missing required positional argument{plural}: {args}",
)
valid_call = False
if len(required_keyword) != 0:
plural: str = "" if len(required_keyword) == 1 else "s"
args: str = join_args(required_keyword)
if report_errors:
self.reporter.error(
location,
f"Missing required keyword argument{plural}: {args}",
)
valid_call = False
return valid_call, mapped
def _are_mapped_subtypes(
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
) -> bool:
"""Check whether the given argument mappings are subtype/supertype of one another
This function checks whether the argument mappings `mapped1` are subtypes
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
of the corresponding parameter in `mapped2`, `False` is returned.
This is used to check whether a given overload is
a more specific function/ a subtype of another.
Args:
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
Returns:
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
"""
by_expr: dict[p.Expr, Type] = {}
for arg in mapped1:
by_expr[arg.expr] = arg.argument.type
for arg in mapped2:
type2: Type = arg.argument.type
type1: Type = by_expr[arg.expr]
if not self.is_subtype(type1, type2):
return False
return True

347
midas/checker/registry.py Normal file
View File

@@ -0,0 +1,347 @@
import logging
from typing import Optional
from midas.checker.builtins import BUILTIN_SUBTYPES
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ComplexType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
Type,
TypeVar,
UnknownType,
substitute_typevars,
)
class TypesRegistry:
def __init__(self) -> None:
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
self._types: dict[str, Type] = {}
self._members: dict[str, dict[str, Type]] = {}
def get_type(self, name: str) -> Type:
"""Get a type from its name
Args:
name (str): the name of the type
Raises:
NameError: if the type is not defined
Returns:
Type: the type
"""
if name in self._types:
return self._types[name]
raise NameError(f"Undefined type {name}")
def define_type(self, name: str, type: Type) -> Type:
"""Define a type in the registry
Args:
name (str): the name of the type
type (Type): the type to define
Raises:
ValueError: if a type is already defined with that name
Returns:
Type: the defined type
"""
if name in self._types:
raise ValueError(f"Type {name} already defined")
self._types[name] = type
return type
def define_member(
self, type_name: str, member_name: str, member_type: Type, is_method: bool
):
members: dict[str, Type] = self._members.setdefault(type_name, {})
if member_name in members:
if not is_method:
self.logger.error(
f"Member '{member_name}' already defined for type {type_name}"
)
return
current: Type = members[member_name]
combined: Type
match current:
case OverloadedFunction(overloads=overloads):
combined = OverloadedFunction(overloads=overloads + [member_type])
case _:
combined = OverloadedFunction(overloads=[current, member_type])
members[member_name] = combined
else:
members[member_name] = member_type
def is_subtype(self, type1: Type, type2: Type) -> bool:
"""Check whether `type1` is a subtype of `type2`
For more details on the rules checked here, see TAPL Chap. 15-16-17
Args:
type1 (Type): the potential subtype
type2 (Type): the potential supertype
Returns:
bool: whether `type1` is a subtype of `type2`
"""
if type1 == type2:
return True
match (type1, type2):
case (_, TopType()):
return True
case (AliasType(type=base1), _):
return self.is_subtype(base1, type2)
case (BaseType(name=name1), BaseType(name=name2)):
return name1 in BUILTIN_SUBTYPES.get(name2, set())
case (ComplexType(properties=props1), ComplexType(properties=props2)):
for k, t in props2.items():
if k not in props1:
return False
if not self.is_subtype(props1[k], t):
return False
return True
case (Function(), Function()):
return self.is_func_subtype(type1, type2)
case (TypeVar(bound=bound), _):
if bound is None:
return False
return self.is_subtype(bound, type2)
return False
# TODO: verify the logic in here
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
"""Check whether a function is a subtype of another
Args:
func1 (Function): the potential function subtype
func2 (Function): the potential function supertype
Returns:
bool: whether `func1` is a subtype of `func2`
"""
if not self.is_subtype(func1.returns, func2.returns):
return False
pos1: list[Function.Argument] = func1.pos_args
mixed1: list[Function.Argument] = func1.args
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
pos2: list[Function.Argument] = func2.pos_args
mixed2: list[Function.Argument] = func2.args
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
if not self.is_subtype(sub.type, sup.type):
return False
if not sup.required and sub.required:
return False
return True
for arg1 in pos1:
arg2: Function.Argument
if arg1.pos < len(pos2):
arg2 = pos2[arg1.pos]
elif arg1.pos in mixed_by_pos:
arg2 = mixed_by_pos[arg1.pos]
elif not arg1.required:
continue
else:
return False
if not is_arg_subtype(arg2, arg1):
return False
for name, arg1 in kw1.items():
arg2: Function.Argument
if name in kw2:
arg2 = kw2[name]
elif name in mixed_by_name:
arg2 = mixed_by_name[name]
elif not arg1.required:
continue
else:
return False
if not is_arg_subtype(arg2, arg1):
return False
for arg1 in mixed1:
pos_arg2: Optional[Function.Argument] = None
kw_arg2: Optional[Function.Argument] = None
if arg1.name in kw2:
kw_arg2 = kw2[arg1.name]
elif arg1.name in mixed_by_name:
kw_arg2 = mixed_by_name[arg1.name]
if arg1.pos < len(pos2):
pos_arg2 = pos2[arg1.pos]
elif arg1.pos in mixed_by_pos:
pos_arg2 = mixed_by_pos[arg1.pos]
# No match in func2 and arg is required
if pos_arg2 is None and kw_arg2 is None and arg1.required:
return False
# Matching keyword argument
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
return False
# Matching positional argument
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
return False
mixed_positions: set[int] = {a.pos for a in mixed1}
mixed_names: set[str] = {a.name for a in mixed1}
for arg2 in pos2:
if not arg2.required:
continue
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
return False
for name, arg2 in kw2.items():
if not arg2.required:
continue
if name not in kw1 and name not in mixed_names:
return False
for arg2 in mixed2:
if arg2.required:
continue
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
if not pos_match or not kw_match:
return False
return True
def apply_generic(self, type: Type, args: list[Type]) -> Type:
match type:
case AliasType(name=name, type=base):
return AliasType(name=name, type=self.apply_generic(base, args))
case GenericType(name=name, params=type_vars, body=body):
n_args: int = len(args)
n_type_vars: int = len(type_vars)
if n_args < n_type_vars:
raise ValueError(
f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
)
if n_args > n_type_vars:
raise ValueError(
f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
)
substitutions: dict[str, Type] = {}
for arg, type_var in zip(args, type_vars):
if type_var.bound is not None and not self.is_subtype(
arg, type_var.bound
):
raise ValueError(
f"Type argument {arg} is not a subtype of {type_var.bound}"
)
substitutions[type_var.name] = arg
return AppliedType(
name=name,
args=args,
body=substitute_typevars(body, substitutions),
)
case _:
raise ValueError(f"{type} is not a generic type")
def reduce_types(self, types: list[Type]) -> list[Type]:
"""Reduce a list of types to remove subtypes and only keep the highest types
Args:
types (list[Type]): the types to reduce
Returns:
list[Type]: the reduced list of types
"""
reduced: bool = True
keep: list[int] = list(range(len(types)))
while reduced:
reduced = False
for i, i1 in enumerate(keep):
type1: Type = types[i1]
for i2 in keep[i + 1 :]:
type2 = types[i2]
if self.is_subtype(type1, type2):
keep.remove(i1)
elif self.is_subtype(type2, type1):
keep.remove(i2)
else:
continue
reduced = True
break
return [types[i] for i in keep]
def lookup_member(self, type: Type, member_name: str) -> Optional[Type]:
match type:
case BaseType(name=name):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name]
return None
case AliasType(name=name, type=base):
if name in self._members:
if member_name in self._members[name]:
return self._members[name][member_name]
return self.lookup_member(base, member_name)
case AppliedType(name=name, body=body, args=args):
generic: Type = self.get_type(name)
if not isinstance(generic, GenericType):
raise ValueError("AppliedType not derived from a GenericType")
substitutions = {
type_var.name: arg for arg, type_var in zip(args, generic.params)
}
if name in self._members:
if member_name in self._members[name]:
member_type: Type = self._members[name][member_name]
return substitute_typevars(member_type, substitutions)
member_type2: Optional[Type] = self.lookup_member(body, member_name)
if member_type2 is not None:
member_type2 = substitute_typevars(member_type2, substitutions)
return member_type2
case ComplexType(members=members):
if member_name in members:
return members[member_name]
self.logger.debug(f"No member '{member_name}' in {type}")
return None
case ExtensionType(base=base, extension=ComplexType(members=members)):
if member_name in members:
return members[member_name]
self.logger.debug(
f"No member '{member_name}' on {type}, looking up in base"
)
return self.lookup_member(base, member_name)
case UnknownType():
return UnknownType()
case _:
self.logger.debug(f"Can't get member on {type}")
return None

63
midas/checker/reporter.py Normal file
View File

@@ -0,0 +1,63 @@
from __future__ import annotations
from typing import Optional
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic, DiagnosticType
class Reporter:
def __init__(self):
self.diagnostics: list[Diagnostic] = []
def report(
self,
path: Optional[str],
type: DiagnosticType,
location: Location,
message: str,
):
self.diagnostics.append(
Diagnostic(
file_path=path,
location=location,
type=type,
message=message,
)
)
def for_file(self, path: Optional[str]) -> FileReporter:
return FileReporter(self, path)
class FileReporter:
def __init__(self, base_reporter: Reporter, path: Optional[str]) -> None:
self.base_reporter: Reporter = base_reporter
self.path: Optional[str] = path
def for_file(self, path: Optional[str]) -> FileReporter:
return FileReporter(self.base_reporter, path)
def report(self, type: DiagnosticType, location: Location, message: str):
self.base_reporter.report(self.path, type, location, message)
def error(self, location: Location, message: str):
self.report(
type=DiagnosticType.ERROR,
location=location,
message=message,
)
def warning(self, location: Location, message: str):
self.report(
type=DiagnosticType.WARNING,
location=location,
message=message,
)
def info(self, location: Location, message: str):
self.report(
type=DiagnosticType.INFO,
location=location,
message=message,
)

View File

@@ -13,7 +13,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
def __init__(self):
self.locals: dict[p.Expr, int] = {}
self.scopes: list[dict[str, bool]] = []
self.scopes: list[dict[str, bool]] = [{}]
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
"""Resolve the given statements or expressions"""
@@ -77,6 +77,12 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.locals[expr] = i
return
def is_defined(self, name: str) -> bool:
for scope in self.scopes:
if name in scope:
return True
return False
def resolve_function(self, function: p.Function) -> None:
"""Resolve a function definition
@@ -112,8 +118,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
for target in stmt.targets:
match target:
case p.VariableExpr(name=name):
self.resolve_local(target, name)
# TODO: declare if not found
if not self.is_defined(name):
self.declare(name)
self.define(name)
target.accept(self)
case p.GetExpr():
target.accept(self)
case _:
raise Exception(f"Unsupported assignment to {target}")
@@ -121,6 +132,24 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
if stmt.value is not None:
self.resolve(stmt.value)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
# Not resolved in sub-environment because assignments in the test leak out of the if
# For example:
# if (m := 1 + 1) < 2:
# ...
# print(m) # <- m is still defined
self.resolve(stmt.test)
# Body
self.begin_scope()
self.resolve(*stmt.body)
self.end_scope()
# Else
self.begin_scope()
self.resolve(*stmt.orelse)
self.end_scope()
def visit_binary_expr(self, expr: p.BinaryExpr) -> None:
self.resolve(expr.left)
self.resolve(expr.right)
@@ -156,9 +185,26 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
self.resolve(expr.left)
self.resolve(expr.right)
def visit_set_expr(self, expr: p.SetExpr) -> None:
self.resolve(expr.value)
self.resolve(expr.object)
def visit_cast_expr(self, expr: p.CastExpr) -> None:
self.resolve(expr.expr)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self.resolve(expr.test)
self.resolve(expr.if_true)
self.resolve(expr.if_false)
def visit_list_expr(self, expr: p.ListExpr) -> None:
for item in expr.items:
self.resolve(item)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
self.resolve(expr.object)
self.resolve(expr.index)
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
if expr.lower is not None:
self.resolve(expr.lower)
if expr.upper is not None:
self.resolve(expr.upper)
if expr.step is not None:
self.resolve(expr.step)

View File

@@ -1,42 +1,233 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
@dataclass(frozen=True, kw_only=True)
class TopType:
def __str__(self) -> str:
return "Any"
@dataclass(frozen=True, kw_only=True)
class BaseType:
name: str
def __str__(self) -> str:
return self.name
@dataclass(frozen=True, kw_only=True)
class SimpleType:
class AliasType:
name: str
base: BaseType | SimpleType
type: Type
def __str__(self) -> str:
return self.name
@dataclass(frozen=True, kw_only=True)
class UnknownType:
pass
def __str__(self) -> str:
return "<Unknown>"
@dataclass(frozen=True, kw_only=True)
class UnitType:
pass
def __str__(self) -> str:
return "None"
@dataclass(frozen=True, kw_only=True)
class Function:
name: str
pos_args: list[Argument]
args: list[Argument]
kw_args: list[Argument]
returns: Type
def __str__(self) -> str:
args: list[str] = []
if len(self.pos_args) != 0:
args += list(map(str, self.pos_args))
if len(self.args) + len(self.kw_args) != 0:
args.append("/")
if len(self.args) != 0:
args += list(map(str, self.args))
if len(self.kw_args) != 0:
if len(args) != 0:
args.append("*")
args += list(map(str, self.kw_args))
return f"({', '.join(args)}) -> {self.returns}"
@dataclass(frozen=True, kw_only=True)
class Argument:
pos: int
name: str
type: Type
required: bool
def __str__(self) -> str:
opt: str = "" if self.required else "?"
return f"{self.name}: {self.type}{opt}"
Type = BaseType | SimpleType | UnknownType | UnitType | Function
@dataclass(frozen=True, kw_only=True)
class OverloadedFunction:
overloads: list[Type]
def __str__(self) -> str:
return "<overloaded function>"
@dataclass(frozen=True, kw_only=True)
class ComplexType:
members: dict[str, Type]
def __str__(self) -> str:
props: list[str] = [f"{name}: {type}" for name, type in self.members.items()]
return f"{{{', '.join(props)}}}"
@dataclass(frozen=True, kw_only=True)
class ExtensionType:
base: Type
extension: ComplexType
def __str__(self) -> str:
return f"{self.base} & {self.extension}"
@dataclass(frozen=True, kw_only=True)
class TypeVar:
name: str
bound: Optional[Type]
def __str__(self) -> str:
if self.bound is not None:
return f"{self.name} <: {self.bound}"
return self.name
@dataclass(frozen=True, kw_only=True)
class GenericType:
name: str
params: list[TypeVar]
body: Type
def __str__(self) -> str:
return f"{self.name}[{', '.join(map(str, self.params))}]"
@dataclass(frozen=True, kw_only=True)
class AppliedType:
name: str
args: list[Type]
body: Type
def __str__(self) -> str:
return f"{self.name}[{', '.join(map(str, self.args))}]"
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument):
return Function.Argument(
pos=arg.pos,
name=arg.name,
type=substitute_typevars(arg.type, substitutions),
required=arg.required,
)
match type:
case BaseType(name=name) if name in substitutions:
return substitutions[name]
case BaseType():
return type
case AliasType(name=name, type=type2):
return AliasType(name=name, type=substitute_typevars(type2, substitutions))
case Function(
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
):
return Function(
pos_args=list(map(sub_argument, pos_args)),
args=list(map(sub_argument, args)),
kw_args=list(map(sub_argument, kw_args)),
returns=substitute_typevars(returns, substitutions),
)
case OverloadedFunction(overloads=overloads):
return OverloadedFunction(
overloads=[
substitute_typevars(overload, substitutions)
for overload in overloads
]
)
case ComplexType(members=members):
members2: dict[str, Type] = {
name: substitute_typevars(prop, substitutions)
for name, prop in members.items()
}
return ComplexType(members=members2)
case ExtensionType(base=base, extension=ComplexType(members=members)):
return ExtensionType(
base=substitute_typevars(base, substitutions),
extension=ComplexType(
members={
name: substitute_typevars(prop, substitutions)
for name, prop in members.items()
}
),
)
case AppliedType(name=name, args=args, body=body):
return AppliedType(
name=name,
args=[substitute_typevars(arg, substitutions) for arg in args],
body=substitute_typevars(body, substitutions),
)
case TypeVar(name=name):
if name in substitutions:
return substitutions[name]
raise ValueError(f"Missing TypeVar substitution for {name}")
case UnknownType() | UnitType():
return type
case _:
raise NotImplementedError(f"Unsupported type {type}")
def unfold_type(type: Type) -> Type:
match type:
case AliasType(type=ref_type):
return unfold_type(ref_type)
case _:
return type
Type = (
TopType
| BaseType
| AliasType
| UnknownType
| UnitType
| Function
| OverloadedFunction
| ComplexType
| ExtensionType
| TypeVar
| GenericType
| AppliedType
)

41
midas/cli/ansi.py Normal file
View File

@@ -0,0 +1,41 @@
class Ansi:
CTRL = "\x1b["
RESET = CTRL + "0m"
BOLD = CTRL + "1m"
DIM = CTRL + "2m"
ITALIC = CTRL + "3m"
UNDERLINE = CTRL + "4m"
BLACK = 0
RED = 1
GREEN = 2
YELLOW = 3
BLUE = 4
MAGENTA = 5
CYAN = 6
WHITE = 7
BRIGHT_BLACK = 60
BRIGHT_RED = 61
BRIGHT_GREEN = 62
BRIGHT_YELLOW = 63
BRIGHT_BLUE = 64
BRIGHT_MAGENTA = 65
BRIGHT_CYAN = 66
BRIGHT_WHITE = 67
@classmethod
def FG(cls, col: int) -> str:
return f"{cls.CTRL}{30 + col}m"
@classmethod
def BG(cls, col: int) -> str:
return f"{cls.CTRL}{40 + col}m"
@classmethod
def FG_RGB(cls, r: int, g: int, b: int) -> str:
return f"{cls.CTRL}38;2;{r};{g};{b}m"
@classmethod
def BG_RGB(cls, r: int, g: int, b: int) -> str:
return f"{cls.CTRL}48;2;{r};{g};{b}m"

View File

@@ -53,5 +53,6 @@ span {
&.keyword {
color: rgb(211, 72, 9);
pointer-events: none;
}
}

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Generic, Optional, Protocol, TextIO, TypeVar
@@ -8,6 +9,7 @@ import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.diagnostic import Diagnostic
from midas.lexer.token import Token
H = TypeVar("H", bound="Highlighter", contravariant=True)
@@ -22,6 +24,15 @@ class Locatable(Protocol):
def location(self) -> Optional[Location]: ...
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
class Highlighter(ABC):
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
EXTRA_CSS_PATH: Optional[Path] = None
@@ -148,6 +159,8 @@ class PythonHighlighter(
self.wrap(stmt, "function")
for arg in stmt.posonlyargs + stmt.args + stmt.kwonlyargs:
self._highlight_function_argument(arg)
for body_stmt in stmt.body:
body_stmt.accept(self)
def _highlight_function_argument(self, arg: p.Function.Argument) -> None:
self.wrap(arg, "argument")
@@ -157,9 +170,23 @@ class PythonHighlighter(
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
stmt.type.accept(self)
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None: ...
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
for target in stmt.targets:
target.accept(self)
stmt.value.accept(self)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None: ...
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
self.wrap(stmt, "return")
if stmt.value is not None:
stmt.value.accept(self)
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
self.wrap(stmt, "if")
stmt.test.accept(self)
for body_stmt in stmt.body:
body_stmt.accept(self)
for else_stmt in stmt.orelse:
else_stmt.accept(self)
def visit_binary_expr(self, expr: p.BinaryExpr) -> None: ...
@@ -167,7 +194,13 @@ class PythonHighlighter(
def visit_unary_expr(self, expr: p.UnaryExpr) -> None: ...
def visit_call_expr(self, expr: p.CallExpr) -> None: ...
def visit_call_expr(self, expr: p.CallExpr) -> None:
self.wrap(expr, "call")
expr.callee.accept(self)
for arg in expr.arguments:
arg.accept(self)
for arg in expr.keywords.values():
arg.accept(self)
def visit_get_expr(self, expr: p.GetExpr) -> None: ...
@@ -177,59 +210,55 @@ class PythonHighlighter(
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
def visit_set_expr(self, expr: p.SetExpr) -> None: ...
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
def visit_list_expr(self, expr: p.ListExpr) -> None:
for item in expr.items:
item.accept(self)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
expr.object.accept(self)
expr.index.accept(self)
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
if expr.lower is not None:
expr.lower.accept(self)
if expr.upper is not None:
expr.upper.accept(self)
if expr.step is not None:
expr.step.accept(self)
class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
def highlight(self, node: Highlightable[MidasHighlighter]):
node.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
self.wrap(stmt, "simple-type")
if stmt.template is not None:
stmt.template.accept(self)
stmt.base.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None:
self.wrap(stmt, "complex-type")
if stmt.template is not None:
stmt.template.accept(self)
for prop in stmt.properties:
prop.accept(self)
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
self.wrap(stmt, "property")
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
self.wrap(stmt, "type-stmt")
self.wrap(LocatableToken(stmt.name), "type-name")
stmt.type.accept(self)
def visit_member_stmt(self, stmt: m.MemberStmt) -> None:
self.wrap(stmt, "member")
stmt.type.accept(self)
if stmt.constraint is not None:
self.wrap(stmt.constraint, "constraint")
stmt.constraint.accept(self)
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
self.wrap(stmt, "extend")
stmt.type.accept(self)
for op in stmt.operations:
op.accept(self)
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
self.wrap(stmt, "op")
stmt.operand.accept(self)
stmt.result.accept(self)
for member in stmt.members:
member.accept(self)
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate")
self.wrap(LocatableToken(stmt.name), "predicate-name")
stmt.type.accept(self)
stmt.condition.accept(self)
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None:
self.wrap(expr, "simple-type-expr")
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.wrap(expr, "logical-expr")
expr.left.accept(self)
@@ -258,14 +287,35 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
self.wrap(expr, "template")
expr.type.accept(self)
def visit_named_type(self, type: m.NamedType) -> None:
self.wrap(type, "named-type")
def visit_type_expr(self, expr: m.TypeExpr) -> None:
self.wrap(expr, "type")
if expr.template is not None:
expr.template.accept(self)
def visit_generic_type(self, type: m.GenericType) -> None:
self.wrap(type, "generic-type")
type.type.accept(self)
for arg in type.args:
arg.accept(self)
def visit_constraint_type(self, type: m.ConstraintType) -> None:
self.wrap(type, "constraint-type")
type.type.accept(self)
type.constraint.accept(self)
def visit_complex_type(self, type: m.ComplexType) -> None:
self.wrap(type, "complex-type")
for member in type.members:
member.accept(self)
def visit_function_type(self, type: m.FunctionType) -> None:
self.wrap(type, "function")
for arg in type.pos_args + type.args + type.kw_args:
arg.type.accept(self)
type.returns.accept(self)
def visit_extension_type(self, type: m.ExtensionType) -> None:
self.wrap(type, "extension")
type.base.accept(self)
type.extension.accept(self)
class DiagnosticsHighlighter(Highlighter):

View File

@@ -1,4 +1,6 @@
span {
--opacity: 0.4;
&.error {
--col: 255, 0, 0;
}
@@ -11,9 +13,14 @@ span {
&.with-msg {
position: relative;
&:not(:hover) {
.message {
display: none;
}
&:hover:not(:has(.with-msg:hover)) {
.message {
display: none;
display: inline-block;
}
}

View File

@@ -5,12 +5,11 @@ span {
font-style: italic;
}
&.simple-type {
--col: 108, 233, 108;
}
&.named-type,
&.generic-type,
&.constraint-type,
&.complex-type {
--col: 233, 206, 108;
--col: 150, 150, 150;
}
&.constraint {
@@ -33,10 +32,6 @@ span {
--col: 193, 108, 233;
}
&.simple-type-expr {
--col: 150, 150, 150;
}
&.logical-expr,
&.binary-expr,
&.unary-expr,
@@ -48,7 +43,9 @@ span {
--col: 163, 117, 71;
}
&.type {
&.type-name,
&.op-name,
&.predicate-name {
--col: 200, 200, 200;
font-weight: bold;
}

View File

@@ -1,56 +1,150 @@
import ast
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, TextIO
from typing import Optional, TextIO, get_args
import click
import midas.ast.midas as m
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
from midas.checker.checker import Checker
from midas.checker.diagnostic import Diagnostic
from midas.cli.highlighter import DiagnosticsHighlighter, Highlighter, MidasHighlighter, PythonHighlighter
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic, DiagnosticType
from midas.checker.types import Type
from midas.cli.ansi import Ansi
from midas.cli.highlighter import (
DiagnosticsHighlighter,
Highlighter,
LocatableToken,
MidasHighlighter,
PythonHighlighter,
)
from midas.lexer.midas import MidasLexer
from midas.lexer.token import Token, TokenType
from midas.parser.midas import MidasParser
from midas.parser.python import PythonParser
from midas.resolver.resolver import Resolver
from midas.utils import UniversalJSONDumper
@click.group()
def midas():
click.echo("Welcome to Midas!")
pass
def print_diagnostic(lines: list[str], diagnostic: Diagnostic, indent: int = 4):
"""Pretty-print a diagnostic, showing some context if possible
If the diagnostic concerns a specific part of one line, the line is shown
with the affected part highlighted. The message is clearly printed under the
line with an underline further indicating the target expression.
If multiple lines are concerned, no context is shown, only the
diagnostic type, location and message
Args:
lines (list[str]): source code lines
diagnostic (Diagnostic): the diagnostic to print
indent (int, optional): the number of spaces added before the target line to indent if from the location header. Defaults to 4.
"""
loc: Location = diagnostic.location
if loc.lineno != loc.end_lineno:
print(diagnostic)
return
start_offset: int = loc.col_offset
end_offset: int = loc.end_col_offset or (start_offset + 1)
line: str = lines[loc.lineno - 1]
before: str = line[:start_offset]
after: str = line[end_offset:]
color: int = {
DiagnosticType.ERROR: Ansi.RED,
DiagnosticType.WARNING: Ansi.YELLOW,
DiagnosticType.INFO: Ansi.CYAN,
}.get(diagnostic.type, Ansi.WHITE)
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
cursor: str = (
" " * start_offset
+ Ansi.FG(color)
+ "~" * (end_offset - start_offset)
+ "> "
+ diagnostic.message
+ Ansi.RESET
)
indent_str: str = " " * indent
print(diagnostic.location_str + ":")
print(indent_str + before + subject + after)
print(indent_str + cursor)
print()
@midas.command()
@click.option("-l", "--highlight", type=click.File("w"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("-v", "--verbose", is_flag=True)
@click.option("-j", "--show-judgements", is_flag=True)
@click.argument("file", type=click.File("r"))
def compile(highlight: Optional[TextIO], file: TextIO):
logging.basicConfig(level=logging.DEBUG)
def compile(
highlight: Optional[TextIO],
types: tuple[TextIO],
verbose: bool,
show_judgements: bool,
file: TextIO,
):
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
source: str = file.read()
tree: ast.Module = ast.parse(source, filename=file.name)
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
resolver = Resolver()
resolver.resolve(*stmts)
checker = Checker(resolver.locals, file_path=Path(file.name).resolve())
diagnostics: list[Diagnostic] = checker.check(stmts)
for diagnostic in diagnostics:
print(diagnostic)
source_path: Path = Path(file.name).resolve()
print(
json.dumps(
UniversalJSONDumper.dump(
checker.global_env, [("Environment", "_children")]
),
indent=4,
checker = TypeChecker()
for types_file in types:
checker.import_midas(Path(types_file.name).resolve())
checker.type_check_source(source, str(source_path))
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
lines: list[str] = source.split("\n")
files: dict[Optional[str], list[str]] = {None: []}
if show_judgements:
for expr, type in checker.python_typer.judgements:
print(f"Judged that {expr} at {expr.location} is of type {type}")
diagnostics.append(
Diagnostic(
file_path=str(source_path),
location=expr.location,
type=DiagnosticType.INFO,
message=f"Type: {type}",
)
)
for diagnostic in diagnostics:
filename: Optional[str] = diagnostic.file_path
if filename is not None and filename not in files:
path: Path = Path(filename)
if path.exists() and path.is_file():
files[filename] = path.read_text().split("\n")
else:
files[filename] = []
lines: list[str] = files[filename]
print_diagnostic(lines, diagnostic)
if verbose:
print(
json.dumps(
UniversalJSONDumper.dump(
checker.python_typer.global_env,
[("Environment", "_children")],
lambda obj: isinstance(obj, get_args(Type)),
),
indent=4,
)
)
)
if highlight is not None:
highlighter = DiagnosticsHighlighter(source)
highlighter.highlight(diagnostics)
@@ -134,14 +228,6 @@ def highlight_midas(source: str, path: str) -> Highlighter:
for err in parser.errors:
print(err.get_report())
@dataclass(frozen=True)
class LocatableToken:
token: Token
@property
def location(self) -> Location:
return self.token.get_location()
for stmt in stmts:
highlighter.highlight(stmt)
for token in tokens:
@@ -168,5 +254,21 @@ def highlight(output: TextIO, file: TextIO):
highlighter.dump(output)
@midas.command()
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.argument("file", type=click.File("r"))
def format(output: TextIO, file: TextIO):
source: str = file.read()
printer = MidasPrinter()
lexer = MidasLexer(source, file=file.name)
tokens: list[Token] = lexer.process()
parser = MidasParser(tokens)
stmts: list[m.Stmt] = parser.parse()
for err in parser.errors:
print(err.get_report())
for stmt in stmts:
output.write(printer.print(stmt) + "\n")
if __name__ == "__main__":
midas()

View File

@@ -40,8 +40,8 @@ class MidasLexer(Lexer):
self.add_token(TokenType.AND)
case "?":
self.add_token(TokenType.QMARK)
# case ",":
# self.add_token(TokenType.COMMA)
case ",":
self.add_token(TokenType.COMMA)
case "_" if not self.is_identifier_char(self.peek_next(), start=False):
self.add_token(TokenType.UNDERSCORE)
case "-" if self.match(">"):
@@ -50,12 +50,14 @@ class MidasLexer(Lexer):
# self.add_token(TokenType.PLUS)
case "-":
self.add_token(TokenType.MINUS)
# case "*":
# self.add_token(TokenType.STAR)
case "*":
self.add_token(TokenType.STAR)
case "/" if self.match("/"):
self.scan_comment()
case "/" if self.match("*"):
self.scan_comment_multiline()
case "/":
self.add_token(TokenType.SLASH)
case "\n":
self.add_token(TokenType.NEWLINE)
case " " | "\r" | "\t":

View File

@@ -17,7 +17,7 @@ class TokenType(Enum):
LEFT_BRACE = auto()
RIGHT_BRACE = auto()
COLON = auto()
# COMMA = auto()
COMMA = auto()
UNDERSCORE = auto()
ARROW = auto()
AND = auto()
@@ -27,8 +27,8 @@ class TokenType(Enum):
# Operators
# PLUS = auto()
MINUS = auto()
# STAR = auto()
# SLASH = auto()
STAR = auto()
SLASH = auto()
GREATER = auto()
GREATER_EQUAL = auto()
LESS = auto()
@@ -46,10 +46,12 @@ class TokenType(Enum):
# Keywords
TYPE = auto()
OP = auto()
PREDICATE = auto()
EXTEND = auto()
WHERE = auto()
PROP = auto()
DEF = auto()
FUNC = auto()
# Misc
COMMENT = auto()
@@ -60,13 +62,15 @@ class TokenType(Enum):
KEYWORDS: dict[str, TokenType] = {
"type": TokenType.TYPE,
"op": TokenType.OP,
"predicate": TokenType.PREDICATE,
"extend": TokenType.EXTEND,
"where": TokenType.WHERE,
"true": TokenType.TRUE,
"false": TokenType.FALSE,
"none": TokenType.NONE,
"prop": TokenType.PROP,
"def": TokenType.DEF,
"fn": TokenType.FUNC,
}

View File

@@ -3,26 +3,30 @@ from typing import Optional
from midas.ast.location import Location
from midas.ast.midas import (
BinaryExpr,
ComplexTypeStmt,
ComplexType,
ConstraintType,
Expr,
ExtendStmt,
ExtensionType,
FunctionType,
GenericType,
GetExpr,
GroupingExpr,
LiteralExpr,
LogicalExpr,
OpStmt,
MemberKind,
MemberStmt,
NamedType,
PredicateStmt,
PropertyStmt,
SimpleTypeExpr,
SimpleTypeStmt,
Stmt,
TemplateExpr,
TypeExpr,
Type,
TypeParam,
TypeStmt,
UnaryExpr,
VariableExpr,
WildcardExpr,
)
from midas.lexer.token import Token, TokenType
from midas.lexer.token import KEYWORDS, Token, TokenType
from midas.parser.base import Parser
from midas.parser.errors import ParsingError
@@ -32,9 +36,10 @@ class MidasParser(Parser):
SYNC_BOUNDARY: set[TokenType] = {
TokenType.TYPE,
TokenType.OP,
TokenType.EXTEND,
TokenType.PREDICATE,
TokenType.PROP,
TokenType.FUNC,
}
def parse(self) -> list[Stmt]:
@@ -81,7 +86,7 @@ class MidasParser(Parser):
self.synchronize()
return None
def type_declaration(self) -> SimpleTypeStmt | ComplexTypeStmt:
def type_declaration(self) -> TypeStmt:
"""Parse a type declaration
A type declaration can either be a simple type alias or a new complex type.
@@ -106,51 +111,52 @@ class MidasParser(Parser):
TypeStmt: the parsed type declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
template: Optional[TemplateExpr] = None
if self.check(TokenType.LEFT_BRACKET):
template = self.template_expr()
name: Token = self.consume_identifier("Expected type name")
params: list[TypeParam] = self.type_params()
if self.match(TokenType.LEFT_PAREN):
base: TypeExpr = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Unclosed base type parenthesis")
constraint: Optional[Expr] = None
if self.match(TokenType.WHERE):
constraint = self.constraint()
return SimpleTypeStmt(
location=keyword.location_to(self.previous()),
name=name,
template=template,
base=base,
constraint=constraint,
)
else:
properties: list[PropertyStmt] = self.type_properties()
return ComplexTypeStmt(
location=keyword.location_to(self.previous()),
name=name,
template=template,
properties=properties,
)
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
def template_expr(self) -> TemplateExpr:
"""Parse a generic template expression
type: Type = self.type_expr()
A template is written `[TypeExpr]`
return TypeStmt(
location=keyword.location_to(self.previous()),
name=name,
params=params,
type=type,
)
def type_params(self) -> list[TypeParam]:
"""Parse a list of type parameters
Type parameters are a comma-separated list of type variables wrapped in brackets.
Each type variable is either a simple variable, or a bounded variable written `S <: T`
Returns:
TemplateExpr: the parsed template expression
list[TypeParam]: the list of type parameters, if any, or an empty list
"""
left: Token = self.consume(
TokenType.LEFT_BRACKET, "Missing '[' before template expression"
)
type: TypeExpr = self.type_expr()
right: Token = self.consume(
TokenType.RIGHT_BRACKET, "Missing ']' after template expression"
)
return TemplateExpr(location=left.location_to(right), type=type)
if not self.match(TokenType.LEFT_BRACKET):
return []
def type_expr(self) -> TypeExpr:
params: list[TypeParam] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
name: Token = self.consume_identifier("Expected type variable")
bound: Optional[Type] = None
if self.match(TokenType.LESS):
self.consume(TokenType.COLON, "Expected ':' after '<'")
bound = self.type_expr()
params.append(
TypeParam(
location=name.location_to(self.previous()),
name=name,
bound=bound,
)
)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
return params
def type_expr(self) -> Type:
"""Parse a type expression
A type is an identifier, optionally followed by a template expression.
@@ -159,30 +165,96 @@ class MidasParser(Parser):
Returns:
TypeExpr: the parsed type expression
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
template: Optional[TemplateExpr] = None
base: Type
if self.match(TokenType.FUNC):
base = self.function()
else:
base = self.constraint_type()
if self.match(TokenType.AND):
extension: ComplexType = self.complex_type()
return ExtensionType(
location=Location.span(base.location, extension.location),
base=base,
extension=extension,
)
return base
def constraint_type(self) -> Type:
type: Type = self.base_type()
if self.match(TokenType.WHERE):
constraint: Expr = self.constraint()
return ConstraintType(
location=Location.span(type.location, constraint.location),
type=type,
constraint=constraint,
)
return type
def base_type(self) -> Type:
if self.match(TokenType.LEFT_PAREN):
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Unclosed parenthesis")
return type
if self.check(TokenType.LEFT_BRACE):
return self.complex_type()
return self.generic_type()
def generic_type(self) -> Type:
type: Type = self.named_type()
if self.check(TokenType.LEFT_BRACKET):
template = self.template_expr()
optional: bool = self.match(TokenType.QMARK)
return TypeExpr(
location=name.location_to(self.previous()),
args: list[Type] = self.type_args()
return GenericType(
location=Location.span(type.location, self.previous().get_location()),
type=type,
args=args,
)
return type
def type_args(self) -> list[Type]:
args: list[Type] = []
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
args.append(self.type_expr())
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
return args
def named_type(self) -> Type:
name: Token = self.consume_identifier("Expected type name")
return NamedType(
location=name.get_location(),
name=name,
template=template,
optional=optional,
)
def simple_type_expr(self) -> SimpleTypeExpr:
"""Parse a simple type expression
def complex_type(self) -> ComplexType:
"""Parse a type definition body
A simple type is just an identifier optionally followed by a '?'
A type definition body is a set of whitespace-separated
property statements enclosed in curly braces
Returns:
SimpleTypeExpr: the parsed simple type expression
ComplexType: the parsed complex type
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
optional: bool = self.match(TokenType.QMARK)
return SimpleTypeExpr(
location=name.location_to(self.previous()), name=name, optional=optional
left: Token = self.consume(
TokenType.LEFT_BRACE, "Expected '{' to start type body"
)
members: list[MemberStmt] = []
# TODO: add keyword to differentiate properties and methods,
# and allow multiple methods with the same name but not properties
names: set[str] = set()
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
member: MemberStmt = self.member_stmt()
# if member.name.lexeme in names:
# raise self.error(member.name, "Duplicate property")
# names.add(member.name.lexeme)
members.append(member)
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
return ComplexType(
location=left.location_to(right),
members=members,
)
def constraint(self) -> Expr:
@@ -269,9 +341,7 @@ class MidasParser(Parser):
"""
expr: Expr = self.primary()
while self.match(TokenType.DOT):
name: Token = self.consume(
TokenType.IDENTIFIER, "Expected property name after '.'"
)
name: Token = self.consume_identifier("Expected property name after '.'")
location: Location = Location.span(expr.location, name.get_location())
expr = GetExpr(location=location, expr=expr, name=name)
return expr
@@ -295,7 +365,7 @@ class MidasParser(Parser):
if self.match(TokenType.NUMBER):
return LiteralExpr(location=token.get_location(), value=token.value)
if self.match(TokenType.IDENTIFIER):
if self.match_identifier():
return VariableExpr(location=token.get_location(), name=token)
if self.match(TokenType.UNDERSCORE):
@@ -308,89 +378,70 @@ class MidasParser(Parser):
raise self.error(self.peek(), "Expected expression")
def type_properties(self) -> list[PropertyStmt]:
"""Parse a type definition body
def consume_identifier(self, message: str = "Expected identifier") -> Token:
if not self.match_identifier():
raise self.error(self.peek(), message)
return self.previous()
A type definition body is a set of whitespace-separated
property statements enclosed in curly braces
def match_identifier(self) -> bool:
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
def check_identifier(self) -> bool:
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
if self.check(tt):
return True
return False
def member_stmt(self) -> MemberStmt:
"""Parse a member statement
A type member statement is written `prop name: Type` or `def name: Type`
Returns:
list[PropertyStmt]: the parsed type properties
MemberStmt: the parsed member statement
"""
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start type body")
properties: list[PropertyStmt] = []
names: set[str] = set()
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
prop: PropertyStmt = self.property_stmt()
if prop.name.lexeme in names:
raise self.error(prop.name, "Duplicate property")
names.add(prop.name.lexeme)
properties.append(prop)
self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
return properties
kind: MemberKind
if self.match(TokenType.PROP):
kind = MemberKind.PROPERTY
elif self.match(TokenType.DEF):
kind = MemberKind.METHOD
else:
raise self.error(self.peek(), "Expected 'prop' or 'def'")
def property_stmt(self) -> PropertyStmt:
"""Parse a property statement
name: Token = self.consume_identifier("Expected member name")
self.consume(TokenType.COLON, "Expected ':' after member name")
A type property statement is written `name: Type` or `name: Type where Condition`
Returns:
PropertyStmt: the parsed property statement
"""
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
self.consume(TokenType.COLON, "Expected ':' after property name")
type: TypeExpr = self.type_expr()
constraint: Optional[Expr] = None
if self.match(TokenType.WHERE):
constraint = self.constraint()
return PropertyStmt(
type: Type = self.type_expr()
return MemberStmt(
location=name.location_to(self.previous()),
name=name,
type=type,
constraint=constraint,
kind=kind,
)
def extend_declaration(self) -> ExtendStmt:
"""Parse an extension definition
An extension is written `extend Type { operations }`
An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }`
Returns:
ExtendStmt: the parsed extension statement
"""
keyword: Token = self.previous()
type: TypeExpr = self.type_expr()
name: Token = self.consume_identifier("Expected type name")
params: list[TypeParam] = self.type_params()
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
operations: list[OpStmt] = []
members: list[MemberStmt] = []
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
operations.append(self.op_declaration())
members.append(self.member_stmt())
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
location: Location = keyword.location_to(self.previous())
return ExtendStmt(location=location, type=type, operations=operations)
def op_declaration(self) -> OpStmt:
"""Parse an operation definition
An operation is written `op name(Type) -> Type`
Returns:
OpStmt: the parsed operation statement
"""
keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword")
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
operand: TypeExpr = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: TypeExpr = self.type_expr()
return OpStmt(
location=keyword.location_to(self.previous()),
return ExtendStmt(
location=location,
name=name,
operand=operand,
result=result,
params=params,
members=members,
)
def predicate_declaration(self) -> PredicateStmt:
@@ -402,11 +453,11 @@ class MidasParser(Parser):
PredicateStmt: the parsed predicate declaration statement
"""
keyword: Token = self.previous()
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
name: Token = self.consume_identifier("Expected predicate name")
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
subject: Token = self.consume_identifier("Expected subject name")
self.consume(TokenType.COLON, "Expected ':' after subject name")
type: TypeExpr = self.type_expr()
type: Type = self.type_expr()
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
condition: Expr = self.constraint()
@@ -417,3 +468,72 @@ class MidasParser(Parser):
type=type,
condition=condition,
)
def function(self) -> FunctionType:
l_paren: Token = self.consume(
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
)
pos_args: list[FunctionType.Argument] = []
args: list[FunctionType.Argument] = []
kw_args: list[FunctionType.Argument] = []
args_first_tokens: list[Token] = []
section: int = 0
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
match section:
case 0 if self.match(TokenType.SLASH):
pos_args = args
args = []
args_first_tokens = []
section = 1
case 0 | 1 if self.match(TokenType.STAR):
section = 2
case _:
# Record first token of mixed argument for errors if unnamed
if section != 2:
args_first_tokens.append(self.peek())
name: Optional[Token] = None
if section == 2:
name = self.consume_identifier("Expected keyword argument name")
self.consume(
TokenType.COLON, "Expected ':' after argument name"
)
elif self.check_identifier() and self.check_next(TokenType.COLON):
name = self.advance()
self.advance()
type: Type = self.type_expr()
optional: bool = self.match(TokenType.QMARK)
arg = FunctionType.Argument(
location=None,
name=name,
type=type,
required=not optional,
)
if section == 2:
kw_args.append(arg)
else:
args.append(arg)
if not self.match(TokenType.COMMA):
break
for arg, token in zip(args, args_first_tokens):
if arg.name is None:
# Not raised because we can keep parsing
self.error(token, "Unnamed mixed argument")
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
self.consume(TokenType.ARROW, "Expected '->' before result type")
result: Type = self.type_expr()
return FunctionType(
location=l_paren.location_to(self.previous()),
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=result,
)

View File

@@ -16,11 +16,16 @@ from midas.ast.python import (
FrameType,
Function,
GetExpr,
IfStmt,
ListExpr,
LiteralExpr,
LogicalExpr,
MidasType,
ReturnStmt,
SliceExpr,
Stmt,
SubscriptExpr,
TernaryExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
@@ -82,6 +87,12 @@ class PythonParser:
value=self.parse_expr(value) if value is not None else None,
)
case ast.If():
return self.parse_if(node)
case ast.Pass():
return None
case _:
print(f"Unsupported statement: {ast.unparse(node)}")
return None
@@ -147,6 +158,30 @@ class PythonParser:
),
)
def parse_if(self, node: ast.If) -> IfStmt:
body: list[Stmt] = []
for stmt in node.body:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
body.append(stmts)
elif stmts is not None:
body.extend(stmts)
orelse: list[Stmt] = []
for stmt in node.orelse:
stmts = self.parse_stmt(stmt)
if isinstance(stmts, Stmt):
orelse.append(stmts)
elif stmts is not None:
orelse.extend(stmts)
return IfStmt(
location=Location.from_ast(node),
test=self.parse_expr(node.test),
body=body,
orelse=orelse,
)
def parse_function(self, node: ast.FunctionDef) -> Function:
loc: Location = Location.from_ast(node)
match node:
@@ -282,6 +317,13 @@ class PythonParser:
constraint=right_expr,
)
case ast.Constant(value=None):
return BaseType(
location=loc,
base="None",
param=None,
)
case _:
raise UnsupportedSyntaxError(type_expr)
@@ -361,6 +403,9 @@ class PythonParser:
case ast.Call():
return self.parse_call(node)
case ast.IfExp():
return self.parse_ternary(node)
case ast.Constant(value=value):
return LiteralExpr(location=location, value=value)
@@ -374,6 +419,27 @@ class PythonParser:
case ast.Name(id=name):
return VariableExpr(location=location, name=name)
case ast.List(elts=items):
return ListExpr(
location=location,
items=[self.parse_expr(item) for item in items],
)
case ast.Subscript(value=value, slice=index):
return SubscriptExpr(
location=location,
object=self.parse_expr(value),
index=self.parse_expr(index),
)
case ast.Slice(lower=lower, upper=upper, step=step):
return SliceExpr(
location=location,
lower=self.parse_expr(lower) if lower is not None else None,
upper=self.parse_expr(upper) if upper is not None else None,
step=self.parse_expr(step) if step is not None else None,
)
case _:
raise UnsupportedSyntaxError(node)
@@ -450,3 +516,11 @@ class PythonParser:
if arg.arg is not None # Should always be True, type checker happy
},
)
def parse_ternary(self, node: ast.IfExp) -> TernaryExpr:
return TernaryExpr(
location=Location.from_ast(node),
test=self.parse_expr(node.test),
if_true=self.parse_expr(node.body),
if_false=self.parse_expr(node.orelse),
)

View File

@@ -1,39 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from midas.checker.types import BaseType, Type
if TYPE_CHECKING:
from midas.resolver.midas import MidasResolver
def basic_op(ctx: MidasResolver, type: Type, op: str):
ctx.define_operation(
left=type,
operator=op,
right=type,
result=type,
)
def define_builtins(ctx: MidasResolver):
"""Define builtin types and operations"""
bool = ctx.define_type("bool", BaseType(name="bool"))
int = ctx.define_type("int", BaseType(name="int"))
float = ctx.define_type("float", BaseType(name="float"))
str = ctx.define_type("str", BaseType(name="str"))
basic_op(ctx, int, "__add__")
basic_op(ctx, int, "__sub__")
basic_op(ctx, int, "__mul__")
basic_op(ctx, int, "__pow__")
basic_op(ctx, int, "__mod__")
basic_op(ctx, int, "__and__")
basic_op(ctx, int, "__or__")
basic_op(ctx, int, "__xor__")
basic_op(ctx, float, "__add__")
basic_op(ctx, float, "__sub__")
basic_op(ctx, float, "__mul__")
basic_op(ctx, float, "__truediv__")
basic_op(ctx, str, "__add__")

View File

@@ -1,153 +0,0 @@
from typing import Optional
import midas.ast.midas as m
from midas.checker.types import BaseType, SimpleType, Type
from midas.resolver.builtin import define_builtins
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[Type]):
"""A resolver which evaluates Midas type definitions and build a registry"""
def __init__(self) -> None:
self._types: dict[str, Type] = {}
self._operations: dict[tuple[Type, str, Type], Type] = {}
define_builtins(self)
def get_type(self, name: str) -> Type:
"""Get a type from its name
Args:
name (str): the name of the type
Raises:
NameError: if the type is not defined
Returns:
Type: the type
"""
type: Optional[Type] = self._types.get(name)
if type is None:
raise NameError(f"Undefined type {name}")
return type
def get_operation_result(
self, left: Type, operator: str, right: Type
) -> Optional[Type]:
"""Get the resulting type of an operation
Args:
left (Type): the type of the left operand
operator (str): the operation name
right (Type): the type of the right operand
Returns:
Optional[Type]: the result type, or None if no matching operation was found
"""
operation: tuple[Type, str, Type] = (left, operator, right)
result: Optional[Type] = self._operations.get(operation)
return result
def define_type(self, name: str, type: Type) -> Type:
"""Define a type in the registry
Args:
name (str): the name of the type
type (Type): the type to define
Raises:
ValueError: if a type is already defined with that name
Returns:
Type: the defined type
"""
if name in self._types:
raise ValueError(f"Type {name} already defined")
self._types[name] = type
return type
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
"""Define an operation in the registry
Args:
left (Type): the type of the left operand
operator (str): the operation name
right (Type): the type of the right operand
result (Type): the result type
Raises:
ValueError: if an operation is already defined with these operands and name
"""
operation: tuple[Type, str, Type] = (left, operator, right)
if operation in self._operations:
raise ValueError(
f"Operation {operator} already defined between {left} and {right}"
)
self._operations[operation] = result
def resolve(self, stmts: list[m.Stmt]):
"""Process a sequence of statements
Args:
stmts (list[m.Stmt]): the statements
"""
for stmt in stmts:
stmt.accept(self)
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
# TODO generics, optional, constraint
base: Type = self.get_type(stmt.base.name.lexeme)
match base:
case BaseType() | SimpleType():
type = SimpleType(
name=stmt.name.lexeme,
base=base,
)
self.define_type(type.name, type)
case _:
raise TypeError(f"Invalid base {base} for simple type")
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None: ...
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
base: Type = stmt.type.accept(self)
for op in stmt.operations:
right: Type = op.operand.accept(self)
result: Type = op.result.accept(self)
self.define_operation(
left=base,
operator=op.name.lexeme,
right=right,
result=result,
)
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> Type:
return self.get_type(expr.name.lexeme)
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type: ...
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type: ...
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type: ...
def visit_get_expr(self, expr: m.GetExpr) -> Type: ...
def visit_variable_expr(self, expr: m.VariableExpr) -> Type: ...
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
return expr.expr.accept(self)
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type: ...
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type: ...
def visit_template_expr(self, expr: m.TemplateExpr) -> Type: ...
def visit_type_expr(self, expr: m.TypeExpr) -> Type:
return self.get_type(expr.name.lexeme)

View File

@@ -1,18 +1,27 @@
from typing import Any, Optional
from typing import Any, Callable, Optional
AllowRepeat = Callable[[object], bool]
class UniversalJSONDumper:
@classmethod
def dump(
cls, obj: Any, include_keys: Optional[list[str | tuple[str, str]]] = None
cls,
obj: Any,
include_keys: Optional[list[str | tuple[str, str]]] = None,
allow_repeat: Optional[AllowRepeat] = None,
) -> Any:
if include_keys is None:
include_keys = []
return cls._dump(obj, include_keys, [])
return cls._dump(obj, include_keys, allow_repeat, [])
@classmethod
def _dump(
cls, obj: Any, include_keys: list[str | tuple[str, str]], visited: list[Any]
cls,
obj: Any,
include_keys: list[str | tuple[str, str]],
allow_repeat: Optional[AllowRepeat],
visited: list[Any],
) -> Any:
if obj in visited:
return None
@@ -20,17 +29,22 @@ class UniversalJSONDumper:
case str() | int() | float() | None:
return obj
case list() | set() | tuple():
return [cls._dump(child, include_keys, visited) for child in obj]
return [
cls._dump(child, include_keys, allow_repeat, visited)
for child in obj
]
case dict():
return {
str(k): cls._dump(v, include_keys, visited) for k, v in obj.items()
str(k): cls._dump(v, include_keys, allow_repeat, visited)
for k, v in obj.items()
}
case object():
visited.append(obj)
if allow_repeat is None or not allow_repeat(obj):
visited.append(obj)
return {
"_type": obj.__class__.__name__,
} | {
k: cls._dump(v, include_keys, visited)
k: cls._dump(v, include_keys, allow_repeat, visited)
for k, v in obj.__dict__.items()
if not k.startswith("_")
or k in include_keys

View File

@@ -19,16 +19,24 @@ Comparison ::= Unary (ComparisonOp Unary)*
Equality ::= Comparison (EqualityOp Comparison)*
Constraint ::= Equality ("&" Equality)*
SimpleType ::= Identifier "?"?
Template ::= "[" Type "]"
Type ::= Identifier Template? "?"?
TemplateParam ::= Identifier ("<:" Type)?
Template ::= "[" (TemplateParam ("," TemplateParam)*)? "]"
TypeProperty ::= Identifier ":" Type
ComplexType ::= "{" TypeProperty* "}"
NamedType ::= Identifier
TypeParams ::= "[" (Type ("," Type)*)? "]"
GenericType ::= NamedType TypeParams?
GroupedType ::= "(" Type ")"
BaseType ::= GroupedType | ComplexType | GenericType
ConstraintType ::= BaseType ("where" Constraint)?
Type ::= ConstraintType
TypeProperty ::= Identifier ":" Type ("where" Constraints)?
ComplexTypeBody ::= "{" TypeProperty* "}"
OpDefinition ::= "op" Identifier "(" Type ")" "->" Type
ExtendBody ::= "{" OpDefinition* "}"
TypeStatement ::= "type" Identifier Template? ("(" Type ")" ("where" Constraint)? | ComplexTypeBody)
TypeStatement ::= "type" Identifier Template? "=" Type
ExtendStatement ::= "extend" Type ExtendBody
PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint

View File

@@ -43,28 +43,52 @@ svg.railroad .terminal rect {
{[`constraint` 'equality'*"&"]}
```
#let simple-type = ```
{[`simple-type` 'identifier' <!, "?">]}
#let template-param = ```
{[`template-param` 'identifier' <!, ["<:" 'type']>]}
```
#let template = ```
{[`template` "[" 'type' "]"]}
```
#let type = ```
{[`type` 'identifier' <!, 'template'> <!, "?">]}
{[`template` "[" <!, 'template-param'*","> "]"]}
```
#let type-property = ```
{[`type-property` 'identifier' ":" 'type' <!, ["where" 'constraint']>]}
{[`type-property` 'identifier' ":" 'type']}
```
#let type-body = ```
{[`type-body` "{" <!, 'type-property'*!> "}"]}
#let complex-type = ```
{[`complex-type` "{" <!, 'type-property'*!> "}"]}
```
#let named-type = ```
{[`named-type` 'identifier']}
```
#let type-params = ```
{[`type-params` "[" <!, 'type'*","> "]"]}
```
#let generic-type = ```
{[`generic-type` 'named-type' <!, 'type-params'>]}
```
#let grouped-type = ```
{[`grouped-type` "(" 'type' ")"]}
```
#let base-type = ```
{[`base-type` <'grouped-type', 'complex-type', 'generic-type'>]}
```
#let constraint-type = ```
{[`constraint-type` 'base-type' <!, ["where" 'constraint']>]}
```
#let type = ```
{[`type` 'constraint-type']}
```
#let type-statement = ```
{[`type-statement` "type" 'identifier' <!, 'template'> <[["(" 'type' ")"] <!, ["where" 'constraint']>], 'type-body'>]}
{[`type-statement` "type" 'identifier' <!, 'template'> "=" 'type']}
```
#let op-definition = ```
@@ -92,11 +116,17 @@ svg.railroad .terminal rect {
comparison: comparison,
equality: equality,
constraint: constraint,
simple-type: simple-type,
template-param: template-param,
template: template,
type: type,
type-property: type-property,
type-body: type-body,
complex-type: complex-type,
named-type: named-type,
type-params: type-params,
generic-type: generic-type,
grouped-type: grouped-type,
base-type: base-type,
constraint-type: constraint-type,
type: type,
type-statement: type-statement,
op-definition: op-definition,
extend-statement: extend-statement,
@@ -107,10 +137,16 @@ svg.railroad .terminal rect {
#let inline = (
"grouping",
"value",
"template-param",
"template",
"simple-type",
"type-property",
"type-body",
"complex-type",
"type-params",
"named-type",
"grouped-type",
"generic-type",
"base-type",
"constraint-type",
"op-definition",
"type-statement",
"extend-statement",

View File

@@ -29,7 +29,7 @@ class Tester(ABC):
def _list_tests(self) -> list[Path]: ...
def run_all_tests(self) -> bool:
paths: list[Path] = self._list_tests()
paths: list[Path] = sorted(self._list_tests())
return self.run_tests(paths)
def run_tests(self, tests: list[Path]) -> bool:
@@ -40,7 +40,7 @@ class Tester(ABC):
print(rule)
for i, test in enumerate(tests):
print(f"Case {i+1}/{n}: {test.relative_to(self.CASES_DIR)}")
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
success: bool = self._run_test(test)
if success:
successes += 1
@@ -78,7 +78,7 @@ class Tester(ABC):
def _exec_case(self, path: Path) -> CaseResult: ...
def update_all_tests(self):
paths: list[Path] = self._list_tests()
paths: list[Path] = sorted(self._list_tests())
return self.update_tests(paths)
def update_tests(self, tests: list[Path]):
@@ -141,3 +141,9 @@ class Tester(ABC):
success = tester.run_tests(args.FILE)
if not success:
sys.exit(1)
case None:
print("No subcommand provided. Available subcommands: run, update")
sys.exit(1)
case _:
print(f"Unknown subcommand '{args.subcommand}'")
sys.exit(1)

View File

@@ -1,3 +1,19 @@
{
"diagnostics": []
"diagnostics": [
{
"type": "Warning",
"location": {
"start": [
6,
4
],
"end": [
13,
5
]
},
"message": "FrameType not yet supported"
}
],
"judgments": []
}

View File

@@ -12,7 +12,7 @@
13
]
},
"message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')"
"message": "Cannot assign str to variable 'c' of type int"
},
{
"type": "Error",
@@ -26,21 +26,166 @@
9
]
},
"message": "Undefined operation __add__ between BaseType(name='bool') and BaseType(name='bool')"
"message": "Undefined operation __add__ between bool and bool"
}
],
"judgments": [
{
"location": {
"from": "L1:9",
"to": "L1:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"type": "Error",
"location": {
"start": [
11,
0
],
"end": [
11,
12
]
"from": "L2:9",
"to": "L2:10"
},
"message": "Cannot assign BaseType(name='int') to f of type BaseType(name='float')"
"expr": {
"_type": "LiteralExpr",
"value": 4
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L4:4",
"to": "L4:5"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L4:8",
"to": "L4:9"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L4:4",
"to": "L4:9"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "a"
},
"operator": "+",
"right": {
"_type": "VariableExpr",
"name": "b"
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:4",
"to": "L6:13"
},
"expr": {
"_type": "LiteralExpr",
"value": "invalid"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L8:4",
"to": "L8:8"
},
"expr": {
"_type": "LiteralExpr",
"value": true
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L9:4",
"to": "L9:5"
},
"expr": {
"_type": "VariableExpr",
"name": "d"
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L9:8",
"to": "L9:9"
},
"expr": {
"_type": "VariableExpr",
"name": "d"
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L9:4",
"to": "L9:9"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "d"
},
"operator": "+",
"right": {
"_type": "VariableExpr",
"name": "d"
}
},
"type": {}
},
{
"location": {
"from": "L11:11",
"to": "L11:12"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
}
]
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,14 @@
type Meter(float)
type Second(float)
type MeterPerSecond(float)
type Meter = float
type Second = float
type MeterPerSecond = float
extend Meter {
op __add__(Meter) -> Meter
op __sub__(Meter) -> Meter
op __truediv__(Second) -> MeterPerSecond
def __add__: fn(Meter, /) -> Meter
def __sub__: fn(Meter, /) -> Meter
def __truediv__: fn(Second, /) -> MeterPerSecond
}
extend Second {
op __add__(Second) -> Second
op __sub__(Second) -> Second
def __add__: fn(Second, /) -> Second
def __sub__: fn(Second, /) -> Second
}

View File

@@ -1,8 +1,6 @@
# type: ignore
# ruff: disable [F821]
midas.using("04_custom_types.midas")
distance: Meter = cast(Meter, 123.45)
time: Second = cast(Second, 6.7)
speed = distance / time

View File

@@ -1,3 +1,109 @@
{
"diagnostics": []
"diagnostics": [],
"judgments": [
{
"location": {
"from": "L4:18",
"to": "L4:37"
},
"expr": {
"_type": "CastExpr",
"type": {
"_type": "BaseType",
"base": "Meter",
"param": null
},
"expr": {
"_type": "LiteralExpr",
"value": 123.45
}
},
"type": {
"name": "Meter",
"type": {
"name": "float"
}
}
},
{
"location": {
"from": "L5:15",
"to": "L5:32"
},
"expr": {
"_type": "CastExpr",
"type": {
"_type": "BaseType",
"base": "Second",
"param": null
},
"expr": {
"_type": "LiteralExpr",
"value": 6.7
}
},
"type": {
"name": "Second",
"type": {
"name": "float"
}
}
},
{
"location": {
"from": "L6:8",
"to": "L6:16"
},
"expr": {
"_type": "VariableExpr",
"name": "distance"
},
"type": {
"name": "Meter",
"type": {
"name": "float"
}
}
},
{
"location": {
"from": "L6:19",
"to": "L6:23"
},
"expr": {
"_type": "VariableExpr",
"name": "time"
},
"type": {
"name": "Second",
"type": {
"name": "float"
}
}
},
{
"location": {
"from": "L6:8",
"to": "L6:23"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "distance"
},
"operator": "/",
"right": {
"_type": "VariableExpr",
"name": "time"
}
},
"type": {
"name": "MeterPerSecond",
"type": {
"name": "float"
}
}
}
]
}

View File

@@ -0,0 +1,25 @@
def valid(a: int, b: int) -> int:
return a + b
def with_if(a: int, b: int) -> int:
if a < b:
return b - a
else:
return a - b
def unreachable1():
return
a = 0
def unreachable2(a: int) -> int:
if a > 10:
return a - 10
else:
return a
b = 0
def mixed(a: int, b: int):
if a < b:
return b - a
else:
return "oops"

View File

@@ -0,0 +1,450 @@
{
"diagnostics": [
{
"type": "Warning",
"location": {
"start": [
12,
4
],
"end": [
12,
9
]
},
"message": "Unreachable statement"
},
{
"type": "Warning",
"location": {
"start": [
19,
4
],
"end": [
19,
9
]
},
"message": "Unreachable statement"
},
{
"type": "Error",
"location": {
"start": [
21,
0
],
"end": [
25,
21
]
},
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
}
],
"judgments": [
{
"location": {
"from": "L2:11",
"to": "L2:12"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L2:15",
"to": "L2:16"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L2:11",
"to": "L2:16"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "a"
},
"operator": "+",
"right": {
"_type": "VariableExpr",
"name": "b"
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L5:7",
"to": "L5:8"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L5:11",
"to": "L5:12"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L5:7",
"to": "L5:12"
},
"expr": {
"_type": "CompareExpr",
"left": {
"_type": "VariableExpr",
"name": "a"
},
"operator": "<",
"right": {
"_type": "VariableExpr",
"name": "b"
}
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L6:15",
"to": "L6:16"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:19",
"to": "L6:20"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:15",
"to": "L6:20"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "b"
},
"operator": "-",
"right": {
"_type": "VariableExpr",
"name": "a"
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L8:15",
"to": "L8:16"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L8:19",
"to": "L8:20"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L8:15",
"to": "L8:20"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "a"
},
"operator": "-",
"right": {
"_type": "VariableExpr",
"name": "b"
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L15:7",
"to": "L15:8"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L15:11",
"to": "L15:13"
},
"expr": {
"_type": "LiteralExpr",
"value": 10
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L15:7",
"to": "L15:13"
},
"expr": {
"_type": "CompareExpr",
"left": {
"_type": "VariableExpr",
"name": "a"
},
"operator": ">",
"right": {
"_type": "LiteralExpr",
"value": 10
}
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L16:15",
"to": "L16:16"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L16:19",
"to": "L16:21"
},
"expr": {
"_type": "LiteralExpr",
"value": 10
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L16:15",
"to": "L16:21"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "a"
},
"operator": "-",
"right": {
"_type": "LiteralExpr",
"value": 10
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L18:15",
"to": "L18:16"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L22:7",
"to": "L22:8"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L22:11",
"to": "L22:12"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L22:7",
"to": "L22:12"
},
"expr": {
"_type": "CompareExpr",
"left": {
"_type": "VariableExpr",
"name": "a"
},
"operator": "<",
"right": {
"_type": "VariableExpr",
"name": "b"
}
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L23:15",
"to": "L23:16"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L23:19",
"to": "L23:20"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L23:15",
"to": "L23:20"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "b"
},
"operator": "-",
"right": {
"_type": "VariableExpr",
"name": "a"
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L25:15",
"to": "L25:21"
},
"expr": {
"_type": "LiteralExpr",
"value": "oops"
},
"type": {
"name": "str"
}
}
]
}

View File

@@ -0,0 +1,12 @@
v1: int = 3
v2: float = 4
def maximum(a: float, b: float):
if b > a:
return b
return a
v3 = maximum(v1, v2)
v3 = v2 + v1

View File

@@ -0,0 +1,239 @@
{
"diagnostics": [],
"judgments": [
{
"location": {
"from": "L1:10",
"to": "L1:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L2:12",
"to": "L2:13"
},
"expr": {
"_type": "LiteralExpr",
"value": 4
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:7",
"to": "L6:8"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L6:11",
"to": "L6:12"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L6:7",
"to": "L6:12"
},
"expr": {
"_type": "CompareExpr",
"left": {
"_type": "VariableExpr",
"name": "b"
},
"operator": ">",
"right": {
"_type": "VariableExpr",
"name": "a"
}
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L7:15",
"to": "L7:16"
},
"expr": {
"_type": "VariableExpr",
"name": "b"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L8:11",
"to": "L8:12"
},
"expr": {
"_type": "VariableExpr",
"name": "a"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L11:5",
"to": "L11:12"
},
"expr": {
"_type": "VariableExpr",
"name": "maximum"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "a",
"type": {
"name": "float"
},
"required": true
},
{
"pos": 1,
"name": "b",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L11:13",
"to": "L11:15"
},
"expr": {
"_type": "VariableExpr",
"name": "v1"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:17",
"to": "L11:19"
},
"expr": {
"_type": "VariableExpr",
"name": "v2"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L11:5",
"to": "L11:20"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "maximum"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "v1"
},
{
"_type": "VariableExpr",
"name": "v2"
}
],
"keywords": {}
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L12:5",
"to": "L12:7"
},
"expr": {
"_type": "VariableExpr",
"name": "v2"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L12:10",
"to": "L12:12"
},
"expr": {
"_type": "VariableExpr",
"name": "v1"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L12:5",
"to": "L12:12"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "v2"
},
"operator": "+",
"right": {
"_type": "VariableExpr",
"name": "v1"
}
},
"type": {
"name": "float"
}
}
]
}

View File

@@ -1,17 +1,17 @@
// Simple custom type derived from float
type Custom(float)
type Custom = float
// Simple custom types with constraints
type Latitude(float) where (-90 <= _ <= 90)
type Longitude(float) where (-180 <= _ <= 180)
type Latitude = float where (-90 <= _ <= 90)
type Longitude = float where (-180 <= _ <= 180)
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
type Difference[T](T)
type Difference[T] = T
// Complex custom type, containing two values accessible through properties
type GeoLocation {
lat: Latitude
lon: Longitude
type GeoLocation = {
prop lat: Latitude
prop lon: Longitude
}
// Define operations on our custom type
@@ -19,23 +19,23 @@ extend GeoLocation {
// This type is compatible with the `-` operation with another GeoLocation
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
// in a Difference of GeoLocations
op __sub__(GeoLocation) -> Difference[GeoLocation]
def __sub__: fn(GeoLocation, /) -> Difference[GeoLocation]
}
// For complex generics, you need to specify how the genericity the properties
// are handled
type Difference[GeoLocation] {
lat: Difference[Latitude]
lon: Difference[Longitude]
type Difference[GeoLocation] = {
prop lat: Difference[Latitude]
prop lon: Difference[Longitude]
}
// Simple operation defined on our custom types
extend Latitude {
op __sub__(Latitude) -> Difference[Latitude]
def __sub__: fn(Latitude, /) -> Difference[Latitude]
}
extend Longitude {
op __sub__(Longitude) -> Difference[Longitude]
def __sub__: fn(Longitude, /) -> Difference[Longitude]
}
// Predefined custom predicates that can be referenced in other definitions
@@ -44,14 +44,14 @@ predicate StrictlyPositive(v: float) = v > 0
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
type Person {
name: str
type Person = {
prop name: str
// Property with an inline constraint
age: int? where (0 <= _ < 150)
prop age: Optional[int where (0 <= _ < 150)]
// Property referencing a predicate
height: float where StrictlyPositive
prop height: float where StrictlyPositive
home: GeoLocation
prop home: GeoLocation
}

File diff suppressed because it is too large Load Diff

View File

@@ -2,10 +2,6 @@
# ruff: disable[F821]
from __future__ import annotations
import midas
midas.using("02_custom_types.midas")
df: Frame[
location: GeoLocation
]

View File

@@ -1,26 +1,5 @@
{
"stmts": [
{
"_type": "ExpressionStmt",
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "GetExpr",
"object": {
"_type": "VariableExpr",
"name": "midas"
},
"name": "using"
},
"arguments": [
{
"_type": "LiteralExpr",
"value": "02_custom_types.midas"
}
],
"keywords": {}
}
},
{
"_type": "TypeAssign",
"name": "df",
@@ -39,6 +18,80 @@
]
}
},
{
"_type": "TypeAssign",
"name": "lat",
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "BaseType",
"base": "GeoLocation",
"param": null
}
}
},
{
"_type": "AssignStmt",
"targets": [
{
"_type": "VariableExpr",
"name": "lat"
}
],
"value": {
"_type": "GetExpr",
"object": {
"_type": "SubscriptExpr",
"object": {
"_type": "VariableExpr",
"name": "df"
},
"index": {
"_type": "LiteralExpr",
"value": "location"
}
},
"name": "lat"
}
},
{
"_type": "TypeAssign",
"name": "lon",
"type": {
"_type": "BaseType",
"base": "Column",
"param": {
"_type": "BaseType",
"base": "GeoLocation",
"param": null
}
}
},
{
"_type": "AssignStmt",
"targets": [
{
"_type": "VariableExpr",
"name": "lon"
}
],
"value": {
"_type": "GetExpr",
"object": {
"_type": "SubscriptExpr",
"object": {
"_type": "VariableExpr",
"name": "df"
},
"index": {
"_type": "LiteralExpr",
"value": "location"
}
},
"name": "lon"
}
},
{
"_type": "ExpressionStmt",
"expr": {
@@ -54,6 +107,64 @@
}
}
},
{
"_type": "TypeAssign",
"name": "lat1",
"type": {
"_type": "BaseType",
"base": "Latitude",
"param": null
}
},
{
"_type": "AssignStmt",
"targets": [
{
"_type": "VariableExpr",
"name": "lat1"
}
],
"value": {
"_type": "SubscriptExpr",
"object": {
"_type": "VariableExpr",
"name": "lat"
},
"index": {
"_type": "LiteralExpr",
"value": 0
}
}
},
{
"_type": "TypeAssign",
"name": "lat2",
"type": {
"_type": "BaseType",
"base": "Latitude",
"param": null
}
},
{
"_type": "AssignStmt",
"targets": [
{
"_type": "VariableExpr",
"name": "lat2"
}
],
"value": {
"_type": "SubscriptExpr",
"object": {
"_type": "VariableExpr",
"name": "lat"
},
"index": {
"_type": "LiteralExpr",
"value": 1
}
}
},
{
"_type": "TypeAssign",
"name": "lat_diff",

View File

@@ -1,19 +1,19 @@
import ast
import json
from dataclasses import asdict, dataclass, field
from pathlib import Path
import midas.ast.python as p
from midas.checker.checker import Checker
from midas.checker.checker import TypeChecker
from midas.checker.diagnostic import Diagnostic
from midas.parser.python import PythonParser
from midas.resolver.resolver import Resolver
from midas.checker.types import Type
from tests.base import Tester
from tests.serializer.python import PythonAstJsonSerializer
@dataclass
class CaseResult:
diagnostics: list[dict] = field(default_factory=list)
judgments: list = field(default_factory=list)
def dumps(self) -> str:
return json.dumps(asdict(self), indent=2)
@@ -33,15 +33,16 @@ class CheckerTester(Tester):
if not path.is_file():
raise TypeError(f"Test '{path}' is not a file")
source: str = path.read_text()
tree: ast.Module = ast.parse(source, filename=path)
parser = PythonParser()
stmts: list[p.Stmt] = parser.parse_module(tree)
resolver = Resolver()
resolver.resolve(*stmts)
result: CaseResult = CaseResult()
checker = Checker(resolver.locals, file_path=path)
diagnostics: list[Diagnostic] = checker.check(stmts)
checker = TypeChecker()
types_path: Path = path.with_suffix(".midas")
if types_path.exists():
checker.import_midas(types_path)
checker.type_check(path)
diagnostics: list[Diagnostic] = checker.diagnostics
for diagnostic in diagnostics:
result.diagnostics.append(
{
@@ -60,6 +61,21 @@ class CheckerTester(Tester):
}
)
judgements: list[tuple[p.Expr, Type]] = checker.python_typer.judgements
serializer = PythonAstJsonSerializer()
for expr, type in judgements:
loc = expr.location
result.judgments.append(
{
"location": {
"from": f"L{loc.lineno}:{loc.col_offset}",
"to": f"L{loc.end_lineno}:{loc.end_col_offset}",
},
"expr": expr.accept(serializer),
"type": asdict(type),
}
)
return result

View File

@@ -2,79 +2,76 @@ from typing import Optional, Sequence
from midas.ast.midas import (
BinaryExpr,
ComplexTypeStmt,
ComplexType,
ConstraintType,
Expr,
ExtendStmt,
ExtensionType,
FunctionType,
GenericType,
GetExpr,
GroupingExpr,
LiteralExpr,
LogicalExpr,
OpStmt,
MemberStmt,
NamedType,
PredicateStmt,
PropertyStmt,
SimpleTypeExpr,
SimpleTypeStmt,
Stmt,
TemplateExpr,
TypeExpr,
Type,
TypeParam,
TypeStmt,
UnaryExpr,
VariableExpr,
WildcardExpr,
)
class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
class MidasAstJsonSerializer(
Stmt.Visitor[dict], Expr.Visitor[dict], Type.Visitor[dict]
):
"""An AST serializer which produces a JSON-compatible structure"""
def serialize(self, stmts: list[Stmt]) -> list[dict]:
return [stmt.accept(self) for stmt in stmts]
def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]:
def _serialize_optional(
self, element: Optional[Stmt | Expr | Type]
) -> Optional[dict]:
if element is None:
return None
return element.accept(self)
def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]:
def _serialize_list(self, elements: Sequence[Stmt | Expr | Type]) -> list[dict]:
return [element.accept(self) for element in elements]
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict:
def visit_type_stmt(self, stmt: TypeStmt) -> dict:
return {
"_type": "SimpleTypeStmt",
"_type": "TypeStmt",
"name": stmt.name.lexeme,
"template": self._serialize_optional(stmt.template),
"base": stmt.base.accept(self),
"constraint": self._serialize_optional(stmt.constraint),
"params": [self._serialize_type_param(param) for param in stmt.params],
"type": stmt.type.accept(self),
}
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict:
def _serialize_type_param(self, param: TypeParam) -> dict:
return {
"_type": "ComplexTypeStmt",
"name": stmt.name.lexeme,
"template": self._serialize_optional(stmt.template),
"properties": self._serialize_list(stmt.properties),
"name": param.name.lexeme,
"bound": self._serialize_optional(param.bound),
}
def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
def visit_member_stmt(self, stmt: MemberStmt) -> dict:
return {
"_type": "PropertyStmt",
"_type": "MemberStmt",
"kind": stmt.kind.name,
"name": stmt.name.lexeme,
"type": stmt.type.accept(self),
"constraint": self._serialize_optional(stmt.constraint),
}
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
return {
"_type": "ExtendStmt",
"type": stmt.type.accept(self),
"operations": self._serialize_list(stmt.operations),
}
def visit_op_stmt(self, stmt: OpStmt) -> dict:
return {
"_type": "OpStmt",
"name": stmt.name.lexeme,
"operand": stmt.operand.accept(self),
"result": stmt.result.accept(self),
"params": [self._serialize_type_param(param) for param in stmt.params],
"members": self._serialize_list(stmt.members),
}
def visit_predicate_stmt(self, stmt: PredicateStmt) -> dict:
@@ -86,13 +83,6 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
"condition": stmt.condition.accept(self),
}
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict:
return {
"_type": "SimpleTypeExpr",
"name": expr.name.lexeme,
"optional": expr.optional,
}
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
return {
"_type": "LogicalExpr",
@@ -144,16 +134,51 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
return {"_type": "WildcardExpr"}
def visit_template_expr(self, expr: TemplateExpr) -> dict:
def visit_named_type(self, type: NamedType) -> dict:
return {
"_type": "TemplateExpr",
"type": expr.type.accept(self),
"_type": "NamedType",
"name": type.name.lexeme,
}
def visit_type_expr(self, expr: TypeExpr) -> dict:
def visit_generic_type(self, type: GenericType) -> dict:
return {
"_type": "TypeExpr",
"name": expr.name.lexeme,
"template": self._serialize_optional(expr.template),
"optional": expr.optional,
"_type": "GenericType",
"type": type.type.accept(self),
"args": self._serialize_list(type.args),
}
def visit_constraint_type(self, type: ConstraintType) -> dict:
return {
"_type": "ConstraintType",
"type": type.type.accept(self),
"constraint": type.constraint.accept(self),
}
def visit_complex_type(self, type: ComplexType) -> dict:
return {
"_type": "ComplexType",
"members": self._serialize_list(type.members),
}
def visit_function_type(self, type: FunctionType) -> dict:
return {
"_type": "FunctionType",
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
"args": [self._serialize_func_arg(arg) for arg in type.args],
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
"returns": type.returns.accept(self),
}
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
return {
"name": arg.name,
"type": arg.type.accept(self),
"required": arg.required,
}
def visit_extension_type(self, type: ExtensionType) -> dict:
return {
"_type": "ExtensionType",
"base": type.base.accept(self),
"extension": type.extension.accept(self),
}

View File

@@ -15,12 +15,16 @@ from midas.ast.python import (
FrameType,
Function,
GetExpr,
IfStmt,
ListExpr,
LiteralExpr,
LogicalExpr,
MidasType,
ReturnStmt,
SetExpr,
SliceExpr,
Stmt,
SubscriptExpr,
TernaryExpr,
TypeAssign,
UnaryExpr,
VariableExpr,
@@ -164,6 +168,14 @@ class PythonAstJsonSerializer(
"value": self._serialize_optional(stmt.value),
}
def visit_if_stmt(self, stmt: IfStmt) -> dict:
return {
"_type": "IfStmt",
"test": stmt.test.accept(self),
"body": self._serialize_list(stmt.body),
"orelse": self._serialize_list(stmt.orelse),
}
def visit_binary_expr(self, expr: BinaryExpr) -> dict:
return {
"_type": "BinaryExpr",
@@ -222,17 +234,38 @@ class PythonAstJsonSerializer(
"right": expr.right.accept(self),
}
def visit_set_expr(self, expr: SetExpr) -> dict:
return {
"_type": "SetExpr",
"object": expr.object.accept(self),
"name": expr.name,
"value": expr.value.accept(self),
}
def visit_cast_expr(self, expr: CastExpr) -> dict:
return {
"_type": "CastExpr",
"type": expr.type.accept(self),
"expr": expr.expr.accept(self),
}
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
return {
"_type": "TernaryExpr",
"test": expr.test.accept(self),
"if_true": expr.if_true.accept(self),
"if_false": expr.if_false.accept(self),
}
def visit_list_expr(self, expr: ListExpr) -> dict:
return {
"_type": "ListExpr",
"items": [item.accept(self) for item in expr.items],
}
def visit_subscript_expr(self, expr: SubscriptExpr) -> dict:
return {
"_type": "SubscriptExpr",
"object": expr.object.accept(self),
"index": expr.index.accept(self),
}
def visit_slice_expr(self, expr: SliceExpr) -> dict:
return {
"_type": "SliceExpr",
"lower": self._serialize_optional(expr.lower),
"upper": self._serialize_optional(expr.upper),
"step": self._serialize_optional(expr.step),
}