Compare commits
91 Commits
35ceda99aa
...
feat/gener
| Author | SHA1 | Date | |
|---|---|---|---|
|
04c0d683de
|
|||
|
ac620f318b
|
|||
|
947e9f0149
|
|||
|
c1ca254b51
|
|||
|
0bb862a1db
|
|||
|
54919a3565
|
|||
|
a5f0140013
|
|||
|
b0af01d906
|
|||
|
6048ee020f
|
|||
|
f815faa2f8
|
|||
|
f7d5d36d44
|
|||
|
503f2b6a0a
|
|||
|
778117664f
|
|||
|
afe3eefbbf
|
|||
|
96495e9f79
|
|||
|
77263139f6
|
|||
|
4f5967a151
|
|||
|
2a714a1021
|
|||
|
dafe0b471a
|
|||
|
a1f2937e16
|
|||
|
2063d94dce
|
|||
|
22fc8010d8
|
|||
|
aff1097d91
|
|||
|
12d034fd1e
|
|||
|
200709cca6
|
|||
|
700284296c
|
|||
|
0b53259b90
|
|||
|
0461a4184c
|
|||
|
01d6e41893
|
|||
|
80e611e49c
|
|||
|
c00915966f
|
|||
|
beaa4d95d8
|
|||
|
bfa0bb3ee0
|
|||
|
31158df2a9
|
|||
|
c6ead886ec
|
|||
|
9de03bf2b5
|
|||
|
a26b9293be
|
|||
|
efa5454776
|
|||
|
b8bb8190c4
|
|||
|
a4f5db7ece
|
|||
|
fc67f01f34
|
|||
|
0a748a36a3
|
|||
|
89fdd1b47e
|
|||
|
0cde53ac6e
|
|||
|
f3ec3606c2
|
|||
|
67ec029529
|
|||
|
e2aef7a811
|
|||
|
86ba4e658a
|
|||
|
7eccf59558
|
|||
|
9dd7801d2d
|
|||
|
154cb8b314
|
|||
|
c64ab434b5
|
|||
|
25e6410546
|
|||
|
8a22acc17c
|
|||
|
e0179bc442
|
|||
|
e665d03533
|
|||
|
b8cb2b4273
|
|||
|
d278dc5f5b
|
|||
|
59e73f0fd9
|
|||
|
3e0dc60283
|
|||
|
c24eb5125e
|
|||
|
25bd895dde
|
|||
|
bccd75317e
|
|||
|
f0e3f7574f
|
|||
|
5d44081847
|
|||
|
2a2bb0aec7
|
|||
|
67c40a3909
|
|||
|
1c30188122
|
|||
|
82a0f13242
|
|||
| 288d15a9bc | |||
|
504703d0f7
|
|||
|
e48895d0af
|
|||
| 13d32d0d27 | |||
| 19b9fdd623 | |||
|
ddcaebb51a
|
|||
|
f182312cd2
|
|||
|
73b21789d5
|
|||
|
5d7c724bc8
|
|||
|
74b297c89c
|
|||
|
822a74acce
|
|||
|
9a934fabfd
|
|||
|
828ec9a3fa
|
|||
|
63a43d79dd
|
|||
|
029caf4526
|
|||
|
1c5c418f1c
|
|||
|
a4139d4652
|
|||
|
2fd2071d40
|
|||
|
97b1ee8ab8
|
|||
|
dee479def5
|
|||
|
c8536e20d2
|
|||
|
d70137775f
|
79
README.md
79
README.md
@@ -5,3 +5,82 @@
|
|||||||
*Midas* aims at providing Python developers with a simple annotation system to enable compile-time integrity and data type checks, as well as generating runtime assertions.
|
*Midas* aims at providing Python developers with a simple annotation system to enable compile-time integrity and data type checks, as well as generating runtime assertions.
|
||||||
|
|
||||||
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
|
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- Python 3.11+
|
||||||
|
- [uv](https://docs.astral.sh/uv/getting-started/installation/)
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
1. Clone the repository
|
||||||
|
```shell
|
||||||
|
git clone https://git.kb28.ch/HEL/midas.git
|
||||||
|
```
|
||||||
|
2. Go in the project directory
|
||||||
|
```shell
|
||||||
|
cd midas
|
||||||
|
```
|
||||||
|
3. Install the CLI as a user-wide tool
|
||||||
|
```shell
|
||||||
|
uv tool install .
|
||||||
|
```
|
||||||
|
4. You can now run the `midas` command from anywhere
|
||||||
|
```shell
|
||||||
|
midas --help
|
||||||
|
```
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
### Compiling
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> In the current state of the project, the `compile` command doesn't generate any runnable code, it only runs the parsers and type checker on the provided files
|
||||||
|
|
||||||
|
```shell
|
||||||
|
midas compile -t types.midas source.py
|
||||||
|
```
|
||||||
|
|
||||||
|
With the `compile` command, you can process a source Python file, with any number of custom type definition files (`-t FILE` option), and the type checker will verify the coherence of your program and generate the runnable code with valid syntax and runtime assertions.
|
||||||
|
|
||||||
|
The optional `-l FILE` option lets you produce a highlighted version of the source code showing diagnostics from the type checker (see [Highlighting](#highlighting))
|
||||||
|
|
||||||
|
### Highlighting
|
||||||
|
|
||||||
|
```shell
|
||||||
|
midas utils highlight source.py
|
||||||
|
# or
|
||||||
|
midas utils highlight types.midas
|
||||||
|
```
|
||||||
|
|
||||||
|
The `highlight` command takes in a source file (Python or Midas), runs the appropriate parser and outputs an HTML file containing the source code with added highlighting. This highlighting takes the form of hoverable annotations showing some of the parsed structures (e.g. a function definition, an assignment, a generic type, etc.)
|
||||||
|
|
||||||
|
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
|
||||||
|
|
||||||
|
### Dumping the AST
|
||||||
|
|
||||||
|
```shell
|
||||||
|
midas utils dump-ast source.py
|
||||||
|
# or
|
||||||
|
midas utils dump-ast types.midas
|
||||||
|
```
|
||||||
|
|
||||||
|
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `-p` flags lets you toggle the custom AST parsing. Without `-p`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
|
||||||
|
|
||||||
|
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
|
||||||
|
|
||||||
|
## Tests
|
||||||
|
|
||||||
|
Several snapshot tests are available to assert the good behaviour of the parsers and type checker. They can be run as follows:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
uv run -m tests.midas run -a
|
||||||
|
uv run -m tests.python run -a
|
||||||
|
uv run -m tests.checker run -a
|
||||||
|
```
|
||||||
|
|
||||||
|
**Available subcommands:**
|
||||||
|
- Run all tests: `run -a`
|
||||||
|
- Run specific tests: `run tests/cases/test1.py tests/cases/test2.py ...`
|
||||||
|
- Update all tests: `update -a`
|
||||||
|
- Update specific tests: `update tests/cases/test1.py tests/cases/test2.py ...`
|
||||||
|
|||||||
@@ -2,10 +2,6 @@
|
|||||||
# ruff: disable[F821]
|
# ruff: disable[F821]
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
# Prototype of custom type import to use valid Python syntax
|
|
||||||
import midas
|
|
||||||
midas.using("02_custom_types.midas")
|
|
||||||
|
|
||||||
# A data-frame using a custom type
|
# A data-frame using a custom type
|
||||||
df: Frame[
|
df: Frame[
|
||||||
location: GeoLocation
|
location: GeoLocation
|
||||||
|
|||||||
@@ -9,3 +9,5 @@ d = True
|
|||||||
e = d + d
|
e = d + d
|
||||||
|
|
||||||
f: float = a
|
f: float = a
|
||||||
|
|
||||||
|
f = -f
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
type Meter(float)
|
type Meter = float
|
||||||
type Second(float)
|
type Second = float
|
||||||
type MeterPerSecond(float)
|
type MeterPerSecond = float
|
||||||
|
|
||||||
extend Meter {
|
extend Meter {
|
||||||
op __add__(Meter) -> Meter
|
def __add__: fn(Meter, /) -> Meter
|
||||||
op __sub__(Meter) -> Meter
|
def __sub__: fn(Meter, /) -> Meter
|
||||||
op __truediv__(Second) -> MeterPerSecond
|
def __truediv__: fn(Second, /) -> MeterPerSecond
|
||||||
}
|
}
|
||||||
|
|
||||||
extend Second {
|
extend Second {
|
||||||
op __add__(Second) -> Second
|
def __add__: fn(Second, /) -> Second
|
||||||
op __sub__(Second) -> Second
|
def __sub__: fn(Second, /) -> Second
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
# type: ignore
|
# type: ignore
|
||||||
# ruff: disable [F821]
|
# ruff: disable [F821]
|
||||||
|
|
||||||
midas.using("02_simple_types.midas")
|
|
||||||
|
|
||||||
distance: Meter = cast(Meter, 123.45)
|
distance: Meter = cast(Meter, 123.45)
|
||||||
time: Second = cast(Second, 6.7)
|
time: Second = cast(Second, 6.7)
|
||||||
speed = distance / time
|
speed = distance / time
|
||||||
|
|||||||
@@ -4,13 +4,20 @@ def minimum(x: int, y: int):
|
|||||||
else:
|
else:
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
a = 15
|
a = 15
|
||||||
b = 72
|
b = 72
|
||||||
c = minimum(a, b)
|
c = minimum(a, b)
|
||||||
|
|
||||||
|
|
||||||
def factorial(n: int) -> int:
|
def factorial(n: int) -> int:
|
||||||
if n <= 1:
|
if n <= 1:
|
||||||
return 1
|
return 1
|
||||||
return n * factorial(n - 1)
|
return n * factorial(n - 1)
|
||||||
|
|
||||||
category = "Category 1" if a < 10 else "Category 2"
|
|
||||||
|
category = "Category 1" if a < 10 else "Category 2"
|
||||||
|
|
||||||
|
|
||||||
|
def foo() -> None:
|
||||||
|
pass
|
||||||
|
|||||||
21
examples/01_simple_type_checking/04_complex_types.midas
Normal file
21
examples/01_simple_type_checking/04_complex_types.midas
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
type Meter = float
|
||||||
|
|
||||||
|
extend Meter {
|
||||||
|
def __add__: fn(Meter, /) -> Meter
|
||||||
|
def __sub__: fn(Meter, /) -> Meter
|
||||||
|
}
|
||||||
|
|
||||||
|
type Coordinate = object
|
||||||
|
|
||||||
|
extend Coordinate {
|
||||||
|
prop x: Meter
|
||||||
|
prop y: Meter
|
||||||
|
}
|
||||||
|
|
||||||
|
type Difference[T <: float] = T
|
||||||
|
type MeterDifference = Difference[Meter]
|
||||||
|
|
||||||
|
type CompDiff[T <: float] = {
|
||||||
|
prop d1: Difference[T]
|
||||||
|
prop d2: Difference[T]
|
||||||
|
}
|
||||||
35
examples/01_simple_type_checking/04_complex_types.py
Normal file
35
examples/01_simple_type_checking/04_complex_types.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# type: ignore
|
||||||
|
# ruff: disable [F821]
|
||||||
|
|
||||||
|
p1: Coordinate
|
||||||
|
p2: Coordinate
|
||||||
|
|
||||||
|
diff_x = p2.x - p1.x
|
||||||
|
diff_y = p2.y - p1.y
|
||||||
|
|
||||||
|
dist = diff_x + diff_y
|
||||||
|
|
||||||
|
p2.x += cast(Meter, 1)
|
||||||
|
p2.y = True # invalid, wrong type
|
||||||
|
p2.z = 3 # invalid, no property 'z' on Coordinate
|
||||||
|
p2.x.a = 3 # invalid, no properties on Meter
|
||||||
|
|
||||||
|
foo: list[float] = []
|
||||||
|
|
||||||
|
append = foo.append
|
||||||
|
|
||||||
|
foo.append("") # invalid, must be float
|
||||||
|
foo.append(2)
|
||||||
|
append(True) # invalid, must be float
|
||||||
|
append(2)
|
||||||
|
|
||||||
|
bar: list[list[Meter]]
|
||||||
|
|
||||||
|
bar.append([p2.x])
|
||||||
|
|
||||||
|
foo2 = foo + foo
|
||||||
|
|
||||||
|
a = foo[0]
|
||||||
|
b = bar[0][1]
|
||||||
|
c = bar[0][1][2] # invalid, not method __getitem__ on Meter
|
||||||
|
c = bar[""] # invalid, wrong index type
|
||||||
15
gen/gen.py
15
gen/gen.py
@@ -30,6 +30,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
{preamble}
|
||||||
{sections}
|
{sections}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile(
|
|||||||
re.MULTILINE | re.DOTALL,
|
re.MULTILINE | re.DOTALL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
PREAMBLE_REGEX = re.compile(
|
||||||
|
r"^###>\s*Preamble\s*?\n(?P<body>.*?)\n###<$",
|
||||||
|
re.MULTILINE | re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def snake_case(text: str) -> str:
|
def snake_case(text: str) -> str:
|
||||||
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
||||||
@@ -88,13 +94,14 @@ def make_banner(text: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||||
|
print(f" Generating {full_name}")
|
||||||
visitor_methods: list[str] = []
|
visitor_methods: list[str] = []
|
||||||
classes: list[str] = []
|
classes: list[str] = []
|
||||||
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
||||||
for cls in definitions:
|
for cls in definitions:
|
||||||
cls = cls.strip("\n")
|
cls = cls.strip("\n")
|
||||||
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
||||||
print(f"Processing {name}")
|
print(f" Processing {name}")
|
||||||
visitor_methods.append(make_visitor_method(name, param))
|
visitor_methods.append(make_visitor_method(name, param))
|
||||||
classes.append(make_class(name, cls, base))
|
classes.append(make_class(name, cls, base))
|
||||||
|
|
||||||
@@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def generate(definitions_path: Path, out_path: Path):
|
def generate(definitions_path: Path, out_path: Path):
|
||||||
|
print(f"Processing generating {out_path} from {definitions_path}")
|
||||||
root_dir: Path = Path(__file__).parent.parent
|
root_dir: Path = Path(__file__).parent.parent
|
||||||
rel_path: Path = definitions_path.relative_to(root_dir)
|
rel_path: Path = definitions_path.relative_to(root_dir)
|
||||||
src: str = definitions_path.read_text()
|
src: str = definitions_path.read_text()
|
||||||
@@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path):
|
|||||||
if m := IMPORTS_REGEX.search(src):
|
if m := IMPORTS_REGEX.search(src):
|
||||||
imports = m.group("body").strip("\n")
|
imports = m.group("body").strip("\n")
|
||||||
|
|
||||||
|
preamble: str = ""
|
||||||
|
if m := PREAMBLE_REGEX.search(src):
|
||||||
|
preamble = m.group("body")
|
||||||
|
|
||||||
for section_m in SECTION_REGEX.finditer(src):
|
for section_m in SECTION_REGEX.finditer(src):
|
||||||
full_name: str = section_m.group("name")
|
full_name: str = section_m.group("name")
|
||||||
base: str = section_m.group("base")
|
base: str = section_m.group("base")
|
||||||
@@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path):
|
|||||||
gen_path=Path(__file__).relative_to(root_dir),
|
gen_path=Path(__file__).relative_to(root_dir),
|
||||||
),
|
),
|
||||||
imports=imports,
|
imports=imports,
|
||||||
|
preamble=preamble,
|
||||||
sections="\n\n\n".join(sections),
|
sections="\n\n\n".join(sections),
|
||||||
)
|
)
|
||||||
out_path.write_text(result)
|
out_path.write_text(result)
|
||||||
|
|||||||
64
gen/midas.py
64
gen/midas.py
@@ -4,6 +4,7 @@
|
|||||||
###> Imports
|
###> Imports
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
from typing import Any, Generic, Optional, TypeVar
|
from typing import Any, Generic, Optional, TypeVar
|
||||||
|
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
@@ -12,33 +13,39 @@ from midas.lexer.token import Token
|
|||||||
###<
|
###<
|
||||||
|
|
||||||
|
|
||||||
|
###> Preamble
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TypeParam:
|
||||||
|
location: Location
|
||||||
|
name: Token
|
||||||
|
bound: Optional[Type]
|
||||||
|
|
||||||
|
|
||||||
|
class MemberKind(Enum):
|
||||||
|
PROPERTY = auto()
|
||||||
|
METHOD = auto()
|
||||||
|
|
||||||
|
|
||||||
|
###<
|
||||||
|
|
||||||
|
|
||||||
###> Stmt | Statements
|
###> Stmt | Statements
|
||||||
class TypeStmt:
|
class TypeStmt:
|
||||||
name: Token
|
name: Token
|
||||||
params: list[Param]
|
params: list[TypeParam]
|
||||||
type: Type
|
type: Type
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
|
||||||
class Param:
|
|
||||||
location: Location
|
|
||||||
name: Token
|
|
||||||
bound: Optional[Type]
|
|
||||||
|
|
||||||
|
class MemberStmt:
|
||||||
class PropertyStmt:
|
|
||||||
name: Token
|
name: Token
|
||||||
type: Type
|
type: Type
|
||||||
|
kind: MemberKind
|
||||||
|
|
||||||
|
|
||||||
class ExtendStmt:
|
class ExtendStmt:
|
||||||
type: Type
|
|
||||||
operations: list[OpStmt]
|
|
||||||
|
|
||||||
|
|
||||||
class OpStmt:
|
|
||||||
name: Token
|
name: Token
|
||||||
operand: Type
|
params: list[TypeParam]
|
||||||
result: Type
|
members: list[MemberStmt]
|
||||||
|
|
||||||
|
|
||||||
class PredicateStmt:
|
class PredicateStmt:
|
||||||
@@ -103,7 +110,7 @@ class NamedType:
|
|||||||
|
|
||||||
class GenericType:
|
class GenericType:
|
||||||
type: Type
|
type: Type
|
||||||
params: list[Type]
|
args: list[Type]
|
||||||
|
|
||||||
|
|
||||||
class ConstraintType:
|
class ConstraintType:
|
||||||
@@ -111,12 +118,27 @@ class ConstraintType:
|
|||||||
constraint: Expr
|
constraint: Expr
|
||||||
|
|
||||||
|
|
||||||
class UnionType:
|
|
||||||
types: list[Type]
|
|
||||||
|
|
||||||
|
|
||||||
class ComplexType:
|
class ComplexType:
|
||||||
properties: list[PropertyStmt]
|
members: list[MemberStmt]
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionType:
|
||||||
|
base: Type
|
||||||
|
extension: ComplexType
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionType:
|
||||||
|
pos_args: list[Argument]
|
||||||
|
args: list[Argument]
|
||||||
|
kw_args: list[Argument]
|
||||||
|
returns: Type
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Argument:
|
||||||
|
location: Optional[Location] = None
|
||||||
|
name: Optional[Token]
|
||||||
|
type: Type
|
||||||
|
required: bool
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|||||||
@@ -128,12 +128,6 @@ class LogicalExpr:
|
|||||||
right: Expr
|
right: Expr
|
||||||
|
|
||||||
|
|
||||||
class SetExpr:
|
|
||||||
object: Expr
|
|
||||||
name: str
|
|
||||||
value: Expr
|
|
||||||
|
|
||||||
|
|
||||||
class CastExpr:
|
class CastExpr:
|
||||||
type: MidasType
|
type: MidasType
|
||||||
expr: Expr
|
expr: Expr
|
||||||
@@ -145,4 +139,13 @@ class TernaryExpr:
|
|||||||
if_false: Expr
|
if_false: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class ListExpr:
|
||||||
|
items: list[Expr]
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptExpr:
|
||||||
|
object: Expr
|
||||||
|
index: Expr
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
from typing import Any, Generic, Optional, TypeVar
|
from typing import Any, Generic, Optional, TypeVar
|
||||||
|
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
@@ -14,6 +15,18 @@ from midas.lexer.token import Token
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TypeParam:
|
||||||
|
location: Location
|
||||||
|
name: Token
|
||||||
|
bound: Optional[Type]
|
||||||
|
|
||||||
|
|
||||||
|
class MemberKind(Enum):
|
||||||
|
PROPERTY = auto()
|
||||||
|
METHOD = auto()
|
||||||
|
|
||||||
|
|
||||||
##############
|
##############
|
||||||
# Statements #
|
# Statements #
|
||||||
##############
|
##############
|
||||||
@@ -31,14 +44,11 @@ class Stmt(ABC):
|
|||||||
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
|
def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ...
|
def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_op_stmt(self, stmt: OpStmt) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_predicate_stmt(self, stmt: PredicateStmt) -> T: ...
|
def visit_predicate_stmt(self, stmt: PredicateStmt) -> T: ...
|
||||||
|
|
||||||
@@ -46,47 +56,33 @@ class Stmt(ABC):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TypeStmt(Stmt):
|
class TypeStmt(Stmt):
|
||||||
name: Token
|
name: Token
|
||||||
params: list[Param]
|
params: list[TypeParam]
|
||||||
type: Type
|
type: Type
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
|
||||||
class Param:
|
|
||||||
location: Location
|
|
||||||
name: Token
|
|
||||||
bound: Optional[Type]
|
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
return visitor.visit_type_stmt(self)
|
return visitor.visit_type_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PropertyStmt(Stmt):
|
class MemberStmt(Stmt):
|
||||||
name: Token
|
name: Token
|
||||||
type: Type
|
type: Type
|
||||||
|
kind: MemberKind
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
return visitor.visit_property_stmt(self)
|
return visitor.visit_member_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ExtendStmt(Stmt):
|
class ExtendStmt(Stmt):
|
||||||
type: Type
|
name: Token
|
||||||
operations: list[OpStmt]
|
params: list[TypeParam]
|
||||||
|
members: list[MemberStmt]
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
return visitor.visit_extend_stmt(self)
|
return visitor.visit_extend_stmt(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class OpStmt(Stmt):
|
|
||||||
name: Token
|
|
||||||
operand: Type
|
|
||||||
result: Type
|
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_op_stmt(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PredicateStmt(Stmt):
|
class PredicateStmt(Stmt):
|
||||||
name: Token
|
name: Token
|
||||||
@@ -229,10 +225,13 @@ class Type(ABC):
|
|||||||
def visit_constraint_type(self, type: ConstraintType) -> T: ...
|
def visit_constraint_type(self, type: ConstraintType) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_union_type(self, type: UnionType) -> T: ...
|
def visit_complex_type(self, type: ComplexType) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_complex_type(self, type: ComplexType) -> T: ...
|
def visit_extension_type(self, type: ExtensionType) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_function_type(self, type: FunctionType) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -246,7 +245,7 @@ class NamedType(Type):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class GenericType(Type):
|
class GenericType(Type):
|
||||||
type: Type
|
type: Type
|
||||||
params: list[Type]
|
args: list[Type]
|
||||||
|
|
||||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
return visitor.visit_generic_type(self)
|
return visitor.visit_generic_type(self)
|
||||||
@@ -261,17 +260,36 @@ class ConstraintType(Type):
|
|||||||
return visitor.visit_constraint_type(self)
|
return visitor.visit_constraint_type(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class UnionType(Type):
|
|
||||||
types: list[Type]
|
|
||||||
|
|
||||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_union_type(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ComplexType(Type):
|
class ComplexType(Type):
|
||||||
properties: list[PropertyStmt]
|
members: list[MemberStmt]
|
||||||
|
|
||||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
return visitor.visit_complex_type(self)
|
return visitor.visit_complex_type(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ExtensionType(Type):
|
||||||
|
base: Type
|
||||||
|
extension: ComplexType
|
||||||
|
|
||||||
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_extension_type(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FunctionType(Type):
|
||||||
|
pos_args: list[Argument]
|
||||||
|
args: list[Argument]
|
||||||
|
kw_args: list[Argument]
|
||||||
|
returns: Type
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Argument:
|
||||||
|
location: Optional[Location] = None
|
||||||
|
name: Optional[Token]
|
||||||
|
type: Type
|
||||||
|
required: bool
|
||||||
|
|
||||||
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_function_type(self)
|
||||||
|
|||||||
@@ -100,20 +100,21 @@ class MidasAstPrinter(
|
|||||||
self._idx = i
|
self._idx = i
|
||||||
if i == len(stmt.params) - 1:
|
if i == len(stmt.params) - 1:
|
||||||
self._mark_last()
|
self._mark_last()
|
||||||
self._print_type_stmt_param(param)
|
self._print_type_param(param)
|
||||||
self._write_line("type", last=True)
|
self._write_line("type", last=True)
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
stmt.type.accept(self)
|
stmt.type.accept(self)
|
||||||
|
|
||||||
def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
|
def _print_type_param(self, param: m.TypeParam) -> None:
|
||||||
self._write_line("Param")
|
self._write_line("Param")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line(f'name: "{param.name.lexeme}"')
|
self._write_line(f'name: "{param.name.lexeme}"')
|
||||||
self._write_optional_child("bound", param.bound, last=True)
|
self._write_optional_child("bound", param.bound, last=True)
|
||||||
|
|
||||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||||
self._write_line("PropertyStmt")
|
self._write_line("MemberStmt")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
|
self._write_line(f"kind: {stmt.kind.name}")
|
||||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||||
self._write_line("type", last=True)
|
self._write_line("type", last=True)
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
@@ -122,29 +123,28 @@ class MidasAstPrinter(
|
|||||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||||
self._write_line("ExtendStmt")
|
self._write_line("ExtendStmt")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line("type")
|
self._write_line("params")
|
||||||
with self._child_level(single=True):
|
|
||||||
stmt.type.accept(self)
|
|
||||||
self._write_line("operations", last=True)
|
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
for i, op in enumerate(stmt.operations):
|
for i, param in enumerate(stmt.params):
|
||||||
self._idx = i
|
self._idx = i
|
||||||
if i == len(stmt.operations) - 1:
|
if i == len(stmt.params) - 1:
|
||||||
self._mark_last()
|
self._mark_last()
|
||||||
op.accept(self)
|
self._print_type_param(param)
|
||||||
|
|
||||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
|
||||||
self._write_line("OpStmt")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||||
|
self._write_line("params")
|
||||||
self._write_line("operand")
|
with self._child_level():
|
||||||
with self._child_level(single=True):
|
for i, param in enumerate(stmt.params):
|
||||||
stmt.operand.accept(self)
|
self._idx = i
|
||||||
|
if i == len(stmt.params) - 1:
|
||||||
self._write_line("result", last=True)
|
self._mark_last()
|
||||||
with self._child_level(single=True):
|
self._print_type_param(param)
|
||||||
stmt.result.accept(self)
|
self._write_line("members", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, member in enumerate(stmt.members):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.members) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
member.accept(self)
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||||
self._write_line("PredicateStmt")
|
self._write_line("PredicateStmt")
|
||||||
@@ -234,11 +234,11 @@ class MidasAstPrinter(
|
|||||||
self._write_line("type")
|
self._write_line("type")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
type.type.accept(self)
|
type.type.accept(self)
|
||||||
self._write_line("params", last=True)
|
self._write_line("args", last=True)
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
for i, param in enumerate(type.params):
|
for i, param in enumerate(type.args):
|
||||||
self._idx = i
|
self._idx = i
|
||||||
if i == len(type.params) - 1:
|
if i == len(type.args) - 1:
|
||||||
self._mark_last()
|
self._mark_last()
|
||||||
param.accept(self)
|
param.accept(self)
|
||||||
|
|
||||||
@@ -252,27 +252,69 @@ class MidasAstPrinter(
|
|||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
type.constraint.accept(self)
|
type.constraint.accept(self)
|
||||||
|
|
||||||
def visit_union_type(self, type: m.UnionType) -> None:
|
|
||||||
self._write_line("UnionType")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line("types", last=True)
|
|
||||||
with self._child_level():
|
|
||||||
for i, type_ in enumerate(type.types):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(type.types) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
type_.accept(self)
|
|
||||||
|
|
||||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||||
self._write_line("ComplexType")
|
self._write_line("ComplexType")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line("properties", last=True)
|
self._write_line("members", last=True)
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
for i, prop in enumerate(type.properties):
|
for i, member in enumerate(type.members):
|
||||||
self._idx = i
|
self._idx = i
|
||||||
if i == len(type.properties) - 1:
|
if i == len(type.members) - 1:
|
||||||
self._mark_last()
|
self._mark_last()
|
||||||
prop.accept(self)
|
member.accept(self)
|
||||||
|
|
||||||
|
def visit_extension_type(self, type: m.ExtensionType) -> None:
|
||||||
|
self._write_line("ExtensionType")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("base")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
type.base.accept(self)
|
||||||
|
self._write_line("extension", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
type.extension.accept(self)
|
||||||
|
|
||||||
|
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||||
|
self._write_line("FunctionType")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("pos_args")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(type.pos_args):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(type.pos_args) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
|
self._write_line("args")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(type.args):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(type.args) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
|
self._write_line("kw_args")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(type.kw_args):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(type.kw_args) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
|
self._write_line("returns", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
type.returns.accept(self)
|
||||||
|
|
||||||
|
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||||
|
self._write_line("Argument")
|
||||||
|
with self._child_level():
|
||||||
|
name: str = "None"
|
||||||
|
if arg.name is not None:
|
||||||
|
name = f'"{arg.name.lexeme}"'
|
||||||
|
self._write_line(f"name: {name}")
|
||||||
|
self._write_line("type")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
arg.type.accept(self)
|
||||||
|
self._write_line(f"required: {arg.required}", last=True)
|
||||||
|
|
||||||
|
|
||||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||||
@@ -283,45 +325,46 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
def indented(self, text: str) -> str:
|
def indented(self, text: str) -> str:
|
||||||
return " " * (self.level * self.indent) + text
|
return " " * (self.level * self.indent) + text
|
||||||
|
|
||||||
def print(self, expr: m.Expr | m.Stmt):
|
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
|
||||||
self.level = 0
|
self.level = 0
|
||||||
return expr.accept(self)
|
return expr.accept(self)
|
||||||
|
|
||||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
||||||
template: str = ""
|
template: str = ""
|
||||||
if len(stmt.params) != 0:
|
if len(stmt.params) != 0:
|
||||||
params: list[str] = [
|
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||||
self._print_type_template_param(param) for param in stmt.params
|
|
||||||
]
|
|
||||||
template = f"[{', '.join(params)}]"
|
template = f"[{', '.join(params)}]"
|
||||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||||
return self.indented(res)
|
return self.indented(res)
|
||||||
|
|
||||||
def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
|
def _print_type_param(self, param: m.TypeParam) -> str:
|
||||||
res: str = param.name.lexeme
|
res: str = param.name.lexeme
|
||||||
if param.bound is not None:
|
if param.bound is not None:
|
||||||
res += "<:" + param.bound.accept(self)
|
res += "<:" + param.bound.accept(self)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||||
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
|
keyword: str = {
|
||||||
|
m.MemberKind.PROPERTY: "prop",
|
||||||
|
m.MemberKind.METHOD: "def",
|
||||||
|
}.get(stmt.kind, "")
|
||||||
|
res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||||
return self.indented(res)
|
return self.indented(res)
|
||||||
|
|
||||||
def visit_extend_stmt(self, stmt: m.ExtendStmt):
|
def visit_extend_stmt(self, stmt: m.ExtendStmt):
|
||||||
res: str = self.indented(f"extend {stmt.type.accept(self)}")
|
template: str = ""
|
||||||
|
if len(stmt.params) != 0:
|
||||||
|
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||||
|
template = f"[{', '.join(params)}]"
|
||||||
|
res: str = self.indented(f"extend {stmt.name.lexeme}{template}")
|
||||||
res += " {\n"
|
res += " {\n"
|
||||||
self.level += 1
|
self.level += 1
|
||||||
for op in stmt.operations:
|
for member in stmt.members:
|
||||||
res += op.accept(self)
|
res += member.accept(self) + "\n"
|
||||||
self.level -= 1
|
self.level -= 1
|
||||||
res += "\n" + self.indented("}")
|
res += self.indented("}")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def visit_op_stmt(self, stmt: m.OpStmt):
|
|
||||||
operand: str = stmt.operand.accept(self)
|
|
||||||
result: str = stmt.result.accept(self)
|
|
||||||
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}")
|
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||||
name: str = stmt.name.lexeme
|
name: str = stmt.name.lexeme
|
||||||
subject: str = stmt.subject.lexeme
|
subject: str = stmt.subject.lexeme
|
||||||
@@ -369,9 +412,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
|
|
||||||
def visit_generic_type(self, type: m.GenericType) -> str:
|
def visit_generic_type(self, type: m.GenericType) -> str:
|
||||||
res: str = type.type.accept(self)
|
res: str = type.type.accept(self)
|
||||||
if len(type.params) != 0:
|
if len(type.args) != 0:
|
||||||
params: list[str] = [param.accept(self) for param in type.params]
|
args: list[str] = [param.accept(self) for param in type.args]
|
||||||
res += f"[{', '.join(params)}]"
|
res += f"[{', '.join(args)}]"
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
||||||
@@ -379,20 +422,44 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
res += " where " + type.constraint.accept(self)
|
res += " where " + type.constraint.accept(self)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def visit_union_type(self, type: m.UnionType) -> str:
|
|
||||||
types: list[str] = [type_.accept(self) for type_ in type.types]
|
|
||||||
return " | ".join(types)
|
|
||||||
|
|
||||||
def visit_complex_type(self, type: m.ComplexType) -> str:
|
def visit_complex_type(self, type: m.ComplexType) -> str:
|
||||||
res: str = "{\n"
|
res: str = "{\n"
|
||||||
self.level += 1
|
self.level += 1
|
||||||
for prop in type.properties:
|
for member in type.members:
|
||||||
res += prop.accept(self)
|
res += member.accept(self)
|
||||||
res += "\n"
|
res += "\n"
|
||||||
self.level -= 1
|
self.level -= 1
|
||||||
res += self.indented("}")
|
res += self.indented("}")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def visit_extension_type(self, type: m.ExtensionType) -> str:
|
||||||
|
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
||||||
|
|
||||||
|
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||||
|
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||||
|
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
|
||||||
|
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
|
||||||
|
args: list[str] = pos_args
|
||||||
|
|
||||||
|
if len(pos_args) != 0:
|
||||||
|
args.append("/")
|
||||||
|
args += mixed_args
|
||||||
|
if len(kw_args) != 0:
|
||||||
|
args.append("*")
|
||||||
|
args += kw_args
|
||||||
|
|
||||||
|
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
|
||||||
|
|
||||||
|
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||||
|
res: str = ""
|
||||||
|
if arg.name is not None:
|
||||||
|
res += arg.name.lexeme
|
||||||
|
res += ": "
|
||||||
|
res += arg.type.accept(self)
|
||||||
|
if not arg.required:
|
||||||
|
res += "?"
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class PythonAstPrinter(
|
class PythonAstPrinter(
|
||||||
AstPrinter,
|
AstPrinter,
|
||||||
@@ -597,7 +664,7 @@ class PythonAstPrinter(
|
|||||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||||
self._write_line("LiteralExpr")
|
self._write_line("LiteralExpr")
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
self._write_line(f"value: {expr.value}")
|
self._write_line(f"value: {expr.value!r}")
|
||||||
|
|
||||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||||
self._write_line("VariableExpr")
|
self._write_line("VariableExpr")
|
||||||
@@ -617,17 +684,6 @@ class PythonAstPrinter(
|
|||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.right.accept(self)
|
expr.right.accept(self)
|
||||||
|
|
||||||
def visit_set_expr(self, expr: p.SetExpr) -> None:
|
|
||||||
self._write_line("SetExpr")
|
|
||||||
with self._child_level():
|
|
||||||
self._write_line("object")
|
|
||||||
with self._child_level(single=True):
|
|
||||||
expr.object.accept(self)
|
|
||||||
self._write_line(f"name: {expr.name}")
|
|
||||||
self._write_line("value", last=True)
|
|
||||||
with self._child_level(single=True):
|
|
||||||
expr.value.accept(self)
|
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||||
self._write_line("CastExpr")
|
self._write_line("CastExpr")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
@@ -652,3 +708,24 @@ class PythonAstPrinter(
|
|||||||
self._write_line("if_false", last=True)
|
self._write_line("if_false", last=True)
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.if_false.accept(self)
|
expr.if_false.accept(self)
|
||||||
|
|
||||||
|
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||||
|
self._write_line("ListExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("items", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, item in enumerate(expr.items):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(expr.items) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
item.accept(self)
|
||||||
|
|
||||||
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||||
|
self._write_line("SubscriptExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("object")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.object.accept(self)
|
||||||
|
self._write_line("index", last=True)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.index.accept(self)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from midas.ast.location import Location
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Type annotations #
|
# Type annotations #
|
||||||
####################
|
####################
|
||||||
@@ -214,15 +215,18 @@ class Expr(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visit_set_expr(self, expr: SetExpr) -> T: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_cast_expr(self, expr: CastExpr) -> T: ...
|
def visit_cast_expr(self, expr: CastExpr) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
|
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_list_expr(self, expr: ListExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class BinaryExpr(Expr):
|
class BinaryExpr(Expr):
|
||||||
@@ -298,16 +302,6 @@ class LogicalExpr(Expr):
|
|||||||
return visitor.visit_logical_expr(self)
|
return visitor.visit_logical_expr(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class SetExpr(Expr):
|
|
||||||
object: Expr
|
|
||||||
name: str
|
|
||||||
value: Expr
|
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
|
||||||
return visitor.visit_set_expr(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class CastExpr(Expr):
|
class CastExpr(Expr):
|
||||||
type: MidasType
|
type: MidasType
|
||||||
@@ -325,3 +319,20 @@ class TernaryExpr(Expr):
|
|||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
return visitor.visit_ternary_expr(self)
|
return visitor.visit_ternary_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ListExpr(Expr):
|
||||||
|
items: list[Expr]
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_list_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SubscriptExpr(Expr):
|
||||||
|
object: Expr
|
||||||
|
index: Expr
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_subscript_expr(self)
|
||||||
|
|||||||
152
midas/checker/builtins.midas
Normal file
152
midas/checker/builtins.midas
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
extend float {
|
||||||
|
def hex: fn() -> str
|
||||||
|
def is_integer: fn() -> bool
|
||||||
|
prop real: float
|
||||||
|
prop imag: float
|
||||||
|
def conjugate: fn() -> float
|
||||||
|
def __add__: fn(value: float, /) -> float
|
||||||
|
def __sub__: fn(value: float, /) -> float
|
||||||
|
def __mul__: fn(value: float, /) -> float
|
||||||
|
def __floordiv__: fn(value: float, /) -> float
|
||||||
|
def __truediv__: fn(value: float, /) -> float
|
||||||
|
def __mod__: fn(value: float, /) -> float
|
||||||
|
// def __divmod__: fn(value: float, /) -> tuple[float, float]
|
||||||
|
|
||||||
|
def __pow__: fn(value: int, /) -> float
|
||||||
|
// positive __value -> float; negative __value -> complex
|
||||||
|
// return type must be Any as `float | complex` causes too many false-positive errors
|
||||||
|
def __pow__: fn(value: float, /) -> Any
|
||||||
|
def __radd__: fn(value: float, /) -> float
|
||||||
|
def __rsub__: fn(value: float, /) -> float
|
||||||
|
def __rmul__: fn(value: float, /) -> float
|
||||||
|
def __rfloordiv__: fn(value: float, /) -> float
|
||||||
|
def __rtruediv__: fn(value: float, /) -> float
|
||||||
|
def __rmod__: fn(value: float, /) -> float
|
||||||
|
// def __rdivmod__: fn(value: float, /) -> tuple[float, float]
|
||||||
|
// def __rpow__: fn(value: _PositiveInteger, mod: None = None, /) -> float
|
||||||
|
// def __rpow__: fn(value: _NegativeInteger, mod: None = None, /) -> complex
|
||||||
|
// Returning `complex` for the general case gives too many false-positive errors.
|
||||||
|
// def __rpow__: fn(value: float, mod: None = None, /) -> Any
|
||||||
|
// def __getnewargs__: fn() -> tuple[float]
|
||||||
|
def __trunc__: fn() -> int
|
||||||
|
def __ceil__: fn() -> int
|
||||||
|
def __floor__: fn() -> int
|
||||||
|
def __round__: fn(ndigits: None?, /) -> int
|
||||||
|
def __round__: fn(ndigits: int, /) -> float
|
||||||
|
def __eq__: fn(value: object, /) -> bool
|
||||||
|
def __ne__: fn(value: object, /) -> bool
|
||||||
|
def __lt__: fn(value: float, /) -> bool
|
||||||
|
def __le__: fn(value: float, /) -> bool
|
||||||
|
def __gt__: fn(value: float, /) -> bool
|
||||||
|
def __ge__: fn(value: float, /) -> bool
|
||||||
|
def __neg__: fn() -> float
|
||||||
|
def __pos__: fn() -> float
|
||||||
|
def __int__: fn() -> int
|
||||||
|
def __float__: fn() -> float
|
||||||
|
def __abs__: fn() -> float
|
||||||
|
def __hash__: fn() -> int
|
||||||
|
def __bool__: fn() -> bool
|
||||||
|
def __format__: fn(format_spec: str, /) -> str
|
||||||
|
}
|
||||||
|
|
||||||
|
extend int {
|
||||||
|
prop real: int
|
||||||
|
prop imag: int
|
||||||
|
prop numerator: int
|
||||||
|
prop denominator: int
|
||||||
|
def conjugate: fn() -> int
|
||||||
|
def bit_length: fn() -> int
|
||||||
|
def bit_count: fn() -> int
|
||||||
|
// def to_bytes: fn(length: int?, byteorder: str?, *, signed: bool?) -> bytes
|
||||||
|
|
||||||
|
def __add__: fn(value: int, /) -> int
|
||||||
|
def __sub__: fn(value: int, /) -> int
|
||||||
|
def __mul__: fn(value: int, /) -> int
|
||||||
|
def __floordiv__: fn(value: int, /) -> int
|
||||||
|
def __truediv__: fn(value: int, /) -> float
|
||||||
|
def __mod__: fn(value: int, /) -> int
|
||||||
|
// def __divmod__: fn(value: int, /) -> tuple[int, int]
|
||||||
|
def __radd__: fn(value: int, /) -> int
|
||||||
|
def __rsub__: fn(value: int, /) -> int
|
||||||
|
def __rmul__: fn(value: int, /) -> int
|
||||||
|
def __rfloordiv__: fn(value: int, /) -> int
|
||||||
|
def __rtruediv__: fn(value: int, /) -> float
|
||||||
|
def __rmod__: fn(value: int, /) -> int
|
||||||
|
// def __rdivmod__: fn(value: int, /) -> tuple[int, int]
|
||||||
|
def __pow__: fn(value: int, /) -> int
|
||||||
|
// def __pow__: fn(value: _PositiveInteger, mod: None = None, /) -> int
|
||||||
|
// def __pow__: fn(value: _NegativeInteger, mod: None = None, /) -> float
|
||||||
|
// positive __value -> int; negative __value -> float
|
||||||
|
// return type must be Any as `int | float` causes too many false-positive errors
|
||||||
|
// def __pow__: fn(value: int, mod: None = None, /) -> Any
|
||||||
|
// def __pow__: fn(value: int, mod: int, /) -> int
|
||||||
|
def __rpow__: fn(value: int, /) -> Any
|
||||||
|
def __and__: fn(value: int, /) -> int
|
||||||
|
def __or__: fn(value: int, /) -> int
|
||||||
|
def __xor__: fn(value: int, /) -> int
|
||||||
|
def __lshift__: fn(value: int, /) -> int
|
||||||
|
def __rshift__: fn(value: int, /) -> int
|
||||||
|
def __rand__: fn(value: int, /) -> int
|
||||||
|
def __ror__: fn(value: int, /) -> int
|
||||||
|
def __rxor__: fn(value: int, /) -> int
|
||||||
|
def __rlshift__: fn(value: int, /) -> int
|
||||||
|
def __rrshift__: fn(value: int, /) -> int
|
||||||
|
def __neg__: fn() -> int
|
||||||
|
def __pos__: fn() -> int
|
||||||
|
def __invert__: fn() -> int
|
||||||
|
def __trunc__: fn() -> int
|
||||||
|
def __ceil__: fn() -> int
|
||||||
|
def __floor__: fn() -> int
|
||||||
|
def __round__: fn(ndigits: None?, /) -> int
|
||||||
|
def __round__: fn(ndigits: int, /) -> int
|
||||||
|
|
||||||
|
// def __getnewargs__: fn() -> tuple[int]
|
||||||
|
def __eq__: fn(value: object, /) -> bool
|
||||||
|
def __ne__: fn(value: object, /) -> bool
|
||||||
|
def __lt__: fn(value: int, /) -> bool
|
||||||
|
def __le__: fn(value: int, /) -> bool
|
||||||
|
def __gt__: fn(value: int, /) -> bool
|
||||||
|
def __ge__: fn(value: int, /) -> bool
|
||||||
|
def __float__: fn() -> float
|
||||||
|
def __int__: fn() -> int
|
||||||
|
def __abs__: fn() -> int
|
||||||
|
def __hash__: fn() -> int
|
||||||
|
def __bool__: fn() -> bool
|
||||||
|
def __index__: fn() -> int
|
||||||
|
def __format__: fn(format_spec: str, /) -> str
|
||||||
|
}
|
||||||
|
|
||||||
|
extend list[T] {
|
||||||
|
def copy: fn () -> list[T]
|
||||||
|
def append: fn (object: T, /) -> None
|
||||||
|
def extend: fn (iterable: list[T], /) -> None
|
||||||
|
def pop: fn (index: int?, /) -> T
|
||||||
|
def index: fn (value: T, start: int?, stop: int?, /) -> int
|
||||||
|
def count: fn (value: T, /) -> int
|
||||||
|
def insert: fn (index: int, object: T, /) -> None
|
||||||
|
def remove: fn (value: T, /) -> None
|
||||||
|
def sort: fn (*, reverse: bool?) -> None
|
||||||
|
def __len__: fn () -> int
|
||||||
|
// def __iter__: fn () -> Iterator[T]
|
||||||
|
def __getitem__: fn (i: int, /) -> T
|
||||||
|
//__getitem__: fn (s: slice, /) -> list[T]
|
||||||
|
def __setitem__: fn (key: int, value: T, /) -> None
|
||||||
|
//__setitem__: fn (key: slice, value: list[T], /) -> None
|
||||||
|
def __delitem__: fn (key: int, /) -> None
|
||||||
|
// def __delitem__: fn (key: slice, /) -> None
|
||||||
|
// def __add__: fn[S <: T] (value: list[S], /) -> list[T]
|
||||||
|
def __add__: fn (value: list[T], /) -> list[T]
|
||||||
|
def __iadd__: fn (value: list[T], /) -> list[T]
|
||||||
|
def __mul__: fn (value: int, /) -> list[T]
|
||||||
|
def __rmul__: fn (value: int, /) -> list[T]
|
||||||
|
def __imul__: fn (value: int, /) -> list[T]
|
||||||
|
def __contains__: fn (key: object, /) -> bool
|
||||||
|
// def __reversed__: fn (self) -> Iterator[_T]
|
||||||
|
def __gt__: fn (value: list[T], /) -> bool
|
||||||
|
def __ge__: fn (value: list[T], /) -> bool
|
||||||
|
def __lt__: fn (value: list[T], /) -> bool
|
||||||
|
def __le__: fn (value: list[T], /) -> bool
|
||||||
|
def __eq__: fn (value: object, /) -> bool
|
||||||
|
|
||||||
|
prop __doc__: str
|
||||||
|
}
|
||||||
40
midas/checker/builtins.py
Normal file
40
midas/checker/builtins.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from midas.checker.types import (
|
||||||
|
BaseType,
|
||||||
|
GenericType,
|
||||||
|
TopType,
|
||||||
|
TypeVar,
|
||||||
|
UnitType,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
|
||||||
|
|
||||||
|
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||||
|
"float": {"int"},
|
||||||
|
"int": {"bool"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def define_builtins(reg: TypesRegistry):
|
||||||
|
"""Define builtin types and operations"""
|
||||||
|
any = reg.define_type("Any", TopType())
|
||||||
|
unit = reg.define_type("None", UnitType())
|
||||||
|
object = reg.define_type("object", BaseType(name="object"))
|
||||||
|
bool = reg.define_type("bool", BaseType(name="bool"))
|
||||||
|
int = reg.define_type("int", BaseType(name="int"))
|
||||||
|
float = reg.define_type("float", BaseType(name="float"))
|
||||||
|
str = reg.define_type("str", BaseType(name="str"))
|
||||||
|
|
||||||
|
list = reg.define_type(
|
||||||
|
"list",
|
||||||
|
GenericType(
|
||||||
|
name="list",
|
||||||
|
params=[TypeVar(name="T", bound=None)],
|
||||||
|
body=BaseType(name="list"),
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -1,549 +1,35 @@
|
|||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import midas.ast.midas as m
|
from midas.checker.diagnostic import Diagnostic
|
||||||
import midas.ast.python as p
|
from midas.checker.midas import MidasTyper
|
||||||
from midas.ast.location import Location
|
from midas.checker.python import PythonTyper
|
||||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.reporter import Reporter
|
||||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
|
||||||
from midas.checker.types import Function, Type, UnitType, UnknownType
|
|
||||||
from midas.lexer.midas import MidasLexer
|
|
||||||
from midas.lexer.token import Token
|
|
||||||
from midas.parser.midas import MidasParser
|
|
||||||
from midas.resolver.midas import MidasResolver
|
|
||||||
|
|
||||||
|
|
||||||
class ReturnException(Exception):
|
class TypeChecker:
|
||||||
pass
|
def __init__(self):
|
||||||
|
self.types: TypesRegistry = TypesRegistry()
|
||||||
|
self.reporter: Reporter = Reporter()
|
||||||
|
|
||||||
|
self.midas_typer = MidasTyper(self.types, self.reporter)
|
||||||
|
self.python_typer = PythonTyper(self.types, self.reporter)
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
def import_midas(self, path: Path):
|
||||||
class MappedArgument:
|
source: str = path.read_text()
|
||||||
expr: p.Expr
|
return self.import_midas_source(source, path=str(path))
|
||||||
type: Type
|
|
||||||
argument: Function.Argument
|
|
||||||
|
|
||||||
|
def import_midas_source(self, source: str, path: Optional[str] = None):
|
||||||
|
self.midas_typer.process(source, path)
|
||||||
|
|
||||||
class Checker(
|
def type_check(self, path: Path):
|
||||||
p.Stmt.Visitor[None],
|
source: str = path.read_text()
|
||||||
p.Expr.Visitor[Type],
|
return self.type_check_source(source, path=str(path))
|
||||||
p.MidasType.Visitor[Type],
|
|
||||||
):
|
|
||||||
"""A type checker which can use custom type definitions"""
|
|
||||||
|
|
||||||
def __init__(self, locals: dict[p.Expr, int], file_path: Path):
|
def type_check_source(self, source: str, path: Optional[str] = None):
|
||||||
self.logger: logging.Logger = logging.getLogger("Checker")
|
self.python_typer.process(source, path)
|
||||||
self.file_path: Path = file_path
|
|
||||||
self.ctx: MidasResolver = MidasResolver()
|
|
||||||
self.global_env: Environment = Environment()
|
|
||||||
self.env: Environment = self.global_env
|
|
||||||
self.locals: dict[p.Expr, int] = locals
|
|
||||||
self.diagnostics: list[Diagnostic] = []
|
|
||||||
|
|
||||||
def diagnostic(self, type: DiagnosticType, location: Location, message: str):
|
@property
|
||||||
self.diagnostics.append(
|
def diagnostics(self) -> list[Diagnostic]:
|
||||||
Diagnostic(
|
return self.reporter.diagnostics
|
||||||
file_path=self.file_path,
|
|
||||||
location=location,
|
|
||||||
type=type,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def error(self, location: Location, message: str):
|
|
||||||
self.diagnostic(
|
|
||||||
type=DiagnosticType.ERROR,
|
|
||||||
location=location,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
|
|
||||||
def warning(self, location: Location, message: str):
|
|
||||||
self.diagnostic(
|
|
||||||
type=DiagnosticType.WARNING,
|
|
||||||
location=location,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
|
|
||||||
def info(self, location: Location, message: str):
|
|
||||||
self.diagnostic(
|
|
||||||
type=DiagnosticType.INFO,
|
|
||||||
location=location,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
|
|
||||||
def evaluate(self, expr: p.Expr) -> Type:
|
|
||||||
"""Evaluate the type of an expression
|
|
||||||
|
|
||||||
Args:
|
|
||||||
expr (p.Expr): the expression to evaluate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Type: the type of the given expression
|
|
||||||
"""
|
|
||||||
return expr.accept(self)
|
|
||||||
|
|
||||||
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
|
||||||
"""Evaluate a sequence of statements
|
|
||||||
|
|
||||||
Args:
|
|
||||||
block (list[p.Stmt]): the statements to evaluate
|
|
||||||
env (Environment): the environment in which to evaluate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: whether a return statement is present in the block
|
|
||||||
"""
|
|
||||||
previous_env: Environment = self.env
|
|
||||||
self.env = env
|
|
||||||
returned: bool = False
|
|
||||||
for i, stmt in enumerate(block):
|
|
||||||
try:
|
|
||||||
stmt.accept(self)
|
|
||||||
except ReturnException:
|
|
||||||
returned = True
|
|
||||||
if i < len(block) - 1:
|
|
||||||
self.warning(block[i + 1].location, "Unreachable statement")
|
|
||||||
break
|
|
||||||
self.env = previous_env
|
|
||||||
return returned
|
|
||||||
|
|
||||||
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
|
||||||
"""Type check a sequence of statements and returns diagnostics
|
|
||||||
|
|
||||||
Args:
|
|
||||||
statements (list[p.Stmt]): the statements to evaluate and check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
|
|
||||||
"""
|
|
||||||
self.diagnostics = []
|
|
||||||
for stmt in statements:
|
|
||||||
stmt.accept(self)
|
|
||||||
|
|
||||||
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
|
||||||
return self.diagnostics
|
|
||||||
|
|
||||||
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
|
||||||
"""Look up a variable in the environment it was declared
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): the name of the variable
|
|
||||||
expr (p.Expr): the variable expression, used to lookup the scope distance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Type]: the type of the variable, or None if it was not found
|
|
||||||
"""
|
|
||||||
distance: Optional[int] = self.locals.get(expr)
|
|
||||||
if distance is not None:
|
|
||||||
return self.env.get_at(distance, name)
|
|
||||||
return self.global_env.get(name)
|
|
||||||
|
|
||||||
def parse_midas_import(self, expr: p.CallExpr) -> Optional[Path]:
|
|
||||||
"""Parse a Midas import statement
|
|
||||||
|
|
||||||
The statement should be written as `midas.using("path/to/types.midas")`
|
|
||||||
|
|
||||||
Args:
|
|
||||||
expr (p.CallExpr): the import call expression
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Path]: the path to the imported file, or None if the expression is malformed
|
|
||||||
"""
|
|
||||||
match expr:
|
|
||||||
case p.CallExpr(
|
|
||||||
callee=p.GetExpr(
|
|
||||||
object=p.VariableExpr(name="midas"),
|
|
||||||
name="using",
|
|
||||||
),
|
|
||||||
arguments=[
|
|
||||||
p.LiteralExpr(value=path),
|
|
||||||
],
|
|
||||||
):
|
|
||||||
return Path(path)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def import_midas(self, path: Path) -> None:
|
|
||||||
"""Import Midas definitions from a path
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (Path): the import path
|
|
||||||
"""
|
|
||||||
self.logger.debug(f"Importing type definitions from {path}")
|
|
||||||
path = (self.file_path.parent / path).resolve()
|
|
||||||
lexer: MidasLexer = MidasLexer(path.read_text())
|
|
||||||
tokens: list[Token] = lexer.process()
|
|
||||||
parser: MidasParser = MidasParser(tokens)
|
|
||||||
stmts: list[m.Stmt] = parser.parse()
|
|
||||||
self.ctx.resolve(stmts)
|
|
||||||
self.logger.debug(f"Midas types: {self.ctx._types}")
|
|
||||||
self.logger.debug(f"Midas operations: {self.ctx._operations}")
|
|
||||||
|
|
||||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
|
||||||
self.evaluate(stmt.expr)
|
|
||||||
|
|
||||||
def visit_function(self, stmt: p.Function) -> None:
|
|
||||||
env: Environment = Environment(self.env)
|
|
||||||
pos_args: list[Function.Argument] = []
|
|
||||||
args: list[Function.Argument] = []
|
|
||||||
kw_args: list[Function.Argument] = []
|
|
||||||
|
|
||||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
|
||||||
if arg.type is not None:
|
|
||||||
return arg.type.accept(self)
|
|
||||||
if arg.default is not None:
|
|
||||||
return arg.default.accept(self)
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
for arg in stmt.posonlyargs:
|
|
||||||
pos_args.append(
|
|
||||||
Function.Argument(
|
|
||||||
name=arg.name,
|
|
||||||
type=eval_arg_type(arg),
|
|
||||||
required=arg.default is None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for arg in stmt.args:
|
|
||||||
args.append(
|
|
||||||
Function.Argument(
|
|
||||||
name=arg.name,
|
|
||||||
type=eval_arg_type(arg),
|
|
||||||
required=arg.default is None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for arg in stmt.kwonlyargs:
|
|
||||||
kw_args.append(
|
|
||||||
Function.Argument(
|
|
||||||
name=arg.name,
|
|
||||||
type=eval_arg_type(arg),
|
|
||||||
required=arg.default is None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for arg in pos_args + args + kw_args:
|
|
||||||
env.define(arg.name, arg.type)
|
|
||||||
|
|
||||||
returns_hint: Optional[Type] = None
|
|
||||||
if stmt.returns is not None:
|
|
||||||
returns_hint = stmt.returns.accept(self)
|
|
||||||
# Early define to handle simple fully-typed recursion
|
|
||||||
inside_function: Function = Function(
|
|
||||||
name=stmt.name,
|
|
||||||
pos_args=pos_args,
|
|
||||||
args=args,
|
|
||||||
kw_args=kw_args,
|
|
||||||
returns=returns_hint,
|
|
||||||
)
|
|
||||||
self.env.define(stmt.name, inside_function)
|
|
||||||
|
|
||||||
returned: bool = self.evaluate_block(stmt.body, env)
|
|
||||||
inferred_return: Type = UnknownType()
|
|
||||||
if not returned:
|
|
||||||
env.return_types.append(UnitType())
|
|
||||||
return_types: set[Type] = set(env.return_types)
|
|
||||||
if len(return_types) == 1:
|
|
||||||
inferred_return = list(return_types)[0]
|
|
||||||
elif len(return_types) > 1:
|
|
||||||
self.error(
|
|
||||||
stmt.location,
|
|
||||||
f"Mixed return types: {env.return_types}",
|
|
||||||
)
|
|
||||||
|
|
||||||
returns: Type = UnknownType()
|
|
||||||
if returns_hint is not None:
|
|
||||||
assert stmt.returns is not None
|
|
||||||
returns = returns_hint
|
|
||||||
if returns != inferred_return:
|
|
||||||
self.error(
|
|
||||||
stmt.returns.location,
|
|
||||||
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
returns = inferred_return
|
|
||||||
|
|
||||||
# TODO: handle *args and **kwargs sinks
|
|
||||||
function: Function = Function(
|
|
||||||
name=stmt.name,
|
|
||||||
pos_args=pos_args,
|
|
||||||
args=args,
|
|
||||||
kw_args=kw_args,
|
|
||||||
returns=returns,
|
|
||||||
)
|
|
||||||
self.env.define(stmt.name, function)
|
|
||||||
|
|
||||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
|
||||||
# TODO check not yet defined locally
|
|
||||||
type: Type = stmt.type.accept(self)
|
|
||||||
self.env.define(stmt.name, type)
|
|
||||||
|
|
||||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
|
||||||
value: Type = self.evaluate(stmt.value)
|
|
||||||
for target in stmt.targets:
|
|
||||||
if not isinstance(target, p.VariableExpr):
|
|
||||||
self.logger.warning(f"Unsupported assignment to {target}")
|
|
||||||
self.warning(target.location, f"Unsupported assignment to {target}")
|
|
||||||
continue
|
|
||||||
name: str = target.name
|
|
||||||
var_type: Optional[Type] = self.look_up_variable(name, target)
|
|
||||||
|
|
||||||
if var_type is None:
|
|
||||||
self.env.define(name, value)
|
|
||||||
else:
|
|
||||||
# TODO: implement real comparison method
|
|
||||||
if var_type != value:
|
|
||||||
self.error(
|
|
||||||
stmt.location,
|
|
||||||
f"Cannot assign {value} to {name} of type {var_type}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
|
||||||
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
|
||||||
self.env.return_types.append(type)
|
|
||||||
raise ReturnException()
|
|
||||||
|
|
||||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
|
||||||
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
|
||||||
# For example:
|
|
||||||
# if (m := 1 + 1) < 2:
|
|
||||||
# ...
|
|
||||||
# print(m) # <- m is still defined
|
|
||||||
test_type: Type = stmt.test.accept(self)
|
|
||||||
|
|
||||||
# TODO Allow subtypes or any type
|
|
||||||
if test_type != self.ctx.get_type("bool"):
|
|
||||||
self.error(
|
|
||||||
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
env: Environment = Environment(self.env)
|
|
||||||
body_returned: bool = self.evaluate_block(stmt.body, env)
|
|
||||||
else_returned: bool = self.evaluate_block(stmt.orelse, env)
|
|
||||||
self.env.return_types.extend(env.return_types)
|
|
||||||
if body_returned and else_returned:
|
|
||||||
raise ReturnException()
|
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
|
||||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
|
||||||
if method is None:
|
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
|
||||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
|
||||||
return UnknownType()
|
|
||||||
left: Type = self.evaluate(expr.left)
|
|
||||||
right: Type = self.evaluate(expr.right)
|
|
||||||
|
|
||||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
|
||||||
if result is None:
|
|
||||||
self.error(
|
|
||||||
expr.location,
|
|
||||||
f"Undefined operation {method} between {left} and {right}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
|
||||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
|
||||||
if method is None:
|
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
|
||||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
|
||||||
return UnknownType()
|
|
||||||
left: Type = self.evaluate(expr.left)
|
|
||||||
right: Type = self.evaluate(expr.right)
|
|
||||||
|
|
||||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
|
||||||
if result is None:
|
|
||||||
self.error(
|
|
||||||
expr.location,
|
|
||||||
f"Undefined operation {method} between {left} and {right}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
|
||||||
|
|
||||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
|
||||||
if path := self.parse_midas_import(expr):
|
|
||||||
self.import_midas(path)
|
|
||||||
return UnknownType()
|
|
||||||
callee: Type = self.evaluate(expr.callee)
|
|
||||||
if not isinstance(callee, Function):
|
|
||||||
self.error(expr.callee.location, "Callee is not a function")
|
|
||||||
return UnknownType()
|
|
||||||
function: Function = callee
|
|
||||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
|
||||||
for arg in mapped:
|
|
||||||
if arg.type != arg.argument.type:
|
|
||||||
self.error(
|
|
||||||
arg.expr.location,
|
|
||||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
|
||||||
)
|
|
||||||
return function.returns
|
|
||||||
|
|
||||||
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
|
||||||
match expr.value:
|
|
||||||
case bool(): # Must be before int
|
|
||||||
return self.ctx.get_type("bool")
|
|
||||||
case int():
|
|
||||||
return self.ctx.get_type("int")
|
|
||||||
case float():
|
|
||||||
return self.ctx.get_type("float")
|
|
||||||
case str():
|
|
||||||
return self.ctx.get_type("str")
|
|
||||||
case _:
|
|
||||||
self.warning(expr.location, f"Unknown literal {expr}")
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
|
||||||
return self.look_up_variable(expr.name, expr) or UnknownType()
|
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: ...
|
|
||||||
|
|
||||||
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
|
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
|
||||||
return expr.type.accept(self)
|
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
|
||||||
test_type: Type = expr.test.accept(self)
|
|
||||||
|
|
||||||
# TODO Allow subtypes or any type
|
|
||||||
if test_type != self.ctx.get_type("bool"):
|
|
||||||
self.error(
|
|
||||||
expr.test.location, f"If test must be a boolean, got {test_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
true_type: Type = expr.if_true.accept(self)
|
|
||||||
false_type: Type = expr.if_false.accept(self)
|
|
||||||
if true_type != false_type:
|
|
||||||
self.error(
|
|
||||||
expr.location,
|
|
||||||
f"Type mismatch in ternary if branches: true={true_type} != false={false_type}",
|
|
||||||
)
|
|
||||||
return UnknownType()
|
|
||||||
return true_type
|
|
||||||
|
|
||||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
|
||||||
return self.ctx.get_type(node.base)
|
|
||||||
|
|
||||||
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
|
|
||||||
|
|
||||||
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
|
||||||
|
|
||||||
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
|
||||||
|
|
||||||
def map_call_arguments(
|
|
||||||
self, function: Function, call: p.CallExpr
|
|
||||||
) -> list[MappedArgument]:
|
|
||||||
"""Map call arguments to function parameters as defined in its signature
|
|
||||||
|
|
||||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
|
||||||
with the arguments passed at the call site
|
|
||||||
|
|
||||||
Any mismatched, missing or unexpected argument is reported as a diagnostic
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function (Function): the function definition
|
|
||||||
call (p.CallExpr): the call expression
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[MappedArgument]: the list of mapped arguments
|
|
||||||
"""
|
|
||||||
positional: list[tuple[p.Expr, Type]] = [
|
|
||||||
(arg, self.evaluate(arg)) for arg in call.arguments
|
|
||||||
]
|
|
||||||
keywords: dict[str, tuple[p.Expr, Type]] = {
|
|
||||||
name: (arg, self.evaluate(arg)) for name, arg in call.keywords.items()
|
|
||||||
}
|
|
||||||
set_args: set[str] = set()
|
|
||||||
|
|
||||||
required_positional: list[str] = [
|
|
||||||
arg.name for arg in function.pos_args + function.args if arg.required
|
|
||||||
]
|
|
||||||
required_keyword: list[str] = [
|
|
||||||
arg.name for arg in function.kw_args if arg.required
|
|
||||||
]
|
|
||||||
|
|
||||||
mapped: list[MappedArgument] = []
|
|
||||||
|
|
||||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
|
||||||
mixed_params: list[Function.Argument] = list(function.args)
|
|
||||||
kw_params: dict[str, Function.Argument] = {
|
|
||||||
arg.name: arg for arg in function.kw_args
|
|
||||||
}
|
|
||||||
|
|
||||||
# TODO: handle *args and **kwargs sinks
|
|
||||||
for arg in positional:
|
|
||||||
param: Function.Argument
|
|
||||||
if len(pos_params) != 0:
|
|
||||||
param = pos_params.pop(0)
|
|
||||||
elif len(mixed_params) != 0:
|
|
||||||
param = mixed_params.pop(0)
|
|
||||||
else:
|
|
||||||
self.error(arg[0].location, "Too many positional arguments")
|
|
||||||
break
|
|
||||||
name: str = param.name
|
|
||||||
if name in required_positional:
|
|
||||||
required_positional.remove(name)
|
|
||||||
if name in required_keyword:
|
|
||||||
required_keyword.remove(name)
|
|
||||||
set_args.add(name)
|
|
||||||
mapped.append(
|
|
||||||
MappedArgument(
|
|
||||||
expr=arg[0],
|
|
||||||
type=arg[1],
|
|
||||||
argument=param,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
|
||||||
for name, arg in keywords.items():
|
|
||||||
param: Function.Argument
|
|
||||||
if name not in kw_params:
|
|
||||||
if name in set_args:
|
|
||||||
self.error(
|
|
||||||
arg[0].location, f"Multiple values for argument '{name}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
|
|
||||||
continue
|
|
||||||
param = kw_params.pop(name)
|
|
||||||
if name in required_positional:
|
|
||||||
required_positional.remove(name)
|
|
||||||
if name in required_keyword:
|
|
||||||
required_keyword.remove(name)
|
|
||||||
set_args.add(name)
|
|
||||||
mapped.append(
|
|
||||||
MappedArgument(
|
|
||||||
expr=arg[0],
|
|
||||||
type=arg[1],
|
|
||||||
argument=param,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def join_args(args: list[str]) -> str:
|
|
||||||
args = list(map(lambda a: f"'{a}'", args))
|
|
||||||
if len(args) == 0:
|
|
||||||
return ""
|
|
||||||
if len(args) == 1:
|
|
||||||
return args[0]
|
|
||||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
|
||||||
|
|
||||||
if len(required_positional) != 0:
|
|
||||||
plural: str = "" if len(required_positional) == 1 else "s"
|
|
||||||
args: str = join_args(required_positional)
|
|
||||||
self.error(
|
|
||||||
call.location,
|
|
||||||
f"Missing required positional argument{plural}: {args}",
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(required_keyword) != 0:
|
|
||||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
|
||||||
args: str = join_args(required_keyword)
|
|
||||||
self.error(
|
|
||||||
call.location,
|
|
||||||
f"Missing required keyword argument{plural}: {args}",
|
|
||||||
)
|
|
||||||
|
|
||||||
return mapped
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
@@ -14,12 +13,13 @@ class DiagnosticType(StrEnum):
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Diagnostic:
|
class Diagnostic:
|
||||||
file_path: Path
|
file_path: Optional[str]
|
||||||
location: Location
|
location: Location
|
||||||
type: DiagnosticType
|
type: DiagnosticType
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
def __str__(self) -> str:
|
@property
|
||||||
|
def location_str(self) -> str:
|
||||||
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
|
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
|
||||||
end_loc: Optional[str] = ""
|
end_loc: Optional[str] = ""
|
||||||
if (
|
if (
|
||||||
@@ -27,7 +27,16 @@ class Diagnostic:
|
|||||||
and self.location.end_col_offset is not None
|
and self.location.end_col_offset is not None
|
||||||
):
|
):
|
||||||
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
|
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
|
||||||
loc: str = (
|
|
||||||
f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}"
|
loc: str = ""
|
||||||
)
|
if self.file_path is not None:
|
||||||
return f"{self.type} in {self.file_path} {loc}: {self.message}"
|
loc += f" in {self.file_path}"
|
||||||
|
if end_loc is None:
|
||||||
|
loc += f" at {start_loc}"
|
||||||
|
else:
|
||||||
|
loc += f" from {start_loc} to {end_loc}"
|
||||||
|
|
||||||
|
return f"{self.type}{loc}"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.location_str}: {self.message}"
|
||||||
|
|||||||
206
midas/checker/midas.py
Normal file
206
midas/checker/midas.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
from midas.checker.builtins import define_builtins
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
|
from midas.checker.types import (
|
||||||
|
AliasType,
|
||||||
|
ComplexType,
|
||||||
|
ExtensionType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
UnknownType,
|
||||||
|
)
|
||||||
|
from midas.lexer.midas import MidasLexer
|
||||||
|
from midas.lexer.token import Token
|
||||||
|
from midas.parser.midas import MidasParser
|
||||||
|
|
||||||
|
|
||||||
|
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||||
|
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||||
|
|
||||||
|
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||||
|
self.logger: logging.Logger = logging.getLogger("MidasTyper")
|
||||||
|
self.reporter: FileReporter = reporter.for_file(None)
|
||||||
|
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self._local_variables: dict[str, TypeVar] = {}
|
||||||
|
|
||||||
|
self._current_name: Optional[str] = None
|
||||||
|
|
||||||
|
define_builtins(self.types)
|
||||||
|
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
|
||||||
|
self.process(builtins_path.read_text(), str(builtins_path))
|
||||||
|
|
||||||
|
def process(self, source: str, path: Optional[str]):
|
||||||
|
self.reporter = self.reporter.for_file(path)
|
||||||
|
lexer: MidasLexer = MidasLexer(source)
|
||||||
|
tokens: list[Token] = lexer.process()
|
||||||
|
parser: MidasParser = MidasParser(tokens)
|
||||||
|
stmts: list[m.Stmt] = parser.parse()
|
||||||
|
for error in parser.errors:
|
||||||
|
self.reporter.error(error.token.get_location(), error.message)
|
||||||
|
self.resolve(stmts)
|
||||||
|
|
||||||
|
def get_type(self, name: str) -> Type:
|
||||||
|
"""Get a type from its name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NameError: if the type is not defined
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the type
|
||||||
|
"""
|
||||||
|
if name in self._local_variables:
|
||||||
|
return self._local_variables[name]
|
||||||
|
return self.types.get_type(name)
|
||||||
|
|
||||||
|
def resolve(self, stmts: list[m.Stmt]):
|
||||||
|
"""Process a sequence of statements
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stmts (list[m.Stmt]): the statements
|
||||||
|
"""
|
||||||
|
for stmt in stmts:
|
||||||
|
stmt.accept(self)
|
||||||
|
|
||||||
|
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||||
|
name: str = stmt.name.lexeme
|
||||||
|
self._current_name = name
|
||||||
|
params: list[TypeVar] = self._resolve_type_params(stmt.params)
|
||||||
|
|
||||||
|
type: Type = stmt.type.accept(self)
|
||||||
|
if len(params) != 0:
|
||||||
|
type = GenericType(name=name, params=params, body=type)
|
||||||
|
else:
|
||||||
|
type = AliasType(name=name, type=type)
|
||||||
|
self.types.define_type(name, type)
|
||||||
|
self._local_variables.clear()
|
||||||
|
self._current_name = None
|
||||||
|
|
||||||
|
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
|
||||||
|
|
||||||
|
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||||
|
self._resolve_type_params(stmt.params)
|
||||||
|
base_name: str = stmt.name.lexeme
|
||||||
|
try:
|
||||||
|
_ = self.get_type(base_name)
|
||||||
|
except NameError:
|
||||||
|
self.reporter.error(stmt.name.get_location(), f"Unknown type '{base_name}'")
|
||||||
|
|
||||||
|
for member in stmt.members:
|
||||||
|
member_type: Type = member.type.accept(self)
|
||||||
|
self.types.define_member(
|
||||||
|
base_name,
|
||||||
|
member.name.lexeme,
|
||||||
|
member_type,
|
||||||
|
member.kind == m.MemberKind.METHOD,
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||||
|
self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||||
|
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
|
||||||
|
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
||||||
|
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
||||||
|
self.reporter.warning(expr.location, "GetExpr not yet supported")
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
|
||||||
|
self.reporter.warning(expr.location, "VariableExpr not yet supported")
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||||
|
return expr.expr.accept(self)
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||||
|
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
|
||||||
|
|
||||||
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||||
|
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
|
||||||
|
|
||||||
|
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||||
|
name: str = type.name.lexeme
|
||||||
|
try:
|
||||||
|
return self.get_type(name)
|
||||||
|
except NameError:
|
||||||
|
msg: str = f"Undefined type {name}"
|
||||||
|
if self._current_name == name:
|
||||||
|
msg += ". Recursive types are not supported, use an extend block"
|
||||||
|
self.reporter.error(type.name.get_location(), msg)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_generic_type(self, type: m.GenericType) -> Type:
|
||||||
|
type_: Type = type.type.accept(self)
|
||||||
|
args: list[Type] = [arg.accept(self) for arg in type.args]
|
||||||
|
try:
|
||||||
|
return self.types.apply_generic(type_, args)
|
||||||
|
except Exception as e:
|
||||||
|
self.reporter.error(type.location, f"Cannot apply generic type: {e}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||||
|
type_: Type = type.type.accept(self)
|
||||||
|
type.constraint.accept(self)
|
||||||
|
# TODO
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
|
||||||
|
return ComplexType(
|
||||||
|
members={
|
||||||
|
member.name.lexeme: member.type.accept(self) for member in type.members
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_extension_type(self, type: m.ExtensionType) -> Type:
|
||||||
|
return ExtensionType(
|
||||||
|
base=type.base.accept(self),
|
||||||
|
extension=self.visit_complex_type(type.extension),
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||||
|
n_pos_args: int = len(type.pos_args)
|
||||||
|
n_args: int = len(type.args)
|
||||||
|
|
||||||
|
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
||||||
|
return Function.Argument(
|
||||||
|
pos=i,
|
||||||
|
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||||
|
type=arg.type.accept(self),
|
||||||
|
required=arg.required,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Function(
|
||||||
|
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
|
||||||
|
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
|
||||||
|
kw_args=[
|
||||||
|
process_arg(arg, i + n_pos_args + n_args)
|
||||||
|
for i, arg in enumerate(type.kw_args)
|
||||||
|
],
|
||||||
|
returns=type.returns.accept(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||||
|
vars: list[TypeVar] = []
|
||||||
|
for param in params:
|
||||||
|
name: str = param.name.lexeme
|
||||||
|
bound: Optional[Type] = None
|
||||||
|
if param.bound is not None:
|
||||||
|
bound = param.bound.accept(self)
|
||||||
|
var = TypeVar(name=name, bound=bound)
|
||||||
|
self._local_variables[name] = var
|
||||||
|
vars.append(var)
|
||||||
|
return vars
|
||||||
@@ -29,3 +29,10 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
|||||||
# ast.In: "__in__",
|
# ast.In: "__in__",
|
||||||
# ast.NotIn: "__notin__",
|
# ast.NotIn: "__notin__",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
||||||
|
ast.Invert: "__invert__",
|
||||||
|
# ast.Not: "",
|
||||||
|
ast.UAdd: "__pos__",
|
||||||
|
ast.USub: "__neg__",
|
||||||
|
}
|
||||||
|
|||||||
705
midas/checker/python.py
Normal file
705
midas/checker/python.py
Normal file
@@ -0,0 +1,705 @@
|
|||||||
|
import ast
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.environment import Environment
|
||||||
|
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
|
from midas.checker.resolver import Resolver
|
||||||
|
from midas.checker.types import (
|
||||||
|
Function,
|
||||||
|
Type,
|
||||||
|
UnitType,
|
||||||
|
UnknownType,
|
||||||
|
)
|
||||||
|
from midas.parser.python import PythonParser
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class MappedArgument:
|
||||||
|
expr: p.Expr
|
||||||
|
type: Type
|
||||||
|
argument: Function.Argument
|
||||||
|
|
||||||
|
|
||||||
|
class PythonTyper(
|
||||||
|
p.Stmt.Visitor[None],
|
||||||
|
p.Expr.Visitor[Type],
|
||||||
|
p.MidasType.Visitor[Type],
|
||||||
|
):
|
||||||
|
"""A type checker which can use custom type definitions"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
types: TypesRegistry,
|
||||||
|
reporter: Reporter,
|
||||||
|
):
|
||||||
|
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
||||||
|
self.reporter: FileReporter = reporter.for_file(None)
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self.global_env: Environment = Environment()
|
||||||
|
self.env: Environment = self.global_env
|
||||||
|
self.locals: dict[p.Expr, int] = {}
|
||||||
|
self.judgements: list[tuple[p.Expr, Type]] = []
|
||||||
|
|
||||||
|
def process(self, source: str, path: Optional[str]):
|
||||||
|
self.reporter = self.reporter.for_file(path)
|
||||||
|
|
||||||
|
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
|
||||||
|
parser = PythonParser()
|
||||||
|
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||||
|
resolver = Resolver()
|
||||||
|
resolver.resolve(*stmts)
|
||||||
|
|
||||||
|
self.env = self.global_env
|
||||||
|
self.locals = resolver.locals
|
||||||
|
self.judgements = []
|
||||||
|
|
||||||
|
self.check(stmts)
|
||||||
|
|
||||||
|
def type_of(self, expr: p.Expr) -> Type:
|
||||||
|
"""Evaluate the type of an expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr (p.Expr): the expression to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the type of the given expression
|
||||||
|
"""
|
||||||
|
type: Type = expr.accept(self)
|
||||||
|
self.judgements.append((expr, type))
|
||||||
|
return type
|
||||||
|
|
||||||
|
def resolve_type_expr(self, expr: p.MidasType) -> Type:
|
||||||
|
return expr.accept(self)
|
||||||
|
|
||||||
|
def process_stmt(self, stmt: p.Stmt) -> None:
|
||||||
|
stmt.accept(self)
|
||||||
|
|
||||||
|
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
||||||
|
"""Evaluate a sequence of statements
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block (list[p.Stmt]): the statements to evaluate
|
||||||
|
env (Environment): the environment in which to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether a return statement is present in the block
|
||||||
|
"""
|
||||||
|
previous_env: Environment = self.env
|
||||||
|
self.env = env
|
||||||
|
returned: bool = False
|
||||||
|
for i, stmt in enumerate(block):
|
||||||
|
try:
|
||||||
|
self.process_stmt(stmt)
|
||||||
|
except ReturnException:
|
||||||
|
returned = True
|
||||||
|
if i < len(block) - 1:
|
||||||
|
self.reporter.warning(
|
||||||
|
block[i + 1].location, "Unreachable statement"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
self.env = previous_env
|
||||||
|
return returned
|
||||||
|
|
||||||
|
def check(self, statements: list[p.Stmt]) -> None:
|
||||||
|
"""Type check a sequence of statements and returns diagnostics
|
||||||
|
|
||||||
|
Args:
|
||||||
|
statements (list[p.Stmt]): the statements to evaluate and check
|
||||||
|
"""
|
||||||
|
for stmt in statements:
|
||||||
|
self.process_stmt(stmt)
|
||||||
|
|
||||||
|
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
||||||
|
|
||||||
|
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
||||||
|
"""Look up a variable in the environment it was declared
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the variable
|
||||||
|
expr (p.Expr): the variable expression, used to lookup the scope distance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Type]: the type of the variable, or None if it was not found
|
||||||
|
"""
|
||||||
|
distance: Optional[int] = self.locals.get(expr)
|
||||||
|
if distance is not None:
|
||||||
|
return self.env.get_at(distance, name)
|
||||||
|
return self.global_env.get(name)
|
||||||
|
|
||||||
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
|
return self.types.is_subtype(type1, type2)
|
||||||
|
|
||||||
|
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||||
|
self.type_of(stmt.expr)
|
||||||
|
|
||||||
|
def visit_function(self, stmt: p.Function) -> None:
|
||||||
|
env: Environment = Environment(self.env)
|
||||||
|
pos_args: list[Function.Argument] = []
|
||||||
|
args: list[Function.Argument] = []
|
||||||
|
kw_args: list[Function.Argument] = []
|
||||||
|
|
||||||
|
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||||
|
if arg.type is not None:
|
||||||
|
return self.resolve_type_expr(arg.type)
|
||||||
|
if arg.default is not None:
|
||||||
|
return self.type_of(arg.default)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
pos: int = 0
|
||||||
|
for arg in stmt.posonlyargs:
|
||||||
|
pos_args.append(
|
||||||
|
Function.Argument(
|
||||||
|
pos=pos,
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pos += 1
|
||||||
|
for arg in stmt.args:
|
||||||
|
args.append(
|
||||||
|
Function.Argument(
|
||||||
|
pos=pos,
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pos += 1
|
||||||
|
for arg in stmt.kwonlyargs:
|
||||||
|
kw_args.append(
|
||||||
|
Function.Argument(
|
||||||
|
pos=pos, # not relevant
|
||||||
|
name=arg.name,
|
||||||
|
type=eval_arg_type(arg),
|
||||||
|
required=arg.default is None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pos += 1
|
||||||
|
|
||||||
|
for arg in pos_args + args + kw_args:
|
||||||
|
env.define(arg.name, arg.type)
|
||||||
|
|
||||||
|
returns_hint: Optional[Type] = None
|
||||||
|
if stmt.returns is not None:
|
||||||
|
returns_hint = self.resolve_type_expr(stmt.returns)
|
||||||
|
# Early define to handle simple fully-typed recursion
|
||||||
|
inside_function: Function = Function(
|
||||||
|
pos_args=pos_args,
|
||||||
|
args=args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=returns_hint,
|
||||||
|
)
|
||||||
|
self.env.define(stmt.name, inside_function)
|
||||||
|
|
||||||
|
returned: bool = self.process_block(stmt.body, env)
|
||||||
|
inferred_return: Type = UnknownType()
|
||||||
|
if not returned:
|
||||||
|
env.return_types.append(UnitType())
|
||||||
|
return_types: list[Type] = self.types.reduce_types(env.return_types)
|
||||||
|
if len(return_types) == 1:
|
||||||
|
inferred_return = return_types[0]
|
||||||
|
elif len(return_types) > 1:
|
||||||
|
self.reporter.error(
|
||||||
|
stmt.location,
|
||||||
|
f"Mixed return types: {return_types}",
|
||||||
|
)
|
||||||
|
|
||||||
|
returns: Type = UnknownType()
|
||||||
|
if returns_hint is not None:
|
||||||
|
assert stmt.returns is not None
|
||||||
|
returns = returns_hint
|
||||||
|
if returns != inferred_return:
|
||||||
|
self.reporter.error(
|
||||||
|
stmt.returns.location,
|
||||||
|
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
returns = inferred_return
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
function: Function = Function(
|
||||||
|
pos_args=pos_args,
|
||||||
|
args=args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=returns,
|
||||||
|
)
|
||||||
|
self.env.define(stmt.name, function)
|
||||||
|
|
||||||
|
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||||
|
# TODO check not yet defined locally
|
||||||
|
type: Type = self.resolve_type_expr(stmt.type)
|
||||||
|
self.env.define(stmt.name, type)
|
||||||
|
|
||||||
|
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||||
|
value_type: Type = self.type_of(stmt.value)
|
||||||
|
for target in stmt.targets:
|
||||||
|
self._assign(stmt.location, target, value_type)
|
||||||
|
|
||||||
|
def _assign(self, location: Location, target: p.Expr, value_type: Type):
|
||||||
|
match target:
|
||||||
|
case p.VariableExpr():
|
||||||
|
self._assign_var(location, target, value_type)
|
||||||
|
|
||||||
|
case p.GetExpr(object=object, name=name):
|
||||||
|
self._assign_attr(location, object, name, value_type)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
if not isinstance(target, p.VariableExpr):
|
||||||
|
self.logger.warning(f"Unsupported assignment to {target}")
|
||||||
|
self.reporter.warning(
|
||||||
|
target.location, f"Unsupported assignment to {target}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type):
|
||||||
|
name: str = target.name
|
||||||
|
var_type: Optional[Type] = self.look_up_variable(name, target)
|
||||||
|
|
||||||
|
if var_type is None:
|
||||||
|
self.env.define(name, value_type)
|
||||||
|
else:
|
||||||
|
# S <: T
|
||||||
|
# Γ, x: T v: S
|
||||||
|
# x = v
|
||||||
|
if not self.is_subtype(value_type, var_type):
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {value_type} to variable '{name}' of type {var_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _assign_attr(
|
||||||
|
self, location: Location, object: p.Expr, name: str, value_type: Type
|
||||||
|
):
|
||||||
|
object_type: Type = self.type_of(object)
|
||||||
|
member: Optional[Type] = self.types.lookup_member(object_type, name)
|
||||||
|
if member is None:
|
||||||
|
self.reporter.error(location, f"Unknown member '{name}' of {object_type}")
|
||||||
|
return
|
||||||
|
self.logger.debug(f"Member '{name}' of {object_type} has type {member}")
|
||||||
|
if not self.is_subtype(value_type, member):
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
|
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
|
||||||
|
self.env.return_types.append(type)
|
||||||
|
raise ReturnException()
|
||||||
|
|
||||||
|
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||||
|
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
||||||
|
# For example:
|
||||||
|
# if (m := 1 + 1) < 2:
|
||||||
|
# ...
|
||||||
|
# print(m) # <- m is still defined
|
||||||
|
test_type: Type = self.type_of(stmt.test)
|
||||||
|
|
||||||
|
# TODO Allow subtypes or any type
|
||||||
|
if test_type != self.types.get_type("bool"):
|
||||||
|
self.reporter.error(
|
||||||
|
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
env: Environment = Environment(self.env)
|
||||||
|
body_returned: bool = self.process_block(stmt.body, env)
|
||||||
|
else_returned: bool = self.process_block(stmt.orelse, env)
|
||||||
|
self.env.return_types.extend(env.return_types)
|
||||||
|
if body_returned and else_returned:
|
||||||
|
raise ReturnException()
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||||
|
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operator {expr.operator}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||||
|
|
||||||
|
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||||
|
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operator {expr.operator}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||||
|
|
||||||
|
def _visit_binary_expr(
|
||||||
|
self, location: Location, left_expr: p.Expr, right_expr: p.Expr, method: str
|
||||||
|
) -> Type:
|
||||||
|
left: Type = self.type_of(left_expr)
|
||||||
|
right: Type = self.type_of(right_expr)
|
||||||
|
|
||||||
|
operation: Optional[Type] = self.types.lookup_member(left, method)
|
||||||
|
if operation is None:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
match operation:
|
||||||
|
case Function() as function:
|
||||||
|
if not self._check_arity(function, 1, 0, 0):
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Wrong definition of binary operation. Expected function with 1 positional-only parameters, got {function}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
rhs: Function.Argument = function.pos_args[0]
|
||||||
|
if not self.is_subtype(right, rhs.type):
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Wrong type for right-hand side, expected {rhs.type}, got {right}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return function.returns
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(location, f"Unsupported operation {operation}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
||||||
|
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operator {expr.operator}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
operand: Type = self.type_of(expr.right)
|
||||||
|
operation: Optional[Type] = self.types.lookup_member(operand, method)
|
||||||
|
if operation is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Undefined operation {method} for {operand}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
match operation:
|
||||||
|
case Function() as function:
|
||||||
|
if not self._check_arity(function, 0, 0, 0):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Wrong definition of unary operation. Expected function with 0 parameters, got {function}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return function.returns
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operation {operation}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||||
|
callee: Type = self.type_of(expr.callee)
|
||||||
|
if not isinstance(callee, Function):
|
||||||
|
self.reporter.error(expr.callee.location, "Callee is not a function")
|
||||||
|
return UnknownType()
|
||||||
|
function: Function = callee
|
||||||
|
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||||
|
for arg in mapped:
|
||||||
|
if not self.is_subtype(arg.type, arg.argument.type):
|
||||||
|
self.reporter.error(
|
||||||
|
arg.expr.location,
|
||||||
|
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||||
|
)
|
||||||
|
return function.returns
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||||
|
object: Type = self.type_of(expr.object)
|
||||||
|
member: Optional[Type] = self.types.lookup_member(object, expr.name)
|
||||||
|
if member is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Unknown member '{expr.name}' of {object}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
self.logger.debug(f"Member '{expr.name}' of {object} has type {member}")
|
||||||
|
return member
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
||||||
|
match expr.value:
|
||||||
|
case bool(): # Must be before int
|
||||||
|
return self.types.get_type("bool")
|
||||||
|
case int():
|
||||||
|
return self.types.get_type("int")
|
||||||
|
case float():
|
||||||
|
return self.types.get_type("float")
|
||||||
|
case str():
|
||||||
|
return self.types.get_type("str")
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
||||||
|
type: Optional[Type] = self.look_up_variable(expr.name, expr)
|
||||||
|
if type is None:
|
||||||
|
self.logger.debug(f"Unknown variable {expr.name} in {self.env.flat_dict()}")
|
||||||
|
self.reporter.warning(expr.location, "Unknown variable")
|
||||||
|
return type or UnknownType()
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
||||||
|
left: Type = self.type_of(expr.left)
|
||||||
|
right: Type = self.type_of(expr.right)
|
||||||
|
|
||||||
|
if self.is_subtype(left, right):
|
||||||
|
return right
|
||||||
|
if self.is_subtype(right, left):
|
||||||
|
return left
|
||||||
|
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Incompatible operand types, {left=} and {right=}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||||
|
return self.resolve_type_expr(expr.type)
|
||||||
|
|
||||||
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||||
|
test_type: Type = self.type_of(expr.test)
|
||||||
|
|
||||||
|
# TODO Allow subtypes or any type
|
||||||
|
if test_type != self.types.get_type("bool"):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.test.location, f"If test must be a boolean, got {test_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
true_type: Type = self.type_of(expr.if_true)
|
||||||
|
false_type: Type = self.type_of(expr.if_false)
|
||||||
|
if self.is_subtype(true_type, false_type):
|
||||||
|
return false_type
|
||||||
|
if self.is_subtype(false_type, true_type):
|
||||||
|
return true_type
|
||||||
|
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_list_expr(self, expr: p.ListExpr) -> Type:
|
||||||
|
list_type: Type = self.types.get_type("list")
|
||||||
|
item_types: list[Type] = [self.type_of(item) for item in expr.items]
|
||||||
|
item_types = self.types.reduce_types(item_types)
|
||||||
|
|
||||||
|
if len(item_types) == 0:
|
||||||
|
return list_type
|
||||||
|
|
||||||
|
if len(item_types) == 1:
|
||||||
|
item_type: Type = item_types[0]
|
||||||
|
return self.types.apply_generic(list_type, [item_type])
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Heterogeneous list items: {item_types}",
|
||||||
|
)
|
||||||
|
return self.types.apply_generic(list_type, [UnknownType()])
|
||||||
|
|
||||||
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
|
||||||
|
object: Type = self.type_of(expr.object)
|
||||||
|
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
|
||||||
|
if operation is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Undefined method __getitem__ on {object}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
index: Type = self.type_of(expr.index)
|
||||||
|
|
||||||
|
match operation:
|
||||||
|
case Function() as function:
|
||||||
|
if not self._check_arity(function, 1, 0, 0):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Wrong definition of __getitem__. Expected function with 1 positional-only parameters, got {function}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
index_arg: Function.Argument = function.pos_args[0]
|
||||||
|
if not self.is_subtype(index, index_arg.type):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Wrong index type, expected {index_arg.type}, got {index}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return function.returns
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operation {operation}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||||
|
base: Type
|
||||||
|
try:
|
||||||
|
base = self.types.get_type(node.base)
|
||||||
|
except NameError:
|
||||||
|
self.reporter.warning(node.location, f"Unknown type '{node.base}'")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
if node.param is not None:
|
||||||
|
param: Type = self.resolve_type_expr(node.param)
|
||||||
|
return self.types.apply_generic(base, [param])
|
||||||
|
return base
|
||||||
|
|
||||||
|
def visit_constraint_type(self, node: p.ConstraintType) -> Type:
|
||||||
|
self.reporter.warning(node.location, "ConstraintType not yet supported")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_frame_column(self, node: p.FrameColumn) -> Type:
|
||||||
|
self.reporter.warning(node.location, "FrameColumn not yet supported")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def visit_frame_type(self, node: p.FrameType) -> Type:
|
||||||
|
self.reporter.warning(node.location, "FrameType not yet supported")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def map_call_arguments(
|
||||||
|
self, function: Function, call: p.CallExpr
|
||||||
|
) -> list[MappedArgument]:
|
||||||
|
"""Map call arguments to function parameters as defined in its signature
|
||||||
|
|
||||||
|
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||||
|
with the arguments passed at the call site
|
||||||
|
|
||||||
|
Any mismatched, missing or unexpected argument is reported as a diagnostic
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function (Function): the function definition
|
||||||
|
call (p.CallExpr): the call expression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[MappedArgument]: the list of mapped arguments
|
||||||
|
"""
|
||||||
|
positional: list[tuple[p.Expr, Type]] = [
|
||||||
|
(arg, self.type_of(arg)) for arg in call.arguments
|
||||||
|
]
|
||||||
|
keywords: dict[str, tuple[p.Expr, Type]] = {
|
||||||
|
name: (arg, self.type_of(arg)) for name, arg in call.keywords.items()
|
||||||
|
}
|
||||||
|
set_args: set[str] = set()
|
||||||
|
|
||||||
|
required_positional: list[str] = [
|
||||||
|
arg.name for arg in function.pos_args + function.args if arg.required
|
||||||
|
]
|
||||||
|
required_keyword: list[str] = [
|
||||||
|
arg.name for arg in function.kw_args if arg.required
|
||||||
|
]
|
||||||
|
|
||||||
|
mapped: list[MappedArgument] = []
|
||||||
|
|
||||||
|
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||||
|
mixed_params: list[Function.Argument] = list(function.args)
|
||||||
|
kw_params: dict[str, Function.Argument] = {
|
||||||
|
arg.name: arg for arg in function.kw_args
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
for arg in positional:
|
||||||
|
param: Function.Argument
|
||||||
|
if len(pos_params) != 0:
|
||||||
|
param = pos_params.pop(0)
|
||||||
|
elif len(mixed_params) != 0:
|
||||||
|
param = mixed_params.pop(0)
|
||||||
|
else:
|
||||||
|
self.reporter.error(arg[0].location, "Too many positional arguments")
|
||||||
|
break
|
||||||
|
name: str = param.name
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||||
|
for name, arg in keywords.items():
|
||||||
|
param: Function.Argument
|
||||||
|
if name not in kw_params:
|
||||||
|
if name in set_args:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Multiple values for argument '{name}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
param = kw_params.pop(name)
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def join_args(args: list[str]) -> str:
|
||||||
|
args = list(map(lambda a: f"'{a}'", args))
|
||||||
|
if len(args) == 0:
|
||||||
|
return ""
|
||||||
|
if len(args) == 1:
|
||||||
|
return args[0]
|
||||||
|
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||||
|
|
||||||
|
if len(required_positional) != 0:
|
||||||
|
plural: str = "" if len(required_positional) == 1 else "s"
|
||||||
|
args: str = join_args(required_positional)
|
||||||
|
self.reporter.error(
|
||||||
|
call.location,
|
||||||
|
f"Missing required positional argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(required_keyword) != 0:
|
||||||
|
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||||
|
args: str = join_args(required_keyword)
|
||||||
|
self.reporter.error(
|
||||||
|
call.location,
|
||||||
|
f"Missing required keyword argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return mapped
|
||||||
|
|
||||||
|
def _check_arity(
|
||||||
|
self,
|
||||||
|
function: Function,
|
||||||
|
n_pos: Optional[int] = None,
|
||||||
|
n_mixed: Optional[int] = None,
|
||||||
|
n_keyword: Optional[int] = None,
|
||||||
|
) -> bool:
|
||||||
|
if n_pos is not None and len(function.pos_args) != n_pos:
|
||||||
|
return False
|
||||||
|
if n_mixed is not None and len(function.args) != n_mixed:
|
||||||
|
return False
|
||||||
|
if n_keyword is not None and len(function.kw_args) != n_keyword:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
347
midas/checker/registry.py
Normal file
347
midas/checker/registry.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||||
|
from midas.checker.types import (
|
||||||
|
AliasType,
|
||||||
|
AppliedType,
|
||||||
|
BaseType,
|
||||||
|
ComplexType,
|
||||||
|
ExtensionType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
OverloadedFunction,
|
||||||
|
TopType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
UnknownType,
|
||||||
|
substitute_typevars,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TypesRegistry:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
||||||
|
self._types: dict[str, Type] = {}
|
||||||
|
self._members: dict[str, dict[str, Type]] = {}
|
||||||
|
|
||||||
|
def get_type(self, name: str) -> Type:
|
||||||
|
"""Get a type from its name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NameError: if the type is not defined
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the type
|
||||||
|
"""
|
||||||
|
if name in self._types:
|
||||||
|
return self._types[name]
|
||||||
|
raise NameError(f"Undefined type {name}")
|
||||||
|
|
||||||
|
def define_type(self, name: str, type: Type) -> Type:
|
||||||
|
"""Define a type in the registry
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the type
|
||||||
|
type (Type): the type to define
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if a type is already defined with that name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the defined type
|
||||||
|
"""
|
||||||
|
if name in self._types:
|
||||||
|
raise ValueError(f"Type {name} already defined")
|
||||||
|
self._types[name] = type
|
||||||
|
return type
|
||||||
|
|
||||||
|
def define_member(
|
||||||
|
self, type_name: str, member_name: str, member_type: Type, is_method: bool
|
||||||
|
):
|
||||||
|
members: dict[str, Type] = self._members.setdefault(type_name, {})
|
||||||
|
if member_name in members:
|
||||||
|
if not is_method:
|
||||||
|
self.logger.error(
|
||||||
|
f"Member '{member_name}' already defined for type {type_name}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
current: Type = members[member_name]
|
||||||
|
combined: Type
|
||||||
|
match current:
|
||||||
|
case OverloadedFunction(overloads=overloads):
|
||||||
|
combined = OverloadedFunction(overloads=overloads + [member_type])
|
||||||
|
case _:
|
||||||
|
combined = OverloadedFunction(overloads=[current, member_type])
|
||||||
|
members[member_name] = combined
|
||||||
|
|
||||||
|
else:
|
||||||
|
members[member_name] = member_type
|
||||||
|
|
||||||
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
|
"""Check whether `type1` is a subtype of `type2`
|
||||||
|
|
||||||
|
For more details on the rules checked here, see TAPL Chap. 15-16-17
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type1 (Type): the potential subtype
|
||||||
|
type2 (Type): the potential supertype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether `type1` is a subtype of `type2`
|
||||||
|
"""
|
||||||
|
|
||||||
|
if type1 == type2:
|
||||||
|
return True
|
||||||
|
|
||||||
|
match (type1, type2):
|
||||||
|
case (_, TopType()):
|
||||||
|
return True
|
||||||
|
|
||||||
|
case (AliasType(type=base1), _):
|
||||||
|
return self.is_subtype(base1, type2)
|
||||||
|
|
||||||
|
case (BaseType(name=name1), BaseType(name=name2)):
|
||||||
|
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
||||||
|
|
||||||
|
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||||
|
for k, t in props2.items():
|
||||||
|
if k not in props1:
|
||||||
|
return False
|
||||||
|
if not self.is_subtype(props1[k], t):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
case (Function(), Function()):
|
||||||
|
return self.is_func_subtype(type1, type2)
|
||||||
|
|
||||||
|
case (TypeVar(bound=bound), _):
|
||||||
|
if bound is None:
|
||||||
|
return False
|
||||||
|
return self.is_subtype(bound, type2)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
# TODO: verify the logic in here
|
||||||
|
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
||||||
|
"""Check whether a function is a subtype of another
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func1 (Function): the potential function subtype
|
||||||
|
func2 (Function): the potential function supertype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether `func1` is a subtype of `func2`
|
||||||
|
"""
|
||||||
|
if not self.is_subtype(func1.returns, func2.returns):
|
||||||
|
return False
|
||||||
|
|
||||||
|
pos1: list[Function.Argument] = func1.pos_args
|
||||||
|
mixed1: list[Function.Argument] = func1.args
|
||||||
|
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
|
||||||
|
pos2: list[Function.Argument] = func2.pos_args
|
||||||
|
mixed2: list[Function.Argument] = func2.args
|
||||||
|
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
|
||||||
|
|
||||||
|
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
|
||||||
|
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
|
||||||
|
|
||||||
|
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
|
||||||
|
if not self.is_subtype(sub.type, sup.type):
|
||||||
|
return False
|
||||||
|
if not sup.required and sub.required:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
for arg1 in pos1:
|
||||||
|
arg2: Function.Argument
|
||||||
|
if arg1.pos < len(pos2):
|
||||||
|
arg2 = pos2[arg1.pos]
|
||||||
|
elif arg1.pos in mixed_by_pos:
|
||||||
|
arg2 = mixed_by_pos[arg1.pos]
|
||||||
|
elif not arg1.required:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
if not is_arg_subtype(arg2, arg1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for name, arg1 in kw1.items():
|
||||||
|
arg2: Function.Argument
|
||||||
|
if name in kw2:
|
||||||
|
arg2 = kw2[name]
|
||||||
|
elif name in mixed_by_name:
|
||||||
|
arg2 = mixed_by_name[name]
|
||||||
|
elif not arg1.required:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
if not is_arg_subtype(arg2, arg1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for arg1 in mixed1:
|
||||||
|
pos_arg2: Optional[Function.Argument] = None
|
||||||
|
kw_arg2: Optional[Function.Argument] = None
|
||||||
|
if arg1.name in kw2:
|
||||||
|
kw_arg2 = kw2[arg1.name]
|
||||||
|
elif arg1.name in mixed_by_name:
|
||||||
|
kw_arg2 = mixed_by_name[arg1.name]
|
||||||
|
if arg1.pos < len(pos2):
|
||||||
|
pos_arg2 = pos2[arg1.pos]
|
||||||
|
elif arg1.pos in mixed_by_pos:
|
||||||
|
pos_arg2 = mixed_by_pos[arg1.pos]
|
||||||
|
|
||||||
|
# No match in func2 and arg is required
|
||||||
|
if pos_arg2 is None and kw_arg2 is None and arg1.required:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Matching keyword argument
|
||||||
|
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Matching positional argument
|
||||||
|
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
mixed_positions: set[int] = {a.pos for a in mixed1}
|
||||||
|
mixed_names: set[str] = {a.name for a in mixed1}
|
||||||
|
for arg2 in pos2:
|
||||||
|
if not arg2.required:
|
||||||
|
continue
|
||||||
|
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for name, arg2 in kw2.items():
|
||||||
|
if not arg2.required:
|
||||||
|
continue
|
||||||
|
if name not in kw1 and name not in mixed_names:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for arg2 in mixed2:
|
||||||
|
if arg2.required:
|
||||||
|
continue
|
||||||
|
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
|
||||||
|
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
|
||||||
|
if not pos_match or not kw_match:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def apply_generic(self, type: Type, args: list[Type]) -> Type:
|
||||||
|
match type:
|
||||||
|
case AliasType(name=name, type=base):
|
||||||
|
return AliasType(name=name, type=self.apply_generic(base, args))
|
||||||
|
|
||||||
|
case GenericType(name=name, params=type_vars, body=body):
|
||||||
|
n_args: int = len(args)
|
||||||
|
n_type_vars: int = len(type_vars)
|
||||||
|
if n_args < n_type_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
|
||||||
|
)
|
||||||
|
if n_args > n_type_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
|
||||||
|
)
|
||||||
|
substitutions: dict[str, Type] = {}
|
||||||
|
for arg, type_var in zip(args, type_vars):
|
||||||
|
if type_var.bound is not None and not self.is_subtype(
|
||||||
|
arg, type_var.bound
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Type argument {arg} is not a subtype of {type_var.bound}"
|
||||||
|
)
|
||||||
|
substitutions[type_var.name] = arg
|
||||||
|
return AppliedType(
|
||||||
|
name=name,
|
||||||
|
args=args,
|
||||||
|
body=substitute_typevars(body, substitutions),
|
||||||
|
)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"{type} is not a generic type")
|
||||||
|
|
||||||
|
def reduce_types(self, types: list[Type]) -> list[Type]:
|
||||||
|
"""Reduce a list of types to remove subtypes and only keep the highest types
|
||||||
|
|
||||||
|
Args:
|
||||||
|
types (list[Type]): the types to reduce
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Type]: the reduced list of types
|
||||||
|
"""
|
||||||
|
|
||||||
|
reduced: bool = True
|
||||||
|
keep: list[int] = list(range(len(types)))
|
||||||
|
while reduced:
|
||||||
|
reduced = False
|
||||||
|
for i, i1 in enumerate(keep):
|
||||||
|
type1: Type = types[i1]
|
||||||
|
for i2 in keep[i + 1 :]:
|
||||||
|
type2 = types[i2]
|
||||||
|
if self.is_subtype(type1, type2):
|
||||||
|
keep.remove(i1)
|
||||||
|
elif self.is_subtype(type2, type1):
|
||||||
|
keep.remove(i2)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
reduced = True
|
||||||
|
break
|
||||||
|
return [types[i] for i in keep]
|
||||||
|
|
||||||
|
def lookup_member(self, type: Type, member_name: str) -> Optional[Type]:
|
||||||
|
match type:
|
||||||
|
case BaseType(name=name):
|
||||||
|
if name in self._members:
|
||||||
|
if member_name in self._members[name]:
|
||||||
|
return self._members[name][member_name]
|
||||||
|
return None
|
||||||
|
|
||||||
|
case AliasType(name=name, type=base):
|
||||||
|
if name in self._members:
|
||||||
|
if member_name in self._members[name]:
|
||||||
|
return self._members[name][member_name]
|
||||||
|
return self.lookup_member(base, member_name)
|
||||||
|
|
||||||
|
case AppliedType(name=name, body=body, args=args):
|
||||||
|
generic: Type = self.get_type(name)
|
||||||
|
|
||||||
|
if not isinstance(generic, GenericType):
|
||||||
|
raise ValueError("AppliedType not derived from a GenericType")
|
||||||
|
|
||||||
|
substitutions = {
|
||||||
|
type_var.name: arg for arg, type_var in zip(args, generic.params)
|
||||||
|
}
|
||||||
|
if name in self._members:
|
||||||
|
if member_name in self._members[name]:
|
||||||
|
member_type: Type = self._members[name][member_name]
|
||||||
|
return substitute_typevars(member_type, substitutions)
|
||||||
|
|
||||||
|
member_type2: Optional[Type] = self.lookup_member(body, member_name)
|
||||||
|
if member_type2 is not None:
|
||||||
|
member_type2 = substitute_typevars(member_type2, substitutions)
|
||||||
|
return member_type2
|
||||||
|
|
||||||
|
case ComplexType(members=members):
|
||||||
|
if member_name in members:
|
||||||
|
return members[member_name]
|
||||||
|
self.logger.debug(f"No member '{member_name}' in {type}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
case ExtensionType(base=base, extension=ComplexType(members=members)):
|
||||||
|
if member_name in members:
|
||||||
|
return members[member_name]
|
||||||
|
self.logger.debug(
|
||||||
|
f"No member '{member_name}' on {type}, looking up in base"
|
||||||
|
)
|
||||||
|
return self.lookup_member(base, member_name)
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self.logger.debug(f"Can't get member on {type}")
|
||||||
|
return None
|
||||||
63
midas/checker/reporter.py
Normal file
63
midas/checker/reporter.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||||
|
|
||||||
|
|
||||||
|
class Reporter:
|
||||||
|
def __init__(self):
|
||||||
|
self.diagnostics: list[Diagnostic] = []
|
||||||
|
|
||||||
|
def report(
|
||||||
|
self,
|
||||||
|
path: Optional[str],
|
||||||
|
type: DiagnosticType,
|
||||||
|
location: Location,
|
||||||
|
message: str,
|
||||||
|
):
|
||||||
|
self.diagnostics.append(
|
||||||
|
Diagnostic(
|
||||||
|
file_path=path,
|
||||||
|
location=location,
|
||||||
|
type=type,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def for_file(self, path: Optional[str]) -> FileReporter:
|
||||||
|
return FileReporter(self, path)
|
||||||
|
|
||||||
|
|
||||||
|
class FileReporter:
|
||||||
|
def __init__(self, base_reporter: Reporter, path: Optional[str]) -> None:
|
||||||
|
self.base_reporter: Reporter = base_reporter
|
||||||
|
self.path: Optional[str] = path
|
||||||
|
|
||||||
|
def for_file(self, path: Optional[str]) -> FileReporter:
|
||||||
|
return FileReporter(self.base_reporter, path)
|
||||||
|
|
||||||
|
def report(self, type: DiagnosticType, location: Location, message: str):
|
||||||
|
self.base_reporter.report(self.path, type, location, message)
|
||||||
|
|
||||||
|
def error(self, location: Location, message: str):
|
||||||
|
self.report(
|
||||||
|
type=DiagnosticType.ERROR,
|
||||||
|
location=location,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def warning(self, location: Location, message: str):
|
||||||
|
self.report(
|
||||||
|
type=DiagnosticType.WARNING,
|
||||||
|
location=location,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def info(self, location: Location, message: str):
|
||||||
|
self.report(
|
||||||
|
type=DiagnosticType.INFO,
|
||||||
|
location=location,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
@@ -13,7 +13,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.locals: dict[p.Expr, int] = {}
|
self.locals: dict[p.Expr, int] = {}
|
||||||
self.scopes: list[dict[str, bool]] = []
|
self.scopes: list[dict[str, bool]] = [{}]
|
||||||
|
|
||||||
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
|
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
|
||||||
"""Resolve the given statements or expressions"""
|
"""Resolve the given statements or expressions"""
|
||||||
@@ -77,6 +77,12 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
self.locals[expr] = i
|
self.locals[expr] = i
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def is_defined(self, name: str) -> bool:
|
||||||
|
for scope in self.scopes:
|
||||||
|
if name in scope:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def resolve_function(self, function: p.Function) -> None:
|
def resolve_function(self, function: p.Function) -> None:
|
||||||
"""Resolve a function definition
|
"""Resolve a function definition
|
||||||
|
|
||||||
@@ -112,8 +118,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
for target in stmt.targets:
|
for target in stmt.targets:
|
||||||
match target:
|
match target:
|
||||||
case p.VariableExpr(name=name):
|
case p.VariableExpr(name=name):
|
||||||
self.resolve_local(target, name)
|
if not self.is_defined(name):
|
||||||
# TODO: declare if not found
|
self.declare(name)
|
||||||
|
self.define(name)
|
||||||
|
target.accept(self)
|
||||||
|
|
||||||
|
case p.GetExpr():
|
||||||
|
target.accept(self)
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f"Unsupported assignment to {target}")
|
raise Exception(f"Unsupported assignment to {target}")
|
||||||
|
|
||||||
@@ -174,10 +185,6 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
self.resolve(expr.left)
|
self.resolve(expr.left)
|
||||||
self.resolve(expr.right)
|
self.resolve(expr.right)
|
||||||
|
|
||||||
def visit_set_expr(self, expr: p.SetExpr) -> None:
|
|
||||||
self.resolve(expr.value)
|
|
||||||
self.resolve(expr.object)
|
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||||
self.resolve(expr.expr)
|
self.resolve(expr.expr)
|
||||||
|
|
||||||
@@ -185,3 +192,11 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
self.resolve(expr.test)
|
self.resolve(expr.test)
|
||||||
self.resolve(expr.if_true)
|
self.resolve(expr.if_true)
|
||||||
self.resolve(expr.if_false)
|
self.resolve(expr.if_false)
|
||||||
|
|
||||||
|
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||||
|
for item in expr.items:
|
||||||
|
self.resolve(item)
|
||||||
|
|
||||||
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||||
|
self.resolve(expr.object)
|
||||||
|
self.resolve(expr.index)
|
||||||
@@ -1,54 +1,225 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TopType:
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class BaseType:
|
class BaseType:
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class AliasType:
|
class AliasType:
|
||||||
name: str
|
name: str
|
||||||
type: Type
|
type: Type
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class UnknownType:
|
class UnknownType:
|
||||||
pass
|
def __str__(self) -> str:
|
||||||
|
return "<Unknown>"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class UnitType:
|
class UnitType:
|
||||||
pass
|
def __str__(self) -> str:
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class Function:
|
class Function:
|
||||||
name: str
|
|
||||||
pos_args: list[Argument]
|
pos_args: list[Argument]
|
||||||
args: list[Argument]
|
args: list[Argument]
|
||||||
kw_args: list[Argument]
|
kw_args: list[Argument]
|
||||||
returns: Type
|
returns: Type
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
args: list[str] = []
|
||||||
|
if len(self.pos_args) != 0:
|
||||||
|
args += list(map(str, self.pos_args))
|
||||||
|
if len(self.args) + len(self.kw_args) != 0:
|
||||||
|
args.append("/")
|
||||||
|
|
||||||
|
if len(self.args) != 0:
|
||||||
|
args += list(map(str, self.args))
|
||||||
|
|
||||||
|
if len(self.kw_args) != 0:
|
||||||
|
if len(args) != 0:
|
||||||
|
args.append("*")
|
||||||
|
args += list(map(str, self.kw_args))
|
||||||
|
|
||||||
|
return f"({', '.join(args)}) -> {self.returns}"
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class Argument:
|
class Argument:
|
||||||
|
pos: int
|
||||||
name: str
|
name: str
|
||||||
type: Type
|
type: Type
|
||||||
required: bool
|
required: bool
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
opt: str = "" if self.required else "?"
|
||||||
|
return f"{self.name}: {self.type}{opt}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class OverloadedFunction:
|
||||||
|
overloads: list[Type]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "<overloaded function>"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class ComplexType:
|
class ComplexType:
|
||||||
properties: dict[str, Type]
|
members: dict[str, Type]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
props: list[str] = [f"{name}: {type}" for name, type in self.members.items()]
|
||||||
|
return f"{{{', '.join(props)}}}"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class UnionType:
|
class ExtensionType:
|
||||||
alternatives: list[Type]
|
base: Type
|
||||||
|
extension: ComplexType
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.base} & {self.extension}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TypeVar:
|
||||||
|
name: str
|
||||||
|
bound: Optional[Type]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if self.bound is not None:
|
||||||
|
return f"{self.name} <: {self.bound}"
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class GenericType:
|
||||||
|
name: str
|
||||||
|
params: list[TypeVar]
|
||||||
|
body: Type
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.name}[{', '.join(map(str, self.params))}]"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class AppliedType:
|
||||||
|
name: str
|
||||||
|
args: list[Type]
|
||||||
|
body: Type
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
||||||
|
|
||||||
|
|
||||||
|
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||||
|
def sub_argument(arg: Function.Argument):
|
||||||
|
return Function.Argument(
|
||||||
|
pos=arg.pos,
|
||||||
|
name=arg.name,
|
||||||
|
type=substitute_typevars(arg.type, substitutions),
|
||||||
|
required=arg.required,
|
||||||
|
)
|
||||||
|
|
||||||
|
match type:
|
||||||
|
case BaseType(name=name) if name in substitutions:
|
||||||
|
return substitutions[name]
|
||||||
|
|
||||||
|
case BaseType():
|
||||||
|
return type
|
||||||
|
|
||||||
|
case AliasType(name=name, type=type2):
|
||||||
|
return AliasType(name=name, type=substitute_typevars(type2, substitutions))
|
||||||
|
|
||||||
|
case Function(
|
||||||
|
pos_args=pos_args,
|
||||||
|
args=args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=returns,
|
||||||
|
):
|
||||||
|
return Function(
|
||||||
|
pos_args=list(map(sub_argument, pos_args)),
|
||||||
|
args=list(map(sub_argument, args)),
|
||||||
|
kw_args=list(map(sub_argument, kw_args)),
|
||||||
|
returns=substitute_typevars(returns, substitutions),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ComplexType(members=members):
|
||||||
|
members2: dict[str, Type] = {
|
||||||
|
name: substitute_typevars(prop, substitutions)
|
||||||
|
for name, prop in members.items()
|
||||||
|
}
|
||||||
|
return ComplexType(members=members2)
|
||||||
|
|
||||||
|
case ExtensionType(base=base, extension=ComplexType(members=members)):
|
||||||
|
return ExtensionType(
|
||||||
|
base=substitute_typevars(base, substitutions),
|
||||||
|
extension=ComplexType(
|
||||||
|
members={
|
||||||
|
name: substitute_typevars(prop, substitutions)
|
||||||
|
for name, prop in members.items()
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
case AppliedType(name=name, args=args, body=body):
|
||||||
|
return AppliedType(
|
||||||
|
name=name,
|
||||||
|
args=[substitute_typevars(arg, substitutions) for arg in args],
|
||||||
|
body=substitute_typevars(body, substitutions),
|
||||||
|
)
|
||||||
|
|
||||||
|
case TypeVar(name=name):
|
||||||
|
if name in substitutions:
|
||||||
|
return substitutions[name]
|
||||||
|
raise ValueError(f"Missing TypeVar substitution for {name}")
|
||||||
|
|
||||||
|
case UnknownType() | UnitType():
|
||||||
|
return type
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError(f"Unsupported type {type}")
|
||||||
|
|
||||||
|
|
||||||
|
def unfold_type(type: Type) -> Type:
|
||||||
|
match type:
|
||||||
|
case AliasType(type=ref_type):
|
||||||
|
return unfold_type(ref_type)
|
||||||
|
case _:
|
||||||
|
return type
|
||||||
|
|
||||||
|
|
||||||
Type = (
|
Type = (
|
||||||
BaseType | AliasType | UnknownType | UnitType | Function | ComplexType | UnionType
|
TopType
|
||||||
|
| BaseType
|
||||||
|
| AliasType
|
||||||
|
| UnknownType
|
||||||
|
| UnitType
|
||||||
|
| Function
|
||||||
|
| OverloadedFunction
|
||||||
|
| ComplexType
|
||||||
|
| ExtensionType
|
||||||
|
| TypeVar
|
||||||
|
| GenericType
|
||||||
|
| AppliedType
|
||||||
)
|
)
|
||||||
|
|||||||
41
midas/cli/ansi.py
Normal file
41
midas/cli/ansi.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
class Ansi:
|
||||||
|
CTRL = "\x1b["
|
||||||
|
RESET = CTRL + "0m"
|
||||||
|
BOLD = CTRL + "1m"
|
||||||
|
DIM = CTRL + "2m"
|
||||||
|
ITALIC = CTRL + "3m"
|
||||||
|
UNDERLINE = CTRL + "4m"
|
||||||
|
|
||||||
|
BLACK = 0
|
||||||
|
RED = 1
|
||||||
|
GREEN = 2
|
||||||
|
YELLOW = 3
|
||||||
|
BLUE = 4
|
||||||
|
MAGENTA = 5
|
||||||
|
CYAN = 6
|
||||||
|
WHITE = 7
|
||||||
|
|
||||||
|
BRIGHT_BLACK = 60
|
||||||
|
BRIGHT_RED = 61
|
||||||
|
BRIGHT_GREEN = 62
|
||||||
|
BRIGHT_YELLOW = 63
|
||||||
|
BRIGHT_BLUE = 64
|
||||||
|
BRIGHT_MAGENTA = 65
|
||||||
|
BRIGHT_CYAN = 66
|
||||||
|
BRIGHT_WHITE = 67
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def FG(cls, col: int) -> str:
|
||||||
|
return f"{cls.CTRL}{30 + col}m"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def BG(cls, col: int) -> str:
|
||||||
|
return f"{cls.CTRL}{40 + col}m"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def FG_RGB(cls, r: int, g: int, b: int) -> str:
|
||||||
|
return f"{cls.CTRL}38;2;{r};{g};{b}m"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def BG_RGB(cls, r: int, g: int, b: int) -> str:
|
||||||
|
return f"{cls.CTRL}48;2;{r};{g};{b}m"
|
||||||
@@ -53,5 +53,6 @@ span {
|
|||||||
|
|
||||||
&.keyword {
|
&.keyword {
|
||||||
color: rgb(211, 72, 9);
|
color: rgb(211, 72, 9);
|
||||||
|
pointer-events: none;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generic, Optional, Protocol, TextIO, TypeVar
|
from typing import Generic, Optional, Protocol, TextIO, TypeVar
|
||||||
|
|
||||||
@@ -8,6 +9,7 @@ import midas.ast.midas as m
|
|||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.diagnostic import Diagnostic
|
from midas.checker.diagnostic import Diagnostic
|
||||||
|
from midas.lexer.token import Token
|
||||||
|
|
||||||
H = TypeVar("H", bound="Highlighter", contravariant=True)
|
H = TypeVar("H", bound="Highlighter", contravariant=True)
|
||||||
|
|
||||||
@@ -22,6 +24,15 @@ class Locatable(Protocol):
|
|||||||
def location(self) -> Optional[Location]: ...
|
def location(self) -> Optional[Location]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LocatableToken:
|
||||||
|
token: Token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def location(self) -> Location:
|
||||||
|
return self.token.get_location()
|
||||||
|
|
||||||
|
|
||||||
class Highlighter(ABC):
|
class Highlighter(ABC):
|
||||||
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
|
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
|
||||||
EXTRA_CSS_PATH: Optional[Path] = None
|
EXTRA_CSS_PATH: Optional[Path] = None
|
||||||
@@ -199,61 +210,47 @@ class PythonHighlighter(
|
|||||||
|
|
||||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
|
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
|
||||||
|
|
||||||
def visit_set_expr(self, expr: p.SetExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
|
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||||
|
for item in expr.items:
|
||||||
|
item.accept(self)
|
||||||
|
|
||||||
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||||
|
expr.object.accept(self)
|
||||||
|
expr.index.accept(self)
|
||||||
|
|
||||||
|
|
||||||
|
class MidasHighlighter(
|
||||||
|
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
|
||||||
|
):
|
||||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
|
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
|
||||||
|
|
||||||
def highlight(self, node: Highlightable[MidasHighlighter]):
|
def highlight(self, node: Highlightable[MidasHighlighter]):
|
||||||
node.accept(self)
|
node.accept(self)
|
||||||
|
|
||||||
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
|
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||||
self.wrap(stmt, "simple-type")
|
self.wrap(stmt, "type-stmt")
|
||||||
if stmt.template is not None:
|
self.wrap(LocatableToken(stmt.name), "type-name")
|
||||||
stmt.template.accept(self)
|
stmt.type.accept(self)
|
||||||
stmt.base.accept(self)
|
|
||||||
if stmt.constraint is not None:
|
def visit_member_stmt(self, stmt: m.MemberStmt) -> None:
|
||||||
self.wrap(stmt.constraint, "constraint")
|
self.wrap(stmt, "member")
|
||||||
stmt.constraint.accept(self)
|
|
||||||
|
|
||||||
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None:
|
|
||||||
self.wrap(stmt, "complex-type")
|
|
||||||
if stmt.template is not None:
|
|
||||||
stmt.template.accept(self)
|
|
||||||
for prop in stmt.properties:
|
|
||||||
prop.accept(self)
|
|
||||||
|
|
||||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
|
|
||||||
self.wrap(stmt, "property")
|
|
||||||
stmt.type.accept(self)
|
stmt.type.accept(self)
|
||||||
if stmt.constraint is not None:
|
|
||||||
self.wrap(stmt.constraint, "constraint")
|
|
||||||
stmt.constraint.accept(self)
|
|
||||||
|
|
||||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||||
self.wrap(stmt, "extend")
|
self.wrap(stmt, "extend")
|
||||||
stmt.type.accept(self)
|
for member in stmt.members:
|
||||||
for op in stmt.operations:
|
member.accept(self)
|
||||||
op.accept(self)
|
|
||||||
|
|
||||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
|
||||||
self.wrap(stmt, "op")
|
|
||||||
stmt.operand.accept(self)
|
|
||||||
stmt.result.accept(self)
|
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||||
self.wrap(stmt, "predicate")
|
self.wrap(stmt, "predicate")
|
||||||
|
self.wrap(LocatableToken(stmt.name), "predicate-name")
|
||||||
stmt.type.accept(self)
|
stmt.type.accept(self)
|
||||||
stmt.condition.accept(self)
|
stmt.condition.accept(self)
|
||||||
|
|
||||||
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None:
|
|
||||||
self.wrap(expr, "simple-type-expr")
|
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||||
self.wrap(expr, "logical-expr")
|
self.wrap(expr, "logical-expr")
|
||||||
expr.left.accept(self)
|
expr.left.accept(self)
|
||||||
@@ -282,14 +279,35 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
|
|||||||
|
|
||||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||||
|
|
||||||
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
|
def visit_named_type(self, type: m.NamedType) -> None:
|
||||||
self.wrap(expr, "template")
|
self.wrap(type, "named-type")
|
||||||
expr.type.accept(self)
|
|
||||||
|
|
||||||
def visit_type_expr(self, expr: m.TypeExpr) -> None:
|
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||||
self.wrap(expr, "type")
|
self.wrap(type, "generic-type")
|
||||||
if expr.template is not None:
|
type.type.accept(self)
|
||||||
expr.template.accept(self)
|
for arg in type.args:
|
||||||
|
arg.accept(self)
|
||||||
|
|
||||||
|
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||||
|
self.wrap(type, "constraint-type")
|
||||||
|
type.type.accept(self)
|
||||||
|
type.constraint.accept(self)
|
||||||
|
|
||||||
|
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||||
|
self.wrap(type, "complex-type")
|
||||||
|
for member in type.members:
|
||||||
|
member.accept(self)
|
||||||
|
|
||||||
|
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||||
|
self.wrap(type, "function")
|
||||||
|
for arg in type.pos_args + type.args + type.kw_args:
|
||||||
|
arg.type.accept(self)
|
||||||
|
type.returns.accept(self)
|
||||||
|
|
||||||
|
def visit_extension_type(self, type: m.ExtensionType) -> None:
|
||||||
|
self.wrap(type, "extension")
|
||||||
|
type.base.accept(self)
|
||||||
|
type.extension.accept(self)
|
||||||
|
|
||||||
|
|
||||||
class DiagnosticsHighlighter(Highlighter):
|
class DiagnosticsHighlighter(Highlighter):
|
||||||
|
|||||||
@@ -5,12 +5,11 @@ span {
|
|||||||
font-style: italic;
|
font-style: italic;
|
||||||
}
|
}
|
||||||
|
|
||||||
&.simple-type {
|
&.named-type,
|
||||||
--col: 108, 233, 108;
|
&.generic-type,
|
||||||
}
|
&.constraint-type,
|
||||||
|
|
||||||
&.complex-type {
|
&.complex-type {
|
||||||
--col: 233, 206, 108;
|
--col: 150, 150, 150;
|
||||||
}
|
}
|
||||||
|
|
||||||
&.constraint {
|
&.constraint {
|
||||||
@@ -33,10 +32,6 @@ span {
|
|||||||
--col: 193, 108, 233;
|
--col: 193, 108, 233;
|
||||||
}
|
}
|
||||||
|
|
||||||
&.simple-type-expr {
|
|
||||||
--col: 150, 150, 150;
|
|
||||||
}
|
|
||||||
|
|
||||||
&.logical-expr,
|
&.logical-expr,
|
||||||
&.binary-expr,
|
&.binary-expr,
|
||||||
&.unary-expr,
|
&.unary-expr,
|
||||||
@@ -48,7 +43,9 @@ span {
|
|||||||
--col: 163, 117, 71;
|
--col: 163, 117, 71;
|
||||||
}
|
}
|
||||||
|
|
||||||
&.type {
|
&.type-name,
|
||||||
|
&.op-name,
|
||||||
|
&.predicate-name {
|
||||||
--col: 200, 200, 200;
|
--col: 200, 200, 200;
|
||||||
font-weight: bold;
|
font-weight: bold;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import ast
|
import ast
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, TextIO, get_args
|
from typing import Optional, TextIO, get_args
|
||||||
|
|
||||||
@@ -10,13 +9,15 @@ import click
|
|||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
|
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
|
||||||
from midas.checker.checker import Checker
|
from midas.checker.checker import TypeChecker
|
||||||
from midas.checker.diagnostic import Diagnostic
|
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||||
from midas.checker.types import Type
|
from midas.checker.types import Type
|
||||||
|
from midas.cli.ansi import Ansi
|
||||||
from midas.cli.highlighter import (
|
from midas.cli.highlighter import (
|
||||||
DiagnosticsHighlighter,
|
DiagnosticsHighlighter,
|
||||||
Highlighter,
|
Highlighter,
|
||||||
|
LocatableToken,
|
||||||
MidasHighlighter,
|
MidasHighlighter,
|
||||||
PythonHighlighter,
|
PythonHighlighter,
|
||||||
)
|
)
|
||||||
@@ -24,41 +25,126 @@ from midas.lexer.midas import MidasLexer
|
|||||||
from midas.lexer.token import Token, TokenType
|
from midas.lexer.token import Token, TokenType
|
||||||
from midas.parser.midas import MidasParser
|
from midas.parser.midas import MidasParser
|
||||||
from midas.parser.python import PythonParser
|
from midas.parser.python import PythonParser
|
||||||
from midas.resolver.resolver import Resolver
|
|
||||||
from midas.utils import UniversalJSONDumper
|
from midas.utils import UniversalJSONDumper
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
def midas():
|
def midas():
|
||||||
click.echo("Welcome to Midas!")
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def print_diagnostic(lines: list[str], diagnostic: Diagnostic, indent: int = 4):
|
||||||
|
"""Pretty-print a diagnostic, showing some context if possible
|
||||||
|
|
||||||
|
If the diagnostic concerns a specific part of one line, the line is shown
|
||||||
|
with the affected part highlighted. The message is clearly printed under the
|
||||||
|
line with an underline further indicating the target expression.
|
||||||
|
|
||||||
|
If multiple lines are concerned, no context is shown, only the
|
||||||
|
diagnostic type, location and message
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lines (list[str]): source code lines
|
||||||
|
diagnostic (Diagnostic): the diagnostic to print
|
||||||
|
indent (int, optional): the number of spaces added before the target line to indent if from the location header. Defaults to 4.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loc: Location = diagnostic.location
|
||||||
|
if loc.lineno != loc.end_lineno:
|
||||||
|
print(diagnostic)
|
||||||
|
return
|
||||||
|
|
||||||
|
start_offset: int = loc.col_offset
|
||||||
|
end_offset: int = loc.end_col_offset or (start_offset + 1)
|
||||||
|
|
||||||
|
line: str = lines[loc.lineno - 1]
|
||||||
|
before: str = line[:start_offset]
|
||||||
|
after: str = line[end_offset:]
|
||||||
|
|
||||||
|
color: int = {
|
||||||
|
DiagnosticType.ERROR: Ansi.RED,
|
||||||
|
DiagnosticType.WARNING: Ansi.YELLOW,
|
||||||
|
DiagnosticType.INFO: Ansi.CYAN,
|
||||||
|
}.get(diagnostic.type, Ansi.WHITE)
|
||||||
|
|
||||||
|
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
|
||||||
|
cursor: str = (
|
||||||
|
" " * start_offset
|
||||||
|
+ Ansi.FG(color)
|
||||||
|
+ "~" * (end_offset - start_offset)
|
||||||
|
+ "> "
|
||||||
|
+ diagnostic.message
|
||||||
|
+ Ansi.RESET
|
||||||
|
)
|
||||||
|
|
||||||
|
indent_str: str = " " * indent
|
||||||
|
print(diagnostic.location_str + ":")
|
||||||
|
print(indent_str + before + subject + after)
|
||||||
|
print(indent_str + cursor)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
@midas.command()
|
@midas.command()
|
||||||
@click.option("-l", "--highlight", type=click.File("w"))
|
@click.option("-l", "--highlight", type=click.File("w"))
|
||||||
|
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||||
|
@click.option("-v", "--verbose", is_flag=True)
|
||||||
|
@click.option("-j", "--show-judgements", is_flag=True)
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
def compile(highlight: Optional[TextIO], file: TextIO):
|
def compile(
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
highlight: Optional[TextIO],
|
||||||
|
types: tuple[TextIO],
|
||||||
|
verbose: bool,
|
||||||
|
show_judgements: bool,
|
||||||
|
file: TextIO,
|
||||||
|
):
|
||||||
|
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
|
||||||
source: str = file.read()
|
source: str = file.read()
|
||||||
tree: ast.Module = ast.parse(source, filename=file.name)
|
source_path: Path = Path(file.name).resolve()
|
||||||
parser = PythonParser()
|
|
||||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
|
||||||
resolver = Resolver()
|
|
||||||
resolver.resolve(*stmts)
|
|
||||||
checker = Checker(resolver.locals, file_path=Path(file.name).resolve())
|
|
||||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
|
||||||
for diagnostic in diagnostics:
|
|
||||||
print(diagnostic)
|
|
||||||
|
|
||||||
print(
|
checker = TypeChecker()
|
||||||
json.dumps(
|
for types_file in types:
|
||||||
UniversalJSONDumper.dump(
|
checker.import_midas(Path(types_file.name).resolve())
|
||||||
checker.global_env,
|
|
||||||
[("Environment", "_children")],
|
checker.type_check_source(source, str(source_path))
|
||||||
lambda obj: isinstance(obj, get_args(Type)),
|
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
|
||||||
),
|
lines: list[str] = source.split("\n")
|
||||||
indent=4,
|
files: dict[Optional[str], list[str]] = {None: []}
|
||||||
|
|
||||||
|
if show_judgements:
|
||||||
|
for expr, type in checker.python_typer.judgements:
|
||||||
|
print(f"Judged that {expr} at {expr.location} is of type {type}")
|
||||||
|
diagnostics.append(
|
||||||
|
Diagnostic(
|
||||||
|
file_path=str(source_path),
|
||||||
|
location=expr.location,
|
||||||
|
type=DiagnosticType.INFO,
|
||||||
|
message=f"Type: {type}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for diagnostic in diagnostics:
|
||||||
|
filename: Optional[str] = diagnostic.file_path
|
||||||
|
if filename is not None and filename not in files:
|
||||||
|
path: Path = Path(filename)
|
||||||
|
if path.exists() and path.is_file():
|
||||||
|
files[filename] = path.read_text().split("\n")
|
||||||
|
else:
|
||||||
|
files[filename] = []
|
||||||
|
|
||||||
|
lines: list[str] = files[filename]
|
||||||
|
print_diagnostic(lines, diagnostic)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
json.dumps(
|
||||||
|
UniversalJSONDumper.dump(
|
||||||
|
checker.python_typer.global_env,
|
||||||
|
[("Environment", "_children")],
|
||||||
|
lambda obj: isinstance(obj, get_args(Type)),
|
||||||
|
),
|
||||||
|
indent=4,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if highlight is not None:
|
if highlight is not None:
|
||||||
highlighter = DiagnosticsHighlighter(source)
|
highlighter = DiagnosticsHighlighter(source)
|
||||||
highlighter.highlight(diagnostics)
|
highlighter.highlight(diagnostics)
|
||||||
@@ -142,14 +228,6 @@ def highlight_midas(source: str, path: str) -> Highlighter:
|
|||||||
for err in parser.errors:
|
for err in parser.errors:
|
||||||
print(err.get_report())
|
print(err.get_report())
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class LocatableToken:
|
|
||||||
token: Token
|
|
||||||
|
|
||||||
@property
|
|
||||||
def location(self) -> Location:
|
|
||||||
return self.token.get_location()
|
|
||||||
|
|
||||||
for stmt in stmts:
|
for stmt in stmts:
|
||||||
highlighter.highlight(stmt)
|
highlighter.highlight(stmt)
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
@@ -176,5 +254,21 @@ def highlight(output: TextIO, file: TextIO):
|
|||||||
highlighter.dump(output)
|
highlighter.dump(output)
|
||||||
|
|
||||||
|
|
||||||
|
@midas.command()
|
||||||
|
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||||
|
@click.argument("file", type=click.File("r"))
|
||||||
|
def format(output: TextIO, file: TextIO):
|
||||||
|
source: str = file.read()
|
||||||
|
printer = MidasPrinter()
|
||||||
|
lexer = MidasLexer(source, file=file.name)
|
||||||
|
tokens: list[Token] = lexer.process()
|
||||||
|
parser = MidasParser(tokens)
|
||||||
|
stmts: list[m.Stmt] = parser.parse()
|
||||||
|
for err in parser.errors:
|
||||||
|
print(err.get_report())
|
||||||
|
for stmt in stmts:
|
||||||
|
output.write(printer.print(stmt) + "\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
midas()
|
midas()
|
||||||
|
|||||||
@@ -18,8 +18,6 @@ class MidasLexer(Lexer):
|
|||||||
self.add_token(TokenType.LEFT_BRACE)
|
self.add_token(TokenType.LEFT_BRACE)
|
||||||
case "}":
|
case "}":
|
||||||
self.add_token(TokenType.RIGHT_BRACE)
|
self.add_token(TokenType.RIGHT_BRACE)
|
||||||
case "|":
|
|
||||||
self.add_token(TokenType.PIPE)
|
|
||||||
case "<":
|
case "<":
|
||||||
self.add_token(
|
self.add_token(
|
||||||
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS
|
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS
|
||||||
@@ -52,12 +50,14 @@ class MidasLexer(Lexer):
|
|||||||
# self.add_token(TokenType.PLUS)
|
# self.add_token(TokenType.PLUS)
|
||||||
case "-":
|
case "-":
|
||||||
self.add_token(TokenType.MINUS)
|
self.add_token(TokenType.MINUS)
|
||||||
# case "*":
|
case "*":
|
||||||
# self.add_token(TokenType.STAR)
|
self.add_token(TokenType.STAR)
|
||||||
case "/" if self.match("/"):
|
case "/" if self.match("/"):
|
||||||
self.scan_comment()
|
self.scan_comment()
|
||||||
case "/" if self.match("*"):
|
case "/" if self.match("*"):
|
||||||
self.scan_comment_multiline()
|
self.scan_comment_multiline()
|
||||||
|
case "/":
|
||||||
|
self.add_token(TokenType.SLASH)
|
||||||
case "\n":
|
case "\n":
|
||||||
self.add_token(TokenType.NEWLINE)
|
self.add_token(TokenType.NEWLINE)
|
||||||
case " " | "\r" | "\t":
|
case " " | "\r" | "\t":
|
||||||
|
|||||||
@@ -23,13 +23,12 @@ class TokenType(Enum):
|
|||||||
AND = auto()
|
AND = auto()
|
||||||
QMARK = auto()
|
QMARK = auto()
|
||||||
DOT = auto()
|
DOT = auto()
|
||||||
PIPE = auto()
|
|
||||||
|
|
||||||
# Operators
|
# Operators
|
||||||
# PLUS = auto()
|
# PLUS = auto()
|
||||||
MINUS = auto()
|
MINUS = auto()
|
||||||
# STAR = auto()
|
STAR = auto()
|
||||||
# SLASH = auto()
|
SLASH = auto()
|
||||||
GREATER = auto()
|
GREATER = auto()
|
||||||
GREATER_EQUAL = auto()
|
GREATER_EQUAL = auto()
|
||||||
LESS = auto()
|
LESS = auto()
|
||||||
@@ -47,10 +46,12 @@ class TokenType(Enum):
|
|||||||
|
|
||||||
# Keywords
|
# Keywords
|
||||||
TYPE = auto()
|
TYPE = auto()
|
||||||
OP = auto()
|
|
||||||
PREDICATE = auto()
|
PREDICATE = auto()
|
||||||
EXTEND = auto()
|
EXTEND = auto()
|
||||||
WHERE = auto()
|
WHERE = auto()
|
||||||
|
PROP = auto()
|
||||||
|
DEF = auto()
|
||||||
|
FUNC = auto()
|
||||||
|
|
||||||
# Misc
|
# Misc
|
||||||
COMMENT = auto()
|
COMMENT = auto()
|
||||||
@@ -61,13 +62,15 @@ class TokenType(Enum):
|
|||||||
|
|
||||||
KEYWORDS: dict[str, TokenType] = {
|
KEYWORDS: dict[str, TokenType] = {
|
||||||
"type": TokenType.TYPE,
|
"type": TokenType.TYPE,
|
||||||
"op": TokenType.OP,
|
|
||||||
"predicate": TokenType.PREDICATE,
|
"predicate": TokenType.PREDICATE,
|
||||||
"extend": TokenType.EXTEND,
|
"extend": TokenType.EXTEND,
|
||||||
"where": TokenType.WHERE,
|
"where": TokenType.WHERE,
|
||||||
"true": TokenType.TRUE,
|
"true": TokenType.TRUE,
|
||||||
"false": TokenType.FALSE,
|
"false": TokenType.FALSE,
|
||||||
"none": TokenType.NONE,
|
"none": TokenType.NONE,
|
||||||
|
"prop": TokenType.PROP,
|
||||||
|
"def": TokenType.DEF,
|
||||||
|
"fn": TokenType.FUNC,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,24 +7,26 @@ from midas.ast.midas import (
|
|||||||
ConstraintType,
|
ConstraintType,
|
||||||
Expr,
|
Expr,
|
||||||
ExtendStmt,
|
ExtendStmt,
|
||||||
|
ExtensionType,
|
||||||
|
FunctionType,
|
||||||
GenericType,
|
GenericType,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
GroupingExpr,
|
GroupingExpr,
|
||||||
LiteralExpr,
|
LiteralExpr,
|
||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
|
MemberKind,
|
||||||
|
MemberStmt,
|
||||||
NamedType,
|
NamedType,
|
||||||
OpStmt,
|
|
||||||
PredicateStmt,
|
PredicateStmt,
|
||||||
PropertyStmt,
|
|
||||||
Stmt,
|
Stmt,
|
||||||
Type,
|
Type,
|
||||||
|
TypeParam,
|
||||||
TypeStmt,
|
TypeStmt,
|
||||||
UnaryExpr,
|
UnaryExpr,
|
||||||
UnionType,
|
|
||||||
VariableExpr,
|
VariableExpr,
|
||||||
WildcardExpr,
|
WildcardExpr,
|
||||||
)
|
)
|
||||||
from midas.lexer.token import Token, TokenType
|
from midas.lexer.token import KEYWORDS, Token, TokenType
|
||||||
from midas.parser.base import Parser
|
from midas.parser.base import Parser
|
||||||
from midas.parser.errors import ParsingError
|
from midas.parser.errors import ParsingError
|
||||||
|
|
||||||
@@ -34,9 +36,10 @@ class MidasParser(Parser):
|
|||||||
|
|
||||||
SYNC_BOUNDARY: set[TokenType] = {
|
SYNC_BOUNDARY: set[TokenType] = {
|
||||||
TokenType.TYPE,
|
TokenType.TYPE,
|
||||||
TokenType.OP,
|
|
||||||
TokenType.EXTEND,
|
TokenType.EXTEND,
|
||||||
TokenType.PREDICATE,
|
TokenType.PREDICATE,
|
||||||
|
TokenType.PROP,
|
||||||
|
TokenType.FUNC,
|
||||||
}
|
}
|
||||||
|
|
||||||
def parse(self) -> list[Stmt]:
|
def parse(self) -> list[Stmt]:
|
||||||
@@ -108,10 +111,8 @@ class MidasParser(Parser):
|
|||||||
TypeStmt: the parsed type declaration statement
|
TypeStmt: the parsed type declaration statement
|
||||||
"""
|
"""
|
||||||
keyword: Token = self.previous()
|
keyword: Token = self.previous()
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
name: Token = self.consume_identifier("Expected type name")
|
||||||
params: list[TypeStmt.Param] = []
|
params: list[TypeParam] = self.type_params()
|
||||||
if self.check(TokenType.LEFT_BRACKET):
|
|
||||||
params = self.type_stmt_params()
|
|
||||||
|
|
||||||
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
||||||
|
|
||||||
@@ -124,24 +125,27 @@ class MidasParser(Parser):
|
|||||||
type=type,
|
type=type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def type_stmt_params(self) -> list[TypeStmt.Param]:
|
def type_params(self) -> list[TypeParam]:
|
||||||
"""Parse a generic template expression
|
"""Parse a list of type parameters
|
||||||
|
|
||||||
A template is written `[TypeExpr]`
|
Type parameters are a comma-separated list of type variables wrapped in brackets.
|
||||||
|
Each type variable is either a simple variable, or a bounded variable written `S <: T`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TemplateExpr: the parsed template expression
|
list[TypeParam]: the list of type parameters, if any, or an empty list
|
||||||
"""
|
"""
|
||||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
|
if not self.match(TokenType.LEFT_BRACKET):
|
||||||
params: list[TypeStmt.Param] = []
|
return []
|
||||||
|
|
||||||
|
params: list[TypeParam] = []
|
||||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
|
name: Token = self.consume_identifier("Expected type variable")
|
||||||
bound: Optional[Type] = None
|
bound: Optional[Type] = None
|
||||||
if self.match(TokenType.LESS):
|
if self.match(TokenType.LESS):
|
||||||
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
||||||
bound = self.type_expr()
|
bound = self.type_expr()
|
||||||
params.append(
|
params.append(
|
||||||
TypeStmt.Param(
|
TypeParam(
|
||||||
location=name.location_to(self.previous()),
|
location=name.location_to(self.previous()),
|
||||||
name=name,
|
name=name,
|
||||||
bound=bound,
|
bound=bound,
|
||||||
@@ -149,7 +153,7 @@ class MidasParser(Parser):
|
|||||||
)
|
)
|
||||||
if not self.match(TokenType.COMMA):
|
if not self.match(TokenType.COMMA):
|
||||||
break
|
break
|
||||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
|
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def type_expr(self) -> Type:
|
def type_expr(self) -> Type:
|
||||||
@@ -161,18 +165,19 @@ class MidasParser(Parser):
|
|||||||
Returns:
|
Returns:
|
||||||
TypeExpr: the parsed type expression
|
TypeExpr: the parsed type expression
|
||||||
"""
|
"""
|
||||||
return self.union_type()
|
base: Type
|
||||||
|
if self.match(TokenType.FUNC):
|
||||||
def union_type(self) -> Type:
|
base = self.function()
|
||||||
types: list[Type] = [self.constraint_type()]
|
else:
|
||||||
while self.match(TokenType.PIPE):
|
base = self.constraint_type()
|
||||||
types.append(self.constraint_type())
|
if self.match(TokenType.AND):
|
||||||
if len(types) == 1:
|
extension: ComplexType = self.complex_type()
|
||||||
return types[0]
|
return ExtensionType(
|
||||||
return UnionType(
|
location=Location.span(base.location, extension.location),
|
||||||
location=Location.span(types[0].location, types[-1].location),
|
base=base,
|
||||||
types=types,
|
extension=extension,
|
||||||
)
|
)
|
||||||
|
return base
|
||||||
|
|
||||||
def constraint_type(self) -> Type:
|
def constraint_type(self) -> Type:
|
||||||
type: Type = self.base_type()
|
type: Type = self.base_type()
|
||||||
@@ -199,55 +204,57 @@ class MidasParser(Parser):
|
|||||||
def generic_type(self) -> Type:
|
def generic_type(self) -> Type:
|
||||||
type: Type = self.named_type()
|
type: Type = self.named_type()
|
||||||
if self.check(TokenType.LEFT_BRACKET):
|
if self.check(TokenType.LEFT_BRACKET):
|
||||||
params: list[Type] = self.type_params()
|
args: list[Type] = self.type_args()
|
||||||
return GenericType(
|
return GenericType(
|
||||||
location=Location.span(type.location, self.previous().get_location()),
|
location=Location.span(type.location, self.previous().get_location()),
|
||||||
type=type,
|
type=type,
|
||||||
params=params,
|
args=args,
|
||||||
)
|
)
|
||||||
return type
|
return type
|
||||||
|
|
||||||
def type_params(self) -> list[Type]:
|
def type_args(self) -> list[Type]:
|
||||||
params: list[Type] = []
|
args: list[Type] = []
|
||||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
|
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
|
||||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||||
params.append(self.type_expr())
|
args.append(self.type_expr())
|
||||||
if not self.match(TokenType.COMMA):
|
if not self.match(TokenType.COMMA):
|
||||||
break
|
break
|
||||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
|
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
|
||||||
return params
|
return args
|
||||||
|
|
||||||
def named_type(self) -> Type:
|
def named_type(self) -> Type:
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
name: Token = self.consume_identifier("Expected type name")
|
||||||
return NamedType(
|
return NamedType(
|
||||||
location=name.get_location(),
|
location=name.get_location(),
|
||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def complex_type(self) -> Type:
|
def complex_type(self) -> ComplexType:
|
||||||
"""Parse a type definition body
|
"""Parse a type definition body
|
||||||
|
|
||||||
A type definition body is a set of whitespace-separated
|
A type definition body is a set of whitespace-separated
|
||||||
property statements enclosed in curly braces
|
property statements enclosed in curly braces
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[PropertyStmt]: the parsed type properties
|
ComplexType: the parsed complex type
|
||||||
"""
|
"""
|
||||||
left: Token = self.consume(
|
left: Token = self.consume(
|
||||||
TokenType.LEFT_BRACE, "Expected '{' to start type body"
|
TokenType.LEFT_BRACE, "Expected '{' to start type body"
|
||||||
)
|
)
|
||||||
properties: list[PropertyStmt] = []
|
members: list[MemberStmt] = []
|
||||||
|
# TODO: add keyword to differentiate properties and methods,
|
||||||
|
# and allow multiple methods with the same name but not properties
|
||||||
names: set[str] = set()
|
names: set[str] = set()
|
||||||
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
||||||
prop: PropertyStmt = self.property_stmt()
|
member: MemberStmt = self.member_stmt()
|
||||||
if prop.name.lexeme in names:
|
# if member.name.lexeme in names:
|
||||||
raise self.error(prop.name, "Duplicate property")
|
# raise self.error(member.name, "Duplicate property")
|
||||||
names.add(prop.name.lexeme)
|
# names.add(member.name.lexeme)
|
||||||
properties.append(prop)
|
members.append(member)
|
||||||
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
||||||
return ComplexType(
|
return ComplexType(
|
||||||
location=left.location_to(right),
|
location=left.location_to(right),
|
||||||
properties=properties,
|
members=members,
|
||||||
)
|
)
|
||||||
|
|
||||||
def constraint(self) -> Expr:
|
def constraint(self) -> Expr:
|
||||||
@@ -334,9 +341,7 @@ class MidasParser(Parser):
|
|||||||
"""
|
"""
|
||||||
expr: Expr = self.primary()
|
expr: Expr = self.primary()
|
||||||
while self.match(TokenType.DOT):
|
while self.match(TokenType.DOT):
|
||||||
name: Token = self.consume(
|
name: Token = self.consume_identifier("Expected property name after '.'")
|
||||||
TokenType.IDENTIFIER, "Expected property name after '.'"
|
|
||||||
)
|
|
||||||
location: Location = Location.span(expr.location, name.get_location())
|
location: Location = Location.span(expr.location, name.get_location())
|
||||||
expr = GetExpr(location=location, expr=expr, name=name)
|
expr = GetExpr(location=location, expr=expr, name=name)
|
||||||
return expr
|
return expr
|
||||||
@@ -360,7 +365,7 @@ class MidasParser(Parser):
|
|||||||
if self.match(TokenType.NUMBER):
|
if self.match(TokenType.NUMBER):
|
||||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||||
|
|
||||||
if self.match(TokenType.IDENTIFIER):
|
if self.match_identifier():
|
||||||
return VariableExpr(location=token.get_location(), name=token)
|
return VariableExpr(location=token.get_location(), name=token)
|
||||||
|
|
||||||
if self.match(TokenType.UNDERSCORE):
|
if self.match(TokenType.UNDERSCORE):
|
||||||
@@ -373,64 +378,70 @@ class MidasParser(Parser):
|
|||||||
|
|
||||||
raise self.error(self.peek(), "Expected expression")
|
raise self.error(self.peek(), "Expected expression")
|
||||||
|
|
||||||
def property_stmt(self) -> PropertyStmt:
|
def consume_identifier(self, message: str = "Expected identifier") -> Token:
|
||||||
"""Parse a property statement
|
if not self.match_identifier():
|
||||||
|
raise self.error(self.peek(), message)
|
||||||
|
return self.previous()
|
||||||
|
|
||||||
A type property statement is written `name: Type` or `name: Type where Condition`
|
def match_identifier(self) -> bool:
|
||||||
|
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
|
||||||
|
|
||||||
|
def check_identifier(self) -> bool:
|
||||||
|
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
|
||||||
|
if self.check(tt):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def member_stmt(self) -> MemberStmt:
|
||||||
|
"""Parse a member statement
|
||||||
|
|
||||||
|
A type member statement is written `prop name: Type` or `def name: Type`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
PropertyStmt: the parsed property statement
|
MemberStmt: the parsed member statement
|
||||||
"""
|
"""
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
kind: MemberKind
|
||||||
self.consume(TokenType.COLON, "Expected ':' after property name")
|
if self.match(TokenType.PROP):
|
||||||
|
kind = MemberKind.PROPERTY
|
||||||
|
elif self.match(TokenType.DEF):
|
||||||
|
kind = MemberKind.METHOD
|
||||||
|
else:
|
||||||
|
raise self.error(self.peek(), "Expected 'prop' or 'def'")
|
||||||
|
|
||||||
|
name: Token = self.consume_identifier("Expected member name")
|
||||||
|
self.consume(TokenType.COLON, "Expected ':' after member name")
|
||||||
|
|
||||||
type: Type = self.type_expr()
|
type: Type = self.type_expr()
|
||||||
return PropertyStmt(
|
return MemberStmt(
|
||||||
location=name.location_to(self.previous()),
|
location=name.location_to(self.previous()),
|
||||||
name=name,
|
name=name,
|
||||||
type=type,
|
type=type,
|
||||||
|
kind=kind,
|
||||||
)
|
)
|
||||||
|
|
||||||
def extend_declaration(self) -> ExtendStmt:
|
def extend_declaration(self) -> ExtendStmt:
|
||||||
"""Parse an extension definition
|
"""Parse an extension definition
|
||||||
|
|
||||||
An extension is written `extend Type { operations }`
|
An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ExtendStmt: the parsed extension statement
|
ExtendStmt: the parsed extension statement
|
||||||
"""
|
"""
|
||||||
keyword: Token = self.previous()
|
keyword: Token = self.previous()
|
||||||
type: Type = self.type_expr()
|
name: Token = self.consume_identifier("Expected type name")
|
||||||
|
params: list[TypeParam] = self.type_params()
|
||||||
|
|
||||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
||||||
operations: list[OpStmt] = []
|
members: list[MemberStmt] = []
|
||||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
|
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
|
||||||
operations.append(self.op_declaration())
|
members.append(self.member_stmt())
|
||||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
|
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
|
||||||
location: Location = keyword.location_to(self.previous())
|
location: Location = keyword.location_to(self.previous())
|
||||||
return ExtendStmt(location=location, type=type, operations=operations)
|
return ExtendStmt(
|
||||||
|
location=location,
|
||||||
def op_declaration(self) -> OpStmt:
|
|
||||||
"""Parse an operation definition
|
|
||||||
|
|
||||||
An operation is written `op name(Type) -> Type`
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OpStmt: the parsed operation statement
|
|
||||||
"""
|
|
||||||
keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword")
|
|
||||||
|
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
|
|
||||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
|
|
||||||
operand: Type = self.type_expr()
|
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
|
|
||||||
|
|
||||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
|
||||||
result: Type = self.type_expr()
|
|
||||||
|
|
||||||
return OpStmt(
|
|
||||||
location=keyword.location_to(self.previous()),
|
|
||||||
name=name,
|
name=name,
|
||||||
operand=operand,
|
params=params,
|
||||||
result=result,
|
members=members,
|
||||||
)
|
)
|
||||||
|
|
||||||
def predicate_declaration(self) -> PredicateStmt:
|
def predicate_declaration(self) -> PredicateStmt:
|
||||||
@@ -442,9 +453,9 @@ class MidasParser(Parser):
|
|||||||
PredicateStmt: the parsed predicate declaration statement
|
PredicateStmt: the parsed predicate declaration statement
|
||||||
"""
|
"""
|
||||||
keyword: Token = self.previous()
|
keyword: Token = self.previous()
|
||||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
|
name: Token = self.consume_identifier("Expected predicate name")
|
||||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
||||||
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
|
subject: Token = self.consume_identifier("Expected subject name")
|
||||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
||||||
type: Type = self.type_expr()
|
type: Type = self.type_expr()
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
||||||
@@ -457,3 +468,72 @@ class MidasParser(Parser):
|
|||||||
type=type,
|
type=type,
|
||||||
condition=condition,
|
condition=condition,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def function(self) -> FunctionType:
|
||||||
|
l_paren: Token = self.consume(
|
||||||
|
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||||
|
)
|
||||||
|
pos_args: list[FunctionType.Argument] = []
|
||||||
|
args: list[FunctionType.Argument] = []
|
||||||
|
kw_args: list[FunctionType.Argument] = []
|
||||||
|
|
||||||
|
args_first_tokens: list[Token] = []
|
||||||
|
|
||||||
|
section: int = 0
|
||||||
|
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
|
||||||
|
match section:
|
||||||
|
case 0 if self.match(TokenType.SLASH):
|
||||||
|
pos_args = args
|
||||||
|
args = []
|
||||||
|
args_first_tokens = []
|
||||||
|
section = 1
|
||||||
|
case 0 | 1 if self.match(TokenType.STAR):
|
||||||
|
section = 2
|
||||||
|
case _:
|
||||||
|
# Record first token of mixed argument for errors if unnamed
|
||||||
|
if section != 2:
|
||||||
|
args_first_tokens.append(self.peek())
|
||||||
|
|
||||||
|
name: Optional[Token] = None
|
||||||
|
if section == 2:
|
||||||
|
name = self.consume_identifier("Expected keyword argument name")
|
||||||
|
self.consume(
|
||||||
|
TokenType.COLON, "Expected ':' after argument name"
|
||||||
|
)
|
||||||
|
elif self.check_identifier() and self.check_next(TokenType.COLON):
|
||||||
|
name = self.advance()
|
||||||
|
self.advance()
|
||||||
|
|
||||||
|
type: Type = self.type_expr()
|
||||||
|
optional: bool = self.match(TokenType.QMARK)
|
||||||
|
arg = FunctionType.Argument(
|
||||||
|
location=None,
|
||||||
|
name=name,
|
||||||
|
type=type,
|
||||||
|
required=not optional,
|
||||||
|
)
|
||||||
|
if section == 2:
|
||||||
|
kw_args.append(arg)
|
||||||
|
else:
|
||||||
|
args.append(arg)
|
||||||
|
|
||||||
|
if not self.match(TokenType.COMMA):
|
||||||
|
break
|
||||||
|
|
||||||
|
for arg, token in zip(args, args_first_tokens):
|
||||||
|
if arg.name is None:
|
||||||
|
# Not raised because we can keep parsing
|
||||||
|
self.error(token, "Unnamed mixed argument")
|
||||||
|
|
||||||
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||||
|
|
||||||
|
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||||
|
result: Type = self.type_expr()
|
||||||
|
|
||||||
|
return FunctionType(
|
||||||
|
location=l_paren.location_to(self.previous()),
|
||||||
|
pos_args=pos_args,
|
||||||
|
args=args,
|
||||||
|
kw_args=kw_args,
|
||||||
|
returns=result,
|
||||||
|
)
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ from midas.ast.python import (
|
|||||||
Function,
|
Function,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
IfStmt,
|
IfStmt,
|
||||||
|
ListExpr,
|
||||||
LiteralExpr,
|
LiteralExpr,
|
||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
MidasType,
|
MidasType,
|
||||||
ReturnStmt,
|
ReturnStmt,
|
||||||
Stmt,
|
Stmt,
|
||||||
|
SubscriptExpr,
|
||||||
TernaryExpr,
|
TernaryExpr,
|
||||||
TypeAssign,
|
TypeAssign,
|
||||||
UnaryExpr,
|
UnaryExpr,
|
||||||
@@ -87,6 +89,9 @@ class PythonParser:
|
|||||||
case ast.If():
|
case ast.If():
|
||||||
return self.parse_if(node)
|
return self.parse_if(node)
|
||||||
|
|
||||||
|
case ast.Pass():
|
||||||
|
return None
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
print(f"Unsupported statement: {ast.unparse(node)}")
|
print(f"Unsupported statement: {ast.unparse(node)}")
|
||||||
return None
|
return None
|
||||||
@@ -311,6 +316,13 @@ class PythonParser:
|
|||||||
constraint=right_expr,
|
constraint=right_expr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case ast.Constant(value=None):
|
||||||
|
return BaseType(
|
||||||
|
location=loc,
|
||||||
|
base="None",
|
||||||
|
param=None,
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise UnsupportedSyntaxError(type_expr)
|
raise UnsupportedSyntaxError(type_expr)
|
||||||
|
|
||||||
@@ -406,6 +418,19 @@ class PythonParser:
|
|||||||
case ast.Name(id=name):
|
case ast.Name(id=name):
|
||||||
return VariableExpr(location=location, name=name)
|
return VariableExpr(location=location, name=name)
|
||||||
|
|
||||||
|
case ast.List(elts=items):
|
||||||
|
return ListExpr(
|
||||||
|
location=location,
|
||||||
|
items=[self.parse_expr(item) for item in items],
|
||||||
|
)
|
||||||
|
|
||||||
|
case ast.Subscript(value=value, slice=index):
|
||||||
|
return SubscriptExpr(
|
||||||
|
location=location,
|
||||||
|
object=self.parse_expr(value),
|
||||||
|
index=self.parse_expr(index),
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise UnsupportedSyntaxError(node)
|
raise UnsupportedSyntaxError(node)
|
||||||
|
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from midas.checker.types import BaseType, Type, UnitType
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from midas.resolver.midas import MidasResolver
|
|
||||||
|
|
||||||
|
|
||||||
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
|
|
||||||
ctx.define_operation(
|
|
||||||
left=t1,
|
|
||||||
operator=operator,
|
|
||||||
right=t2,
|
|
||||||
result=t3,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def basic_op(ctx: MidasResolver, type: Type, op: str):
|
|
||||||
ctx.define_operation(
|
|
||||||
left=type,
|
|
||||||
operator=op,
|
|
||||||
right=type,
|
|
||||||
result=type,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def define_builtins(ctx: MidasResolver):
|
|
||||||
"""Define builtin types and operations"""
|
|
||||||
unit = ctx.define_type("None", UnitType())
|
|
||||||
bool = ctx.define_type("bool", BaseType(name="bool"))
|
|
||||||
int = ctx.define_type("int", BaseType(name="int"))
|
|
||||||
float = ctx.define_type("float", BaseType(name="float"))
|
|
||||||
str = ctx.define_type("str", BaseType(name="str"))
|
|
||||||
|
|
||||||
basic_op(ctx, int, "__add__") # int + int = int
|
|
||||||
basic_op(ctx, int, "__sub__") # int - int = int
|
|
||||||
basic_op(ctx, int, "__mul__") # int * int = int
|
|
||||||
basic_op(ctx, int, "__pow__") # int ** int = int
|
|
||||||
basic_op(ctx, int, "__mod__") # int % int = int
|
|
||||||
basic_op(ctx, int, "__and__") # int & int = int
|
|
||||||
basic_op(ctx, int, "__or__") # int | int = int
|
|
||||||
basic_op(ctx, int, "__xor__") # int ^ int = int
|
|
||||||
op(ctx, int, "__lt__", int, bool) # int < int = bool
|
|
||||||
op(ctx, int, "__gt__", int, bool) # int > int = bool
|
|
||||||
op(ctx, int, "__le__", int, bool) # int <= int = bool
|
|
||||||
op(ctx, int, "__ge__", int, bool) # int >= int = bool
|
|
||||||
op(ctx, int, "__eq__", int, bool) # int == int = bool
|
|
||||||
basic_op(ctx, float, "__add__") # float + float = float
|
|
||||||
basic_op(ctx, float, "__sub__") # float - float = float
|
|
||||||
basic_op(ctx, float, "__mul__") # float * float = float
|
|
||||||
basic_op(ctx, float, "__truediv__") # float / float = float
|
|
||||||
op(ctx, float, "__lt__", float, bool) # float < float = bool
|
|
||||||
op(ctx, float, "__gt__", float, bool) # float > float = bool
|
|
||||||
op(ctx, float, "__le__", float, bool) # float <= float = bool
|
|
||||||
op(ctx, float, "__ge__", float, bool) # float >= float = bool
|
|
||||||
op(ctx, float, "__eq__", float, bool) # float == float = bool
|
|
||||||
basic_op(ctx, str, "__add__") # str + str = str
|
|
||||||
op(ctx, str, "__eq__", str, bool) # str == str = bool
|
|
||||||
|
|
||||||
op(ctx, int, "__lt__", float, bool) # int < float = bool
|
|
||||||
op(ctx, int, "__gt__", float, bool) # int > float = bool
|
|
||||||
op(ctx, int, "__le__", float, bool) # int <= float = bool
|
|
||||||
op(ctx, int, "__ge__", float, bool) # int >= float = bool
|
|
||||||
op(ctx, int, "__eq__", float, bool) # int == float = bool
|
|
||||||
|
|
||||||
op(ctx, float, "__lt__", int, bool) # float < int = bool
|
|
||||||
op(ctx, float, "__gt__", int, bool) # float > int = bool
|
|
||||||
op(ctx, float, "__le__", int, bool) # float <= int = bool
|
|
||||||
op(ctx, float, "__ge__", int, bool) # float >= int = bool
|
|
||||||
op(ctx, float, "__eq__", int, bool) # float == int = bool
|
|
||||||
@@ -1,166 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import midas.ast.midas as m
|
|
||||||
from midas.checker.types import (
|
|
||||||
Type,
|
|
||||||
UnionType,
|
|
||||||
UnknownType,
|
|
||||||
)
|
|
||||||
from midas.resolver.builtin import define_builtins
|
|
||||||
|
|
||||||
|
|
||||||
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
|
||||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._types: dict[str, Type] = {}
|
|
||||||
self._operations: dict[tuple[Type, str, Type], Type] = {}
|
|
||||||
|
|
||||||
define_builtins(self)
|
|
||||||
|
|
||||||
def get_type(self, name: str) -> Type:
|
|
||||||
"""Get a type from its name
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): the name of the type
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NameError: if the type is not defined
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Type: the type
|
|
||||||
"""
|
|
||||||
type: Optional[Type] = self._types.get(name)
|
|
||||||
if type is None:
|
|
||||||
raise NameError(f"Undefined type {name}")
|
|
||||||
return type
|
|
||||||
|
|
||||||
def get_operation_result(
|
|
||||||
self, left: Type, operator: str, right: Type
|
|
||||||
) -> Optional[Type]:
|
|
||||||
"""Get the resulting type of an operation
|
|
||||||
|
|
||||||
Args:
|
|
||||||
left (Type): the type of the left operand
|
|
||||||
operator (str): the operation name
|
|
||||||
right (Type): the type of the right operand
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Type]: the result type, or None if no matching operation was found
|
|
||||||
"""
|
|
||||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
|
||||||
result: Optional[Type] = self._operations.get(operation)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def define_type(self, name: str, type: Type) -> Type:
|
|
||||||
"""Define a type in the registry
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): the name of the type
|
|
||||||
type (Type): the type to define
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if a type is already defined with that name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Type: the defined type
|
|
||||||
"""
|
|
||||||
if name in self._types:
|
|
||||||
raise ValueError(f"Type {name} already defined")
|
|
||||||
self._types[name] = type
|
|
||||||
return type
|
|
||||||
|
|
||||||
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
|
|
||||||
"""Define an operation in the registry
|
|
||||||
|
|
||||||
Args:
|
|
||||||
left (Type): the type of the left operand
|
|
||||||
operator (str): the operation name
|
|
||||||
right (Type): the type of the right operand
|
|
||||||
result (Type): the result type
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if an operation is already defined with these operands and name
|
|
||||||
"""
|
|
||||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
|
||||||
if operation in self._operations:
|
|
||||||
raise ValueError(
|
|
||||||
f"Operation {operator} already defined between {left} and {right}"
|
|
||||||
)
|
|
||||||
self._operations[operation] = result
|
|
||||||
|
|
||||||
def resolve(self, stmts: list[m.Stmt]):
|
|
||||||
"""Process a sequence of statements
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stmts (list[m.Stmt]): the statements
|
|
||||||
"""
|
|
||||||
for stmt in stmts:
|
|
||||||
stmt.accept(self)
|
|
||||||
|
|
||||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
|
||||||
type: Type = stmt.type.accept(self)
|
|
||||||
for param in stmt.params:
|
|
||||||
if param.bound is not None:
|
|
||||||
param.bound.accept(self)
|
|
||||||
self.define_type(stmt.name.lexeme, type)
|
|
||||||
|
|
||||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
|
||||||
|
|
||||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
|
||||||
base: Type = stmt.type.accept(self)
|
|
||||||
for op in stmt.operations:
|
|
||||||
right: Type = op.operand.accept(self)
|
|
||||||
result: Type = op.result.accept(self)
|
|
||||||
self.define_operation(
|
|
||||||
left=base,
|
|
||||||
operator=op.name.lexeme,
|
|
||||||
right=right,
|
|
||||||
result=result,
|
|
||||||
)
|
|
||||||
|
|
||||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
|
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
|
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
|
||||||
return expr.expr.accept(self)
|
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
|
||||||
|
|
||||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
|
||||||
return self.get_type(type.name.lexeme)
|
|
||||||
|
|
||||||
def visit_generic_type(self, type: m.GenericType) -> Type:
|
|
||||||
type_: Type = type.type.accept(self)
|
|
||||||
params: list[Type] = [param.accept(self) for param in type.params]
|
|
||||||
# TODO
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
|
||||||
type_: Type = type.type.accept(self)
|
|
||||||
type.constraint.accept(self)
|
|
||||||
# TODO
|
|
||||||
return UnknownType()
|
|
||||||
|
|
||||||
def visit_union_type(self, type: m.UnionType) -> Type:
|
|
||||||
types: list[Type] = [type_.accept(self) for type_ in type.types]
|
|
||||||
return UnionType(alternatives=types)
|
|
||||||
|
|
||||||
def visit_complex_type(self, type: m.ComplexType) -> Type:
|
|
||||||
for prop in type.properties:
|
|
||||||
prop.accept(self)
|
|
||||||
# TODO
|
|
||||||
return UnknownType()
|
|
||||||
@@ -19,16 +19,24 @@ Comparison ::= Unary (ComparisonOp Unary)*
|
|||||||
Equality ::= Comparison (EqualityOp Comparison)*
|
Equality ::= Comparison (EqualityOp Comparison)*
|
||||||
Constraint ::= Equality ("&" Equality)*
|
Constraint ::= Equality ("&" Equality)*
|
||||||
|
|
||||||
SimpleType ::= Identifier "?"?
|
TemplateParam ::= Identifier ("<:" Type)?
|
||||||
Template ::= "[" Type "]"
|
Template ::= "[" (TemplateParam ("," TemplateParam)*)? "]"
|
||||||
Type ::= Identifier Template? "?"?
|
|
||||||
|
|
||||||
|
TypeProperty ::= Identifier ":" Type
|
||||||
|
ComplexType ::= "{" TypeProperty* "}"
|
||||||
|
NamedType ::= Identifier
|
||||||
|
TypeParams ::= "[" (Type ("," Type)*)? "]"
|
||||||
|
GenericType ::= NamedType TypeParams?
|
||||||
|
GroupedType ::= "(" Type ")"
|
||||||
|
BaseType ::= GroupedType | ComplexType | GenericType
|
||||||
|
ConstraintType ::= BaseType ("where" Constraint)?
|
||||||
|
Type ::= ConstraintType
|
||||||
|
|
||||||
TypeProperty ::= Identifier ":" Type ("where" Constraints)?
|
|
||||||
ComplexTypeBody ::= "{" TypeProperty* "}"
|
|
||||||
OpDefinition ::= "op" Identifier "(" Type ")" "->" Type
|
OpDefinition ::= "op" Identifier "(" Type ")" "->" Type
|
||||||
ExtendBody ::= "{" OpDefinition* "}"
|
ExtendBody ::= "{" OpDefinition* "}"
|
||||||
|
|
||||||
TypeStatement ::= "type" Identifier Template? ("(" Type ")" ("where" Constraint)? | ComplexTypeBody)
|
TypeStatement ::= "type" Identifier Template? "=" Type
|
||||||
ExtendStatement ::= "extend" Type ExtendBody
|
ExtendStatement ::= "extend" Type ExtendBody
|
||||||
PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint
|
PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint
|
||||||
|
|
||||||
|
|||||||
@@ -43,28 +43,52 @@ svg.railroad .terminal rect {
|
|||||||
{[`constraint` 'equality'*"&"]}
|
{[`constraint` 'equality'*"&"]}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let simple-type = ```
|
#let template-param = ```
|
||||||
{[`simple-type` 'identifier' <!, "?">]}
|
{[`template-param` 'identifier' <!, ["<:" 'type']>]}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let template = ```
|
#let template = ```
|
||||||
{[`template` "[" 'type' "]"]}
|
{[`template` "[" <!, 'template-param'*","> "]"]}
|
||||||
```
|
|
||||||
|
|
||||||
#let type = ```
|
|
||||||
{[`type` 'identifier' <!, 'template'> <!, "?">]}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#let type-property = ```
|
#let type-property = ```
|
||||||
{[`type-property` 'identifier' ":" 'type' <!, ["where" 'constraint']>]}
|
{[`type-property` 'identifier' ":" 'type']}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let type-body = ```
|
#let complex-type = ```
|
||||||
{[`type-body` "{" <!, 'type-property'*!> "}"]}
|
{[`complex-type` "{" <!, 'type-property'*!> "}"]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let named-type = ```
|
||||||
|
{[`named-type` 'identifier']}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let type-params = ```
|
||||||
|
{[`type-params` "[" <!, 'type'*","> "]"]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let generic-type = ```
|
||||||
|
{[`generic-type` 'named-type' <!, 'type-params'>]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let grouped-type = ```
|
||||||
|
{[`grouped-type` "(" 'type' ")"]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let base-type = ```
|
||||||
|
{[`base-type` <'grouped-type', 'complex-type', 'generic-type'>]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let constraint-type = ```
|
||||||
|
{[`constraint-type` 'base-type' <!, ["where" 'constraint']>]}
|
||||||
|
```
|
||||||
|
|
||||||
|
#let type = ```
|
||||||
|
{[`type` 'constraint-type']}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let type-statement = ```
|
#let type-statement = ```
|
||||||
{[`type-statement` "type" 'identifier' <!, 'template'> <[["(" 'type' ")"] <!, ["where" 'constraint']>], 'type-body'>]}
|
{[`type-statement` "type" 'identifier' <!, 'template'> "=" 'type']}
|
||||||
```
|
```
|
||||||
|
|
||||||
#let op-definition = ```
|
#let op-definition = ```
|
||||||
@@ -92,11 +116,17 @@ svg.railroad .terminal rect {
|
|||||||
comparison: comparison,
|
comparison: comparison,
|
||||||
equality: equality,
|
equality: equality,
|
||||||
constraint: constraint,
|
constraint: constraint,
|
||||||
simple-type: simple-type,
|
template-param: template-param,
|
||||||
template: template,
|
template: template,
|
||||||
type: type,
|
|
||||||
type-property: type-property,
|
type-property: type-property,
|
||||||
type-body: type-body,
|
complex-type: complex-type,
|
||||||
|
named-type: named-type,
|
||||||
|
type-params: type-params,
|
||||||
|
generic-type: generic-type,
|
||||||
|
grouped-type: grouped-type,
|
||||||
|
base-type: base-type,
|
||||||
|
constraint-type: constraint-type,
|
||||||
|
type: type,
|
||||||
type-statement: type-statement,
|
type-statement: type-statement,
|
||||||
op-definition: op-definition,
|
op-definition: op-definition,
|
||||||
extend-statement: extend-statement,
|
extend-statement: extend-statement,
|
||||||
@@ -107,10 +137,16 @@ svg.railroad .terminal rect {
|
|||||||
#let inline = (
|
#let inline = (
|
||||||
"grouping",
|
"grouping",
|
||||||
"value",
|
"value",
|
||||||
|
"template-param",
|
||||||
"template",
|
"template",
|
||||||
"simple-type",
|
|
||||||
"type-property",
|
"type-property",
|
||||||
"type-body",
|
"complex-type",
|
||||||
|
"type-params",
|
||||||
|
"named-type",
|
||||||
|
"grouped-type",
|
||||||
|
"generic-type",
|
||||||
|
"base-type",
|
||||||
|
"constraint-type",
|
||||||
"op-definition",
|
"op-definition",
|
||||||
"type-statement",
|
"type-statement",
|
||||||
"extend-statement",
|
"extend-statement",
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class Tester(ABC):
|
|||||||
def _list_tests(self) -> list[Path]: ...
|
def _list_tests(self) -> list[Path]: ...
|
||||||
|
|
||||||
def run_all_tests(self) -> bool:
|
def run_all_tests(self) -> bool:
|
||||||
paths: list[Path] = self._list_tests()
|
paths: list[Path] = sorted(self._list_tests())
|
||||||
return self.run_tests(paths)
|
return self.run_tests(paths)
|
||||||
|
|
||||||
def run_tests(self, tests: list[Path]) -> bool:
|
def run_tests(self, tests: list[Path]) -> bool:
|
||||||
@@ -40,7 +40,7 @@ class Tester(ABC):
|
|||||||
|
|
||||||
print(rule)
|
print(rule)
|
||||||
for i, test in enumerate(tests):
|
for i, test in enumerate(tests):
|
||||||
print(f"Case {i+1}/{n}: {test.relative_to(self.CASES_DIR)}")
|
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
|
||||||
success: bool = self._run_test(test)
|
success: bool = self._run_test(test)
|
||||||
if success:
|
if success:
|
||||||
successes += 1
|
successes += 1
|
||||||
@@ -78,7 +78,7 @@ class Tester(ABC):
|
|||||||
def _exec_case(self, path: Path) -> CaseResult: ...
|
def _exec_case(self, path: Path) -> CaseResult: ...
|
||||||
|
|
||||||
def update_all_tests(self):
|
def update_all_tests(self):
|
||||||
paths: list[Path] = self._list_tests()
|
paths: list[Path] = sorted(self._list_tests())
|
||||||
return self.update_tests(paths)
|
return self.update_tests(paths)
|
||||||
|
|
||||||
def update_tests(self, tests: list[Path]):
|
def update_tests(self, tests: list[Path]):
|
||||||
@@ -141,3 +141,9 @@ class Tester(ABC):
|
|||||||
success = tester.run_tests(args.FILE)
|
success = tester.run_tests(args.FILE)
|
||||||
if not success:
|
if not success:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
case None:
|
||||||
|
print("No subcommand provided. Available subcommands: run, update")
|
||||||
|
sys.exit(1)
|
||||||
|
case _:
|
||||||
|
print(f"Unknown subcommand '{args.subcommand}'")
|
||||||
|
sys.exit(1)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
{
|
{
|
||||||
"diagnostics": []
|
"diagnostics": [],
|
||||||
|
"judgments": []
|
||||||
}
|
}
|
||||||
@@ -12,7 +12,7 @@
|
|||||||
13
|
13
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')"
|
"message": "Cannot assign str to variable 'c' of type int"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "Error",
|
"type": "Error",
|
||||||
@@ -26,21 +26,166 @@
|
|||||||
9
|
9
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"message": "Undefined operation __add__ between BaseType(name='bool') and BaseType(name='bool')"
|
"message": "Undefined operation __add__ between bool and bool"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"judgments": [
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L1:9",
|
||||||
|
"to": "L1:10"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 3
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "Error",
|
|
||||||
"location": {
|
"location": {
|
||||||
"start": [
|
"from": "L2:9",
|
||||||
11,
|
"to": "L2:10"
|
||||||
0
|
|
||||||
],
|
|
||||||
"end": [
|
|
||||||
11,
|
|
||||||
12
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
"message": "Cannot assign BaseType(name='int') to f of type BaseType(name='float')"
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 4
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L4:4",
|
||||||
|
"to": "L4:5"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L4:8",
|
||||||
|
"to": "L4:9"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L4:4",
|
||||||
|
"to": "L4:9"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"operator": "+",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:4",
|
||||||
|
"to": "L6:13"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "invalid"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L8:4",
|
||||||
|
"to": "L8:8"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": true
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:4",
|
||||||
|
"to": "L9:5"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "d"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:8",
|
||||||
|
"to": "L9:9"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "d"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:4",
|
||||||
|
"to": "L9:9"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "d"
|
||||||
|
},
|
||||||
|
"operator": "+",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "d"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:11",
|
||||||
|
"to": "L11:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,14 +1,14 @@
|
|||||||
type Meter(float)
|
type Meter = float
|
||||||
type Second(float)
|
type Second = float
|
||||||
type MeterPerSecond(float)
|
type MeterPerSecond = float
|
||||||
|
|
||||||
extend Meter {
|
extend Meter {
|
||||||
op __add__(Meter) -> Meter
|
def __add__: fn(Meter, /) -> Meter
|
||||||
op __sub__(Meter) -> Meter
|
def __sub__: fn(Meter, /) -> Meter
|
||||||
op __truediv__(Second) -> MeterPerSecond
|
def __truediv__: fn(Second, /) -> MeterPerSecond
|
||||||
}
|
}
|
||||||
|
|
||||||
extend Second {
|
extend Second {
|
||||||
op __add__(Second) -> Second
|
def __add__: fn(Second, /) -> Second
|
||||||
op __sub__(Second) -> Second
|
def __sub__: fn(Second, /) -> Second
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
# type: ignore
|
# type: ignore
|
||||||
# ruff: disable [F821]
|
# ruff: disable [F821]
|
||||||
|
|
||||||
midas.using("04_custom_types.midas")
|
|
||||||
|
|
||||||
distance: Meter = cast(Meter, 123.45)
|
distance: Meter = cast(Meter, 123.45)
|
||||||
time: Second = cast(Second, 6.7)
|
time: Second = cast(Second, 6.7)
|
||||||
speed = distance / time
|
speed = distance / time
|
||||||
|
|||||||
@@ -1,3 +1,109 @@
|
|||||||
{
|
{
|
||||||
"diagnostics": []
|
"diagnostics": [],
|
||||||
|
"judgments": [
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L4:18",
|
||||||
|
"to": "L4:37"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CastExpr",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "Meter",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 123.45
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Meter",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L5:15",
|
||||||
|
"to": "L5:32"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CastExpr",
|
||||||
|
"type": {
|
||||||
|
"_type": "BaseType",
|
||||||
|
"base": "Second",
|
||||||
|
"param": null
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 6.7
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Second",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:8",
|
||||||
|
"to": "L6:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "distance"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Meter",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:19",
|
||||||
|
"to": "L6:23"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "time"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Second",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:8",
|
||||||
|
"to": "L6:23"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "distance"
|
||||||
|
},
|
||||||
|
"operator": "/",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "time"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "MeterPerSecond",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
@@ -42,5 +42,409 @@
|
|||||||
},
|
},
|
||||||
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
|
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
|
||||||
}
|
}
|
||||||
|
],
|
||||||
|
"judgments": [
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L2:11",
|
||||||
|
"to": "L2:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L2:15",
|
||||||
|
"to": "L2:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L2:11",
|
||||||
|
"to": "L2:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"operator": "+",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L5:7",
|
||||||
|
"to": "L5:8"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L5:11",
|
||||||
|
"to": "L5:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L5:7",
|
||||||
|
"to": "L5:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CompareExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"operator": "<",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:15",
|
||||||
|
"to": "L6:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:19",
|
||||||
|
"to": "L6:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:15",
|
||||||
|
"to": "L6:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"operator": "-",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L8:15",
|
||||||
|
"to": "L8:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L8:19",
|
||||||
|
"to": "L8:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L8:15",
|
||||||
|
"to": "L8:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"operator": "-",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L15:7",
|
||||||
|
"to": "L15:8"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L15:11",
|
||||||
|
"to": "L15:13"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 10
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L15:7",
|
||||||
|
"to": "L15:13"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CompareExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"operator": ">",
|
||||||
|
"right": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 10
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L16:15",
|
||||||
|
"to": "L16:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L16:19",
|
||||||
|
"to": "L16:21"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 10
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L16:15",
|
||||||
|
"to": "L16:21"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"operator": "-",
|
||||||
|
"right": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 10
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L18:15",
|
||||||
|
"to": "L18:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L22:7",
|
||||||
|
"to": "L22:8"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L22:11",
|
||||||
|
"to": "L22:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L22:7",
|
||||||
|
"to": "L22:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CompareExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"operator": "<",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L23:15",
|
||||||
|
"to": "L23:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L23:19",
|
||||||
|
"to": "L23:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L23:15",
|
||||||
|
"to": "L23:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"operator": "-",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L25:15",
|
||||||
|
"to": "L25:21"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "oops"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
12
tests/cases/checker/06_subtyping.py
Normal file
12
tests/cases/checker/06_subtyping.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
v1: int = 3
|
||||||
|
v2: float = 4
|
||||||
|
|
||||||
|
|
||||||
|
def maximum(a: float, b: float):
|
||||||
|
if b > a:
|
||||||
|
return b
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
v3 = maximum(v1, v2)
|
||||||
|
v3 = v2 + v1
|
||||||
239
tests/cases/checker/06_subtyping.py.ref.json
Normal file
239
tests/cases/checker/06_subtyping.py.ref.json
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
{
|
||||||
|
"diagnostics": [],
|
||||||
|
"judgments": [
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L1:10",
|
||||||
|
"to": "L1:11"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 3
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L2:12",
|
||||||
|
"to": "L2:13"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 4
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:7",
|
||||||
|
"to": "L6:8"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:11",
|
||||||
|
"to": "L6:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:7",
|
||||||
|
"to": "L6:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CompareExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"operator": ">",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L7:15",
|
||||||
|
"to": "L7:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "b"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L8:11",
|
||||||
|
"to": "L8:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:5",
|
||||||
|
"to": "L11:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "maximum"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"pos_args": [],
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "a",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pos": 1,
|
||||||
|
"name": "b",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:13",
|
||||||
|
"to": "L11:15"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v1"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:17",
|
||||||
|
"to": "L11:19"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v2"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:5",
|
||||||
|
"to": "L11:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CallExpr",
|
||||||
|
"callee": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "maximum"
|
||||||
|
},
|
||||||
|
"arguments": [
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v2"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"keywords": {}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:5",
|
||||||
|
"to": "L12:7"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v2"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:10",
|
||||||
|
"to": "L12:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v1"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:5",
|
||||||
|
"to": "L12:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v2"
|
||||||
|
},
|
||||||
|
"operator": "+",
|
||||||
|
"right": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -1,17 +1,17 @@
|
|||||||
// Simple custom type derived from float
|
// Simple custom type derived from float
|
||||||
type Custom(float)
|
type Custom = float
|
||||||
|
|
||||||
// Simple custom types with constraints
|
// Simple custom types with constraints
|
||||||
type Latitude(float) where (-90 <= _ <= 90)
|
type Latitude = float where (-90 <= _ <= 90)
|
||||||
type Longitude(float) where (-180 <= _ <= 180)
|
type Longitude = float where (-180 <= _ <= 180)
|
||||||
|
|
||||||
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
|
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
|
||||||
type Difference[T](T)
|
type Difference[T] = T
|
||||||
|
|
||||||
// Complex custom type, containing two values accessible through properties
|
// Complex custom type, containing two values accessible through properties
|
||||||
type GeoLocation {
|
type GeoLocation = {
|
||||||
lat: Latitude
|
prop lat: Latitude
|
||||||
lon: Longitude
|
prop lon: Longitude
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define operations on our custom type
|
// Define operations on our custom type
|
||||||
@@ -19,23 +19,23 @@ extend GeoLocation {
|
|||||||
// This type is compatible with the `-` operation with another GeoLocation
|
// This type is compatible with the `-` operation with another GeoLocation
|
||||||
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
|
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
|
||||||
// in a Difference of GeoLocations
|
// in a Difference of GeoLocations
|
||||||
op __sub__(GeoLocation) -> Difference[GeoLocation]
|
def __sub__: fn(GeoLocation, /) -> Difference[GeoLocation]
|
||||||
}
|
}
|
||||||
|
|
||||||
// For complex generics, you need to specify how the genericity the properties
|
// For complex generics, you need to specify how the genericity the properties
|
||||||
// are handled
|
// are handled
|
||||||
type Difference[GeoLocation] {
|
type Difference[GeoLocation] = {
|
||||||
lat: Difference[Latitude]
|
prop lat: Difference[Latitude]
|
||||||
lon: Difference[Longitude]
|
prop lon: Difference[Longitude]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simple operation defined on our custom types
|
// Simple operation defined on our custom types
|
||||||
extend Latitude {
|
extend Latitude {
|
||||||
op __sub__(Latitude) -> Difference[Latitude]
|
def __sub__: fn(Latitude, /) -> Difference[Latitude]
|
||||||
}
|
}
|
||||||
|
|
||||||
extend Longitude {
|
extend Longitude {
|
||||||
op __sub__(Longitude) -> Difference[Longitude]
|
def __sub__: fn(Longitude, /) -> Difference[Longitude]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Predefined custom predicates that can be referenced in other definitions
|
// Predefined custom predicates that can be referenced in other definitions
|
||||||
@@ -44,14 +44,14 @@ predicate StrictlyPositive(v: float) = v > 0
|
|||||||
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
|
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
|
||||||
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
|
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
|
||||||
|
|
||||||
type Person {
|
type Person = {
|
||||||
name: str
|
prop name: str
|
||||||
|
|
||||||
// Property with an inline constraint
|
// Property with an inline constraint
|
||||||
age: int? where (0 <= _ < 150)
|
prop age: Optional[int where (0 <= _ < 150)]
|
||||||
|
|
||||||
// Property referencing a predicate
|
// Property referencing a predicate
|
||||||
height: float where StrictlyPositive
|
prop height: float where StrictlyPositive
|
||||||
|
|
||||||
home: GeoLocation
|
prop home: GeoLocation
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -2,10 +2,6 @@
|
|||||||
# ruff: disable[F821]
|
# ruff: disable[F821]
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import midas
|
|
||||||
|
|
||||||
midas.using("02_custom_types.midas")
|
|
||||||
|
|
||||||
df: Frame[
|
df: Frame[
|
||||||
location: GeoLocation
|
location: GeoLocation
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,26 +1,5 @@
|
|||||||
{
|
{
|
||||||
"stmts": [
|
"stmts": [
|
||||||
{
|
|
||||||
"_type": "ExpressionStmt",
|
|
||||||
"expr": {
|
|
||||||
"_type": "CallExpr",
|
|
||||||
"callee": {
|
|
||||||
"_type": "GetExpr",
|
|
||||||
"object": {
|
|
||||||
"_type": "VariableExpr",
|
|
||||||
"name": "midas"
|
|
||||||
},
|
|
||||||
"name": "using"
|
|
||||||
},
|
|
||||||
"arguments": [
|
|
||||||
{
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": "02_custom_types.midas"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"keywords": {}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"_type": "TypeAssign",
|
"_type": "TypeAssign",
|
||||||
"name": "df",
|
"name": "df",
|
||||||
|
|||||||
@@ -1,19 +1,19 @@
|
|||||||
import ast
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.checker.checker import Checker
|
from midas.checker.checker import TypeChecker
|
||||||
from midas.checker.diagnostic import Diagnostic
|
from midas.checker.diagnostic import Diagnostic
|
||||||
from midas.parser.python import PythonParser
|
from midas.checker.types import Type
|
||||||
from midas.resolver.resolver import Resolver
|
|
||||||
from tests.base import Tester
|
from tests.base import Tester
|
||||||
|
from tests.serializer.python import PythonAstJsonSerializer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CaseResult:
|
class CaseResult:
|
||||||
diagnostics: list[dict] = field(default_factory=list)
|
diagnostics: list[dict] = field(default_factory=list)
|
||||||
|
judgments: list = field(default_factory=list)
|
||||||
|
|
||||||
def dumps(self) -> str:
|
def dumps(self) -> str:
|
||||||
return json.dumps(asdict(self), indent=2)
|
return json.dumps(asdict(self), indent=2)
|
||||||
@@ -33,15 +33,16 @@ class CheckerTester(Tester):
|
|||||||
if not path.is_file():
|
if not path.is_file():
|
||||||
raise TypeError(f"Test '{path}' is not a file")
|
raise TypeError(f"Test '{path}' is not a file")
|
||||||
|
|
||||||
source: str = path.read_text()
|
|
||||||
tree: ast.Module = ast.parse(source, filename=path)
|
|
||||||
parser = PythonParser()
|
|
||||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
|
||||||
resolver = Resolver()
|
|
||||||
resolver.resolve(*stmts)
|
|
||||||
result: CaseResult = CaseResult()
|
result: CaseResult = CaseResult()
|
||||||
checker = Checker(resolver.locals, file_path=path)
|
|
||||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
checker = TypeChecker()
|
||||||
|
types_path: Path = path.with_suffix(".midas")
|
||||||
|
if types_path.exists():
|
||||||
|
checker.import_midas(types_path)
|
||||||
|
|
||||||
|
checker.type_check(path)
|
||||||
|
|
||||||
|
diagnostics: list[Diagnostic] = checker.diagnostics
|
||||||
for diagnostic in diagnostics:
|
for diagnostic in diagnostics:
|
||||||
result.diagnostics.append(
|
result.diagnostics.append(
|
||||||
{
|
{
|
||||||
@@ -60,6 +61,21 @@ class CheckerTester(Tester):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
judgements: list[tuple[p.Expr, Type]] = checker.python_typer.judgements
|
||||||
|
serializer = PythonAstJsonSerializer()
|
||||||
|
for expr, type in judgements:
|
||||||
|
loc = expr.location
|
||||||
|
result.judgments.append(
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": f"L{loc.lineno}:{loc.col_offset}",
|
||||||
|
"to": f"L{loc.end_lineno}:{loc.end_col_offset}",
|
||||||
|
},
|
||||||
|
"expr": expr.accept(serializer),
|
||||||
|
"type": asdict(type),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,79 +2,76 @@ from typing import Optional, Sequence
|
|||||||
|
|
||||||
from midas.ast.midas import (
|
from midas.ast.midas import (
|
||||||
BinaryExpr,
|
BinaryExpr,
|
||||||
ComplexTypeStmt,
|
ComplexType,
|
||||||
|
ConstraintType,
|
||||||
Expr,
|
Expr,
|
||||||
ExtendStmt,
|
ExtendStmt,
|
||||||
|
ExtensionType,
|
||||||
|
FunctionType,
|
||||||
|
GenericType,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
GroupingExpr,
|
GroupingExpr,
|
||||||
LiteralExpr,
|
LiteralExpr,
|
||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
OpStmt,
|
MemberStmt,
|
||||||
|
NamedType,
|
||||||
PredicateStmt,
|
PredicateStmt,
|
||||||
PropertyStmt,
|
|
||||||
SimpleTypeExpr,
|
|
||||||
SimpleTypeStmt,
|
|
||||||
Stmt,
|
Stmt,
|
||||||
TemplateExpr,
|
Type,
|
||||||
TypeExpr,
|
TypeParam,
|
||||||
|
TypeStmt,
|
||||||
UnaryExpr,
|
UnaryExpr,
|
||||||
VariableExpr,
|
VariableExpr,
|
||||||
WildcardExpr,
|
WildcardExpr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
|
class MidasAstJsonSerializer(
|
||||||
|
Stmt.Visitor[dict], Expr.Visitor[dict], Type.Visitor[dict]
|
||||||
|
):
|
||||||
"""An AST serializer which produces a JSON-compatible structure"""
|
"""An AST serializer which produces a JSON-compatible structure"""
|
||||||
|
|
||||||
def serialize(self, stmts: list[Stmt]) -> list[dict]:
|
def serialize(self, stmts: list[Stmt]) -> list[dict]:
|
||||||
return [stmt.accept(self) for stmt in stmts]
|
return [stmt.accept(self) for stmt in stmts]
|
||||||
|
|
||||||
def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]:
|
def _serialize_optional(
|
||||||
|
self, element: Optional[Stmt | Expr | Type]
|
||||||
|
) -> Optional[dict]:
|
||||||
if element is None:
|
if element is None:
|
||||||
return None
|
return None
|
||||||
return element.accept(self)
|
return element.accept(self)
|
||||||
|
|
||||||
def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]:
|
def _serialize_list(self, elements: Sequence[Stmt | Expr | Type]) -> list[dict]:
|
||||||
return [element.accept(self) for element in elements]
|
return [element.accept(self) for element in elements]
|
||||||
|
|
||||||
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict:
|
def visit_type_stmt(self, stmt: TypeStmt) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "SimpleTypeStmt",
|
"_type": "TypeStmt",
|
||||||
"name": stmt.name.lexeme,
|
"name": stmt.name.lexeme,
|
||||||
"template": self._serialize_optional(stmt.template),
|
"params": [self._serialize_type_param(param) for param in stmt.params],
|
||||||
"base": stmt.base.accept(self),
|
"type": stmt.type.accept(self),
|
||||||
"constraint": self._serialize_optional(stmt.constraint),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict:
|
def _serialize_type_param(self, param: TypeParam) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "ComplexTypeStmt",
|
"name": param.name.lexeme,
|
||||||
"name": stmt.name.lexeme,
|
"bound": self._serialize_optional(param.bound),
|
||||||
"template": self._serialize_optional(stmt.template),
|
|
||||||
"properties": self._serialize_list(stmt.properties),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
|
def visit_member_stmt(self, stmt: MemberStmt) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "PropertyStmt",
|
"_type": "MemberStmt",
|
||||||
|
"kind": stmt.kind.name,
|
||||||
"name": stmt.name.lexeme,
|
"name": stmt.name.lexeme,
|
||||||
"type": stmt.type.accept(self),
|
"type": stmt.type.accept(self),
|
||||||
"constraint": self._serialize_optional(stmt.constraint),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
|
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "ExtendStmt",
|
"_type": "ExtendStmt",
|
||||||
"type": stmt.type.accept(self),
|
|
||||||
"operations": self._serialize_list(stmt.operations),
|
|
||||||
}
|
|
||||||
|
|
||||||
def visit_op_stmt(self, stmt: OpStmt) -> dict:
|
|
||||||
return {
|
|
||||||
"_type": "OpStmt",
|
|
||||||
"name": stmt.name.lexeme,
|
"name": stmt.name.lexeme,
|
||||||
"operand": stmt.operand.accept(self),
|
"params": [self._serialize_type_param(param) for param in stmt.params],
|
||||||
"result": stmt.result.accept(self),
|
"members": self._serialize_list(stmt.members),
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: PredicateStmt) -> dict:
|
def visit_predicate_stmt(self, stmt: PredicateStmt) -> dict:
|
||||||
@@ -86,13 +83,6 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
|
|||||||
"condition": stmt.condition.accept(self),
|
"condition": stmt.condition.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict:
|
|
||||||
return {
|
|
||||||
"_type": "SimpleTypeExpr",
|
|
||||||
"name": expr.name.lexeme,
|
|
||||||
"optional": expr.optional,
|
|
||||||
}
|
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "LogicalExpr",
|
"_type": "LogicalExpr",
|
||||||
@@ -144,16 +134,51 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
|
|||||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
|
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
|
||||||
return {"_type": "WildcardExpr"}
|
return {"_type": "WildcardExpr"}
|
||||||
|
|
||||||
def visit_template_expr(self, expr: TemplateExpr) -> dict:
|
def visit_named_type(self, type: NamedType) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "TemplateExpr",
|
"_type": "NamedType",
|
||||||
"type": expr.type.accept(self),
|
"name": type.name.lexeme,
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_type_expr(self, expr: TypeExpr) -> dict:
|
def visit_generic_type(self, type: GenericType) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "TypeExpr",
|
"_type": "GenericType",
|
||||||
"name": expr.name.lexeme,
|
"type": type.type.accept(self),
|
||||||
"template": self._serialize_optional(expr.template),
|
"args": self._serialize_list(type.args),
|
||||||
"optional": expr.optional,
|
}
|
||||||
|
|
||||||
|
def visit_constraint_type(self, type: ConstraintType) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ConstraintType",
|
||||||
|
"type": type.type.accept(self),
|
||||||
|
"constraint": type.constraint.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_complex_type(self, type: ComplexType) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ComplexType",
|
||||||
|
"members": self._serialize_list(type.members),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_function_type(self, type: FunctionType) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "FunctionType",
|
||||||
|
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
|
||||||
|
"args": [self._serialize_func_arg(arg) for arg in type.args],
|
||||||
|
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
|
||||||
|
"returns": type.returns.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||||
|
return {
|
||||||
|
"name": arg.name,
|
||||||
|
"type": arg.type.accept(self),
|
||||||
|
"required": arg.required,
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_extension_type(self, type: ExtensionType) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ExtensionType",
|
||||||
|
"base": type.base.accept(self),
|
||||||
|
"extension": type.extension.accept(self),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,12 +16,14 @@ from midas.ast.python import (
|
|||||||
Function,
|
Function,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
IfStmt,
|
IfStmt,
|
||||||
|
ListExpr,
|
||||||
LiteralExpr,
|
LiteralExpr,
|
||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
MidasType,
|
MidasType,
|
||||||
ReturnStmt,
|
ReturnStmt,
|
||||||
SetExpr,
|
|
||||||
Stmt,
|
Stmt,
|
||||||
|
SubscriptExpr,
|
||||||
|
TernaryExpr,
|
||||||
TypeAssign,
|
TypeAssign,
|
||||||
UnaryExpr,
|
UnaryExpr,
|
||||||
VariableExpr,
|
VariableExpr,
|
||||||
@@ -231,17 +233,30 @@ class PythonAstJsonSerializer(
|
|||||||
"right": expr.right.accept(self),
|
"right": expr.right.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_set_expr(self, expr: SetExpr) -> dict:
|
|
||||||
return {
|
|
||||||
"_type": "SetExpr",
|
|
||||||
"object": expr.object.accept(self),
|
|
||||||
"name": expr.name,
|
|
||||||
"value": expr.value.accept(self),
|
|
||||||
}
|
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: CastExpr) -> dict:
|
def visit_cast_expr(self, expr: CastExpr) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "CastExpr",
|
"_type": "CastExpr",
|
||||||
"type": expr.type.accept(self),
|
"type": expr.type.accept(self),
|
||||||
"expr": expr.expr.accept(self),
|
"expr": expr.expr.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "TernaryExpr",
|
||||||
|
"test": expr.test.accept(self),
|
||||||
|
"if_true": expr.if_true.accept(self),
|
||||||
|
"if_false": expr.if_false.accept(self),
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_list_expr(self, expr: ListExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ListExpr",
|
||||||
|
"items": [item.accept(self) for item in expr.items],
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_subscript_expr(self, expr: SubscriptExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "SubscriptExpr",
|
||||||
|
"object": expr.object.accept(self),
|
||||||
|
"index": expr.index.accept(self),
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user