Compare commits
101 Commits
35ceda99aa
...
feat/gener
| Author | SHA1 | Date | |
|---|---|---|---|
|
2886ffe00b
|
|||
|
def38e720b
|
|||
|
49274be2f4
|
|||
|
aec6b7aa7b
|
|||
|
530d93723e
|
|||
|
905132a18e
|
|||
|
9594c74952
|
|||
|
8df5607461
|
|||
|
757054d7af
|
|||
|
25a96d20e1
|
|||
|
04c0d683de
|
|||
|
ac620f318b
|
|||
|
947e9f0149
|
|||
|
c1ca254b51
|
|||
|
0bb862a1db
|
|||
|
54919a3565
|
|||
|
a5f0140013
|
|||
|
b0af01d906
|
|||
|
6048ee020f
|
|||
|
f815faa2f8
|
|||
|
f7d5d36d44
|
|||
|
503f2b6a0a
|
|||
|
778117664f
|
|||
|
afe3eefbbf
|
|||
|
96495e9f79
|
|||
|
77263139f6
|
|||
|
4f5967a151
|
|||
|
2a714a1021
|
|||
|
dafe0b471a
|
|||
|
a1f2937e16
|
|||
|
2063d94dce
|
|||
|
22fc8010d8
|
|||
|
aff1097d91
|
|||
|
12d034fd1e
|
|||
|
200709cca6
|
|||
|
700284296c
|
|||
|
0b53259b90
|
|||
|
0461a4184c
|
|||
|
01d6e41893
|
|||
|
80e611e49c
|
|||
|
c00915966f
|
|||
|
beaa4d95d8
|
|||
|
bfa0bb3ee0
|
|||
|
31158df2a9
|
|||
|
c6ead886ec
|
|||
|
9de03bf2b5
|
|||
|
a26b9293be
|
|||
|
efa5454776
|
|||
|
b8bb8190c4
|
|||
|
a4f5db7ece
|
|||
|
fc67f01f34
|
|||
|
0a748a36a3
|
|||
|
89fdd1b47e
|
|||
|
0cde53ac6e
|
|||
|
f3ec3606c2
|
|||
|
67ec029529
|
|||
|
e2aef7a811
|
|||
|
86ba4e658a
|
|||
|
7eccf59558
|
|||
|
9dd7801d2d
|
|||
|
154cb8b314
|
|||
|
c64ab434b5
|
|||
|
25e6410546
|
|||
|
8a22acc17c
|
|||
|
e0179bc442
|
|||
|
e665d03533
|
|||
|
b8cb2b4273
|
|||
|
d278dc5f5b
|
|||
|
59e73f0fd9
|
|||
|
3e0dc60283
|
|||
|
c24eb5125e
|
|||
|
25bd895dde
|
|||
|
bccd75317e
|
|||
|
f0e3f7574f
|
|||
|
5d44081847
|
|||
|
2a2bb0aec7
|
|||
|
67c40a3909
|
|||
|
1c30188122
|
|||
|
82a0f13242
|
|||
| 288d15a9bc | |||
|
504703d0f7
|
|||
|
e48895d0af
|
|||
| 13d32d0d27 | |||
| 19b9fdd623 | |||
|
ddcaebb51a
|
|||
|
f182312cd2
|
|||
|
73b21789d5
|
|||
|
5d7c724bc8
|
|||
|
74b297c89c
|
|||
|
822a74acce
|
|||
|
9a934fabfd
|
|||
|
828ec9a3fa
|
|||
|
63a43d79dd
|
|||
|
029caf4526
|
|||
|
1c5c418f1c
|
|||
|
a4139d4652
|
|||
|
2fd2071d40
|
|||
|
97b1ee8ab8
|
|||
|
dee479def5
|
|||
|
c8536e20d2
|
|||
|
d70137775f
|
79
README.md
79
README.md
@@ -5,3 +5,82 @@
|
||||
*Midas* aims at providing Python developers with a simple annotation system to enable compile-time integrity and data type checks, as well as generating runtime assertions.
|
||||
|
||||
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.11+
|
||||
- [uv](https://docs.astral.sh/uv/getting-started/installation/)
|
||||
|
||||
## Installation
|
||||
|
||||
1. Clone the repository
|
||||
```shell
|
||||
git clone https://git.kb28.ch/HEL/midas.git
|
||||
```
|
||||
2. Go in the project directory
|
||||
```shell
|
||||
cd midas
|
||||
```
|
||||
3. Install the CLI as a user-wide tool
|
||||
```shell
|
||||
uv tool install .
|
||||
```
|
||||
4. You can now run the `midas` command from anywhere
|
||||
```shell
|
||||
midas --help
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
### Compiling
|
||||
|
||||
> [!NOTE]
|
||||
> In the current state of the project, the `compile` command doesn't generate any runnable code, it only runs the parsers and type checker on the provided files
|
||||
|
||||
```shell
|
||||
midas compile -t types.midas source.py
|
||||
```
|
||||
|
||||
With the `compile` command, you can process a source Python file, with any number of custom type definition files (`-t FILE` option), and the type checker will verify the coherence of your program and generate the runnable code with valid syntax and runtime assertions.
|
||||
|
||||
The optional `-l FILE` option lets you produce a highlighted version of the source code showing diagnostics from the type checker (see [Highlighting](#highlighting))
|
||||
|
||||
### Highlighting
|
||||
|
||||
```shell
|
||||
midas utils highlight source.py
|
||||
# or
|
||||
midas utils highlight types.midas
|
||||
```
|
||||
|
||||
The `highlight` command takes in a source file (Python or Midas), runs the appropriate parser and outputs an HTML file containing the source code with added highlighting. This highlighting takes the form of hoverable annotations showing some of the parsed structures (e.g. a function definition, an assignment, a generic type, etc.)
|
||||
|
||||
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
|
||||
|
||||
### Dumping the AST
|
||||
|
||||
```shell
|
||||
midas utils dump-ast source.py
|
||||
# or
|
||||
midas utils dump-ast types.midas
|
||||
```
|
||||
|
||||
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `-p` flags lets you toggle the custom AST parsing. Without `-p`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
|
||||
|
||||
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
|
||||
|
||||
## Tests
|
||||
|
||||
Several snapshot tests are available to assert the good behaviour of the parsers and type checker. They can be run as follows:
|
||||
|
||||
```shell
|
||||
uv run -m tests.midas run -a
|
||||
uv run -m tests.python run -a
|
||||
uv run -m tests.checker run -a
|
||||
```
|
||||
|
||||
**Available subcommands:**
|
||||
- Run all tests: `run -a`
|
||||
- Run specific tests: `run tests/cases/test1.py tests/cases/test2.py ...`
|
||||
- Update all tests: `update -a`
|
||||
- Update specific tests: `update tests/cases/test1.py tests/cases/test2.py ...`
|
||||
|
||||
@@ -2,10 +2,6 @@
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
# Prototype of custom type import to use valid Python syntax
|
||||
import midas
|
||||
midas.using("02_custom_types.midas")
|
||||
|
||||
# A data-frame using a custom type
|
||||
df: Frame[
|
||||
location: GeoLocation
|
||||
|
||||
@@ -9,3 +9,5 @@ d = True
|
||||
e = d + d
|
||||
|
||||
f: float = a
|
||||
|
||||
f = -f
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
type Meter(float)
|
||||
type Second(float)
|
||||
type MeterPerSecond(float)
|
||||
type Meter = float
|
||||
type Second = float
|
||||
type MeterPerSecond = float
|
||||
|
||||
extend Meter {
|
||||
op __add__(Meter) -> Meter
|
||||
op __sub__(Meter) -> Meter
|
||||
op __truediv__(Second) -> MeterPerSecond
|
||||
def __add__: fn(Meter, /) -> Meter
|
||||
def __sub__: fn(Meter, /) -> Meter
|
||||
def __truediv__: fn(Second, /) -> MeterPerSecond
|
||||
}
|
||||
|
||||
extend Second {
|
||||
op __add__(Second) -> Second
|
||||
op __sub__(Second) -> Second
|
||||
def __add__: fn(Second, /) -> Second
|
||||
def __sub__: fn(Second, /) -> Second
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
midas.using("02_simple_types.midas")
|
||||
|
||||
distance: Meter = cast(Meter, 123.45)
|
||||
time: Second = cast(Second, 6.7)
|
||||
speed = distance / time
|
||||
|
||||
@@ -4,13 +4,20 @@ def minimum(x: int, y: int):
|
||||
else:
|
||||
return y
|
||||
|
||||
|
||||
a = 15
|
||||
b = 72
|
||||
c = minimum(a, b)
|
||||
|
||||
|
||||
def factorial(n: int) -> int:
|
||||
if n <= 1:
|
||||
return 1
|
||||
return n * factorial(n - 1)
|
||||
|
||||
category = "Category 1" if a < 10 else "Category 2"
|
||||
|
||||
category = "Category 1" if a < 10 else "Category 2"
|
||||
|
||||
|
||||
def foo() -> None:
|
||||
pass
|
||||
|
||||
21
examples/01_simple_type_checking/04_complex_types.midas
Normal file
21
examples/01_simple_type_checking/04_complex_types.midas
Normal file
@@ -0,0 +1,21 @@
|
||||
type Meter = float
|
||||
|
||||
extend Meter {
|
||||
def __add__: fn(Meter, /) -> Meter
|
||||
def __sub__: fn(Meter, /) -> Meter
|
||||
}
|
||||
|
||||
type Coordinate = object
|
||||
|
||||
extend Coordinate {
|
||||
prop x: Meter
|
||||
prop y: Meter
|
||||
}
|
||||
|
||||
type Difference[T <: float] = T
|
||||
type MeterDifference = Difference[Meter]
|
||||
|
||||
type CompDiff[T <: float] = {
|
||||
prop d1: Difference[T]
|
||||
prop d2: Difference[T]
|
||||
}
|
||||
37
examples/01_simple_type_checking/04_complex_types.py
Normal file
37
examples/01_simple_type_checking/04_complex_types.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
p1: Coordinate
|
||||
p2: Coordinate
|
||||
|
||||
diff_x = p2.x - p1.x
|
||||
diff_y = p2.y - p1.y
|
||||
|
||||
dist = diff_x + diff_y
|
||||
|
||||
p2.x += cast(Meter, 1)
|
||||
p2.y = True # invalid, wrong type
|
||||
p2.z = 3 # invalid, no property 'z' on Coordinate
|
||||
p2.x.a = 3 # invalid, no properties on Meter
|
||||
|
||||
foo: list[float] = []
|
||||
|
||||
append = foo.append
|
||||
|
||||
foo.append("") # invalid, must be float
|
||||
foo.append(2)
|
||||
append(True) # invalid, must be float
|
||||
append(2)
|
||||
|
||||
bar: list[list[Meter]]
|
||||
|
||||
bar.append([p2.x])
|
||||
|
||||
foo2 = foo + foo
|
||||
|
||||
a = foo[0]
|
||||
b = bar[0][1]
|
||||
c = bar[0][1][2] # invalid, not method __getitem__ on Meter
|
||||
c = bar[""] # invalid, wrong index type
|
||||
|
||||
d = foo[1:2]
|
||||
28
examples/01_simple_type_checking/05_functions.py
Normal file
28
examples/01_simple_type_checking/05_functions.py
Normal file
@@ -0,0 +1,28 @@
|
||||
def incr(value: int):
|
||||
return value + 1
|
||||
|
||||
|
||||
def decr(value: int):
|
||||
return value - 1
|
||||
|
||||
|
||||
def foo(a: int, /, b: float, *, c: str):
|
||||
return True
|
||||
|
||||
|
||||
r1 = foo() # foo() missing 2 required positional arguments: 'a' and 'b'
|
||||
r2 = foo(1) # foo() missing 1 required positional argument: 'b'
|
||||
r3 = foo(1, 2.0) # foo() missing 1 required keyword-only argument: 'c'
|
||||
r4 = foo(1, b=2.0) # foo() missing 1 required keyword-only argument: 'c'
|
||||
r5 = foo(1, 2.0, "test") # foo() takes 2 positional arguments but 3 were given
|
||||
r6 = foo(1, 2.0, b=3.0) # foo() got multiple values for argument 'b'
|
||||
r7 = foo(
|
||||
a=1
|
||||
) # foo() got some positional-only arguments passed as keyword arguments: 'a'
|
||||
r8 = foo(g="test") # foo() got an unexpected keyword argument 'g'
|
||||
|
||||
r9a = foo(1, 2.0, c="test")
|
||||
r9b = foo(1, b=2.0, c="test")
|
||||
r9c = foo(1, c="test", b=2.0)
|
||||
|
||||
r10 = foo("a", 3, c=False) # wrong argument types
|
||||
10
examples/01_simple_type_checking/06_overloads.midas
Normal file
10
examples/01_simple_type_checking/06_overloads.midas
Normal file
@@ -0,0 +1,10 @@
|
||||
type T1 = object
|
||||
type T2 = object
|
||||
type Foo = object
|
||||
type T2b = T2
|
||||
|
||||
extend Foo {
|
||||
def bar: fn(T1, /) -> int
|
||||
def bar: fn(T2, /) -> float
|
||||
def bar: fn(T2b, /) -> int
|
||||
}
|
||||
18
examples/01_simple_type_checking/06_overloads.py
Normal file
18
examples/01_simple_type_checking/06_overloads.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
foo: Foo
|
||||
t1: T1
|
||||
t2: T2
|
||||
|
||||
a = foo.bar(t1)
|
||||
b = foo.bar(t2)
|
||||
|
||||
func = foo.bar
|
||||
|
||||
c = func(t1)
|
||||
d = func(t2)
|
||||
|
||||
t2b: T2b
|
||||
|
||||
e = foo.bar(t2b)
|
||||
15
gen/gen.py
15
gen/gen.py
@@ -30,6 +30,7 @@ from __future__ import annotations
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
{preamble}
|
||||
{sections}
|
||||
"""
|
||||
|
||||
@@ -57,6 +58,11 @@ IMPORTS_REGEX = re.compile(
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
PREAMBLE_REGEX = re.compile(
|
||||
r"^###>\s*Preamble\s*?\n(?P<body>.*?)\n###<$",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def snake_case(text: str) -> str:
|
||||
return re.sub(r"[A-Z]", lambda c: "_" + c.group().lower(), text).lower().strip("_")
|
||||
@@ -88,13 +94,14 @@ def make_banner(text: str) -> str:
|
||||
|
||||
|
||||
def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||
print(f" Generating {full_name}")
|
||||
visitor_methods: list[str] = []
|
||||
classes: list[str] = []
|
||||
definitions: list[str] = body.strip("\n").split("\n\n\n")
|
||||
for cls in definitions:
|
||||
cls = cls.strip("\n")
|
||||
name: str = re.match("class (.*?):", cls).group(1) # type: ignore
|
||||
print(f"Processing {name}")
|
||||
print(f" Processing {name}")
|
||||
visitor_methods.append(make_visitor_method(name, param))
|
||||
classes.append(make_class(name, cls, base))
|
||||
|
||||
@@ -107,6 +114,7 @@ def make_section(full_name: str, base: str, param: str, body: str) -> str:
|
||||
|
||||
|
||||
def generate(definitions_path: Path, out_path: Path):
|
||||
print(f"Processing generating {out_path} from {definitions_path}")
|
||||
root_dir: Path = Path(__file__).parent.parent
|
||||
rel_path: Path = definitions_path.relative_to(root_dir)
|
||||
src: str = definitions_path.read_text()
|
||||
@@ -116,6 +124,10 @@ def generate(definitions_path: Path, out_path: Path):
|
||||
if m := IMPORTS_REGEX.search(src):
|
||||
imports = m.group("body").strip("\n")
|
||||
|
||||
preamble: str = ""
|
||||
if m := PREAMBLE_REGEX.search(src):
|
||||
preamble = m.group("body")
|
||||
|
||||
for section_m in SECTION_REGEX.finditer(src):
|
||||
full_name: str = section_m.group("name")
|
||||
base: str = section_m.group("base")
|
||||
@@ -129,6 +141,7 @@ def generate(definitions_path: Path, out_path: Path):
|
||||
gen_path=Path(__file__).relative_to(root_dir),
|
||||
),
|
||||
imports=imports,
|
||||
preamble=preamble,
|
||||
sections="\n\n\n".join(sections),
|
||||
)
|
||||
out_path.write_text(result)
|
||||
|
||||
64
gen/midas.py
64
gen/midas.py
@@ -4,6 +4,7 @@
|
||||
###> Imports
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
@@ -12,33 +13,39 @@ from midas.lexer.token import Token
|
||||
###<
|
||||
|
||||
|
||||
###> Preamble
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypeParam:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
|
||||
class MemberKind(Enum):
|
||||
PROPERTY = auto()
|
||||
METHOD = auto()
|
||||
|
||||
|
||||
###<
|
||||
|
||||
|
||||
###> Stmt | Statements
|
||||
class TypeStmt:
|
||||
name: Token
|
||||
params: list[Param]
|
||||
params: list[TypeParam]
|
||||
type: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Param:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
|
||||
class PropertyStmt:
|
||||
class MemberStmt:
|
||||
name: Token
|
||||
type: Type
|
||||
kind: MemberKind
|
||||
|
||||
|
||||
class ExtendStmt:
|
||||
type: Type
|
||||
operations: list[OpStmt]
|
||||
|
||||
|
||||
class OpStmt:
|
||||
name: Token
|
||||
operand: Type
|
||||
result: Type
|
||||
params: list[TypeParam]
|
||||
members: list[MemberStmt]
|
||||
|
||||
|
||||
class PredicateStmt:
|
||||
@@ -103,7 +110,7 @@ class NamedType:
|
||||
|
||||
class GenericType:
|
||||
type: Type
|
||||
params: list[Type]
|
||||
args: list[Type]
|
||||
|
||||
|
||||
class ConstraintType:
|
||||
@@ -111,12 +118,27 @@ class ConstraintType:
|
||||
constraint: Expr
|
||||
|
||||
|
||||
class UnionType:
|
||||
types: list[Type]
|
||||
|
||||
|
||||
class ComplexType:
|
||||
properties: list[PropertyStmt]
|
||||
members: list[MemberStmt]
|
||||
|
||||
|
||||
class ExtensionType:
|
||||
base: Type
|
||||
extension: ComplexType
|
||||
|
||||
|
||||
class FunctionType:
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[Token]
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
|
||||
###<
|
||||
|
||||
@@ -128,12 +128,6 @@ class LogicalExpr:
|
||||
right: Expr
|
||||
|
||||
|
||||
class SetExpr:
|
||||
object: Expr
|
||||
name: str
|
||||
value: Expr
|
||||
|
||||
|
||||
class CastExpr:
|
||||
type: MidasType
|
||||
expr: Expr
|
||||
@@ -145,4 +139,19 @@ class TernaryExpr:
|
||||
if_false: Expr
|
||||
|
||||
|
||||
class ListExpr:
|
||||
items: list[Expr]
|
||||
|
||||
|
||||
class SubscriptExpr:
|
||||
object: Expr
|
||||
index: Expr
|
||||
|
||||
|
||||
class SliceExpr:
|
||||
lower: Optional[Expr]
|
||||
upper: Optional[Expr]
|
||||
step: Optional[Expr]
|
||||
|
||||
|
||||
###<
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from midas.ast.location import Location
|
||||
@@ -14,6 +15,18 @@ from midas.lexer.token import Token
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypeParam:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
|
||||
class MemberKind(Enum):
|
||||
PROPERTY = auto()
|
||||
METHOD = auto()
|
||||
|
||||
|
||||
##############
|
||||
# Statements #
|
||||
##############
|
||||
@@ -31,14 +44,11 @@ class Stmt(ABC):
|
||||
def visit_type_stmt(self, stmt: TypeStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> T: ...
|
||||
def visit_member_stmt(self, stmt: MemberStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_extend_stmt(self, stmt: ExtendStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_op_stmt(self, stmt: OpStmt) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_predicate_stmt(self, stmt: PredicateStmt) -> T: ...
|
||||
|
||||
@@ -46,47 +56,33 @@ class Stmt(ABC):
|
||||
@dataclass(frozen=True)
|
||||
class TypeStmt(Stmt):
|
||||
name: Token
|
||||
params: list[Param]
|
||||
params: list[TypeParam]
|
||||
type: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Param:
|
||||
location: Location
|
||||
name: Token
|
||||
bound: Optional[Type]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_type_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PropertyStmt(Stmt):
|
||||
class MemberStmt(Stmt):
|
||||
name: Token
|
||||
type: Type
|
||||
kind: MemberKind
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_property_stmt(self)
|
||||
return visitor.visit_member_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtendStmt(Stmt):
|
||||
type: Type
|
||||
operations: list[OpStmt]
|
||||
name: Token
|
||||
params: list[TypeParam]
|
||||
members: list[MemberStmt]
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_extend_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OpStmt(Stmt):
|
||||
name: Token
|
||||
operand: Type
|
||||
result: Type
|
||||
|
||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||
return visitor.visit_op_stmt(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PredicateStmt(Stmt):
|
||||
name: Token
|
||||
@@ -229,10 +225,13 @@ class Type(ABC):
|
||||
def visit_constraint_type(self, type: ConstraintType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_union_type(self, type: UnionType) -> T: ...
|
||||
def visit_complex_type(self, type: ComplexType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_complex_type(self, type: ComplexType) -> T: ...
|
||||
def visit_extension_type(self, type: ExtensionType) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_function_type(self, type: FunctionType) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -246,7 +245,7 @@ class NamedType(Type):
|
||||
@dataclass(frozen=True)
|
||||
class GenericType(Type):
|
||||
type: Type
|
||||
params: list[Type]
|
||||
args: list[Type]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_generic_type(self)
|
||||
@@ -261,17 +260,36 @@ class ConstraintType(Type):
|
||||
return visitor.visit_constraint_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UnionType(Type):
|
||||
types: list[Type]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_union_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ComplexType(Type):
|
||||
properties: list[PropertyStmt]
|
||||
members: list[MemberStmt]
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_complex_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtensionType(Type):
|
||||
base: Type
|
||||
extension: ComplexType
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_extension_type(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionType(Type):
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
location: Optional[Location] = None
|
||||
name: Optional[Token]
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||
return visitor.visit_function_type(self)
|
||||
|
||||
@@ -100,20 +100,21 @@ class MidasAstPrinter(
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._print_type_stmt_param(param)
|
||||
self._print_type_param(param)
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
|
||||
def _print_type_stmt_param(self, param: m.TypeStmt.Param) -> None:
|
||||
def _print_type_param(self, param: m.TypeParam) -> None:
|
||||
self._write_line("Param")
|
||||
with self._child_level():
|
||||
self._write_line(f'name: "{param.name.lexeme}"')
|
||||
self._write_optional_child("bound", param.bound, last=True)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
self._write_line("PropertyStmt")
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||
self._write_line("MemberStmt")
|
||||
with self._child_level():
|
||||
self._write_line(f"kind: {stmt.kind.name}")
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
self._write_line("type", last=True)
|
||||
with self._child_level(single=True):
|
||||
@@ -122,29 +123,28 @@ class MidasAstPrinter(
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._write_line("ExtendStmt")
|
||||
with self._child_level():
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
stmt.type.accept(self)
|
||||
self._write_line("operations", last=True)
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, op in enumerate(stmt.operations):
|
||||
for i, param in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.operations) - 1:
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
op.accept(self)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
||||
self._write_line("OpStmt")
|
||||
with self._child_level():
|
||||
self._print_type_param(param)
|
||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||
|
||||
self._write_line("operand")
|
||||
with self._child_level(single=True):
|
||||
stmt.operand.accept(self)
|
||||
|
||||
self._write_line("result", last=True)
|
||||
with self._child_level(single=True):
|
||||
stmt.result.accept(self)
|
||||
self._write_line("params")
|
||||
with self._child_level():
|
||||
for i, param in enumerate(stmt.params):
|
||||
self._idx = i
|
||||
if i == len(stmt.params) - 1:
|
||||
self._mark_last()
|
||||
self._print_type_param(param)
|
||||
self._write_line("members", last=True)
|
||||
with self._child_level():
|
||||
for i, member in enumerate(stmt.members):
|
||||
self._idx = i
|
||||
if i == len(stmt.members) - 1:
|
||||
self._mark_last()
|
||||
member.accept(self)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
self._write_line("PredicateStmt")
|
||||
@@ -234,11 +234,11 @@ class MidasAstPrinter(
|
||||
self._write_line("type")
|
||||
with self._child_level():
|
||||
type.type.accept(self)
|
||||
self._write_line("params", last=True)
|
||||
self._write_line("args", last=True)
|
||||
with self._child_level():
|
||||
for i, param in enumerate(type.params):
|
||||
for i, param in enumerate(type.args):
|
||||
self._idx = i
|
||||
if i == len(type.params) - 1:
|
||||
if i == len(type.args) - 1:
|
||||
self._mark_last()
|
||||
param.accept(self)
|
||||
|
||||
@@ -252,27 +252,69 @@ class MidasAstPrinter(
|
||||
with self._child_level(single=True):
|
||||
type.constraint.accept(self)
|
||||
|
||||
def visit_union_type(self, type: m.UnionType) -> None:
|
||||
self._write_line("UnionType")
|
||||
with self._child_level():
|
||||
self._write_line("types", last=True)
|
||||
with self._child_level():
|
||||
for i, type_ in enumerate(type.types):
|
||||
self._idx = i
|
||||
if i == len(type.types) - 1:
|
||||
self._mark_last()
|
||||
type_.accept(self)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||
self._write_line("ComplexType")
|
||||
with self._child_level():
|
||||
self._write_line("properties", last=True)
|
||||
self._write_line("members", last=True)
|
||||
with self._child_level():
|
||||
for i, prop in enumerate(type.properties):
|
||||
for i, member in enumerate(type.members):
|
||||
self._idx = i
|
||||
if i == len(type.properties) - 1:
|
||||
if i == len(type.members) - 1:
|
||||
self._mark_last()
|
||||
prop.accept(self)
|
||||
member.accept(self)
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> None:
|
||||
self._write_line("ExtensionType")
|
||||
with self._child_level():
|
||||
self._write_line("base")
|
||||
with self._child_level(single=True):
|
||||
type.base.accept(self)
|
||||
self._write_line("extension", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.extension.accept(self)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self._write_line("FunctionType")
|
||||
with self._child_level():
|
||||
self._write_line("pos_args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.pos_args):
|
||||
self._idx = i
|
||||
if i == len(type.pos_args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.args):
|
||||
self._idx = i
|
||||
if i == len(type.args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("kw_args")
|
||||
with self._child_level():
|
||||
for i, arg in enumerate(type.kw_args):
|
||||
self._idx = i
|
||||
if i == len(type.kw_args) - 1:
|
||||
self._mark_last()
|
||||
self._print_function_arg(arg)
|
||||
|
||||
self._write_line("returns", last=True)
|
||||
with self._child_level(single=True):
|
||||
type.returns.accept(self)
|
||||
|
||||
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||
self._write_line("Argument")
|
||||
with self._child_level():
|
||||
name: str = "None"
|
||||
if arg.name is not None:
|
||||
name = f'"{arg.name.lexeme}"'
|
||||
self._write_line(f"name: {name}")
|
||||
self._write_line("type")
|
||||
with self._child_level(single=True):
|
||||
arg.type.accept(self)
|
||||
self._write_line(f"required: {arg.required}", last=True)
|
||||
|
||||
|
||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||
@@ -283,45 +325,46 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
def indented(self, text: str) -> str:
|
||||
return " " * (self.level * self.indent) + text
|
||||
|
||||
def print(self, expr: m.Expr | m.Stmt):
|
||||
def print(self, expr: m.Expr | m.Stmt | m.Type) -> str:
|
||||
self.level = 0
|
||||
return expr.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> str:
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [
|
||||
self._print_type_template_param(param) for param in stmt.params
|
||||
]
|
||||
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = f"type {stmt.name.lexeme}{template} = {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def _print_type_template_param(self, param: m.TypeStmt.Param) -> str:
|
||||
def _print_type_param(self, param: m.TypeParam) -> str:
|
||||
res: str = param.name.lexeme
|
||||
if param.bound is not None:
|
||||
res += "<:" + param.bound.accept(self)
|
||||
return res
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt):
|
||||
res: str = f"{stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt):
|
||||
keyword: str = {
|
||||
m.MemberKind.PROPERTY: "prop",
|
||||
m.MemberKind.METHOD: "def",
|
||||
}.get(stmt.kind, "")
|
||||
res: str = f"{keyword} {stmt.name.lexeme}: {stmt.type.accept(self)}"
|
||||
return self.indented(res)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt):
|
||||
res: str = self.indented(f"extend {stmt.type.accept(self)}")
|
||||
template: str = ""
|
||||
if len(stmt.params) != 0:
|
||||
params: list[str] = [self._print_type_param(param) for param in stmt.params]
|
||||
template = f"[{', '.join(params)}]"
|
||||
res: str = self.indented(f"extend {stmt.name.lexeme}{template}")
|
||||
res += " {\n"
|
||||
self.level += 1
|
||||
for op in stmt.operations:
|
||||
res += op.accept(self)
|
||||
for member in stmt.members:
|
||||
res += member.accept(self) + "\n"
|
||||
self.level -= 1
|
||||
res += "\n" + self.indented("}")
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt):
|
||||
operand: str = stmt.operand.accept(self)
|
||||
result: str = stmt.result.accept(self)
|
||||
return self.indented(f"op {stmt.name.lexeme}({operand}) -> {result}")
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||
name: str = stmt.name.lexeme
|
||||
subject: str = stmt.subject.lexeme
|
||||
@@ -369,9 +412,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> str:
|
||||
res: str = type.type.accept(self)
|
||||
if len(type.params) != 0:
|
||||
params: list[str] = [param.accept(self) for param in type.params]
|
||||
res += f"[{', '.join(params)}]"
|
||||
if len(type.args) != 0:
|
||||
args: list[str] = [param.accept(self) for param in type.args]
|
||||
res += f"[{', '.join(args)}]"
|
||||
return res
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> str:
|
||||
@@ -379,20 +422,44 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
||||
res += " where " + type.constraint.accept(self)
|
||||
return res
|
||||
|
||||
def visit_union_type(self, type: m.UnionType) -> str:
|
||||
types: list[str] = [type_.accept(self) for type_ in type.types]
|
||||
return " | ".join(types)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> str:
|
||||
res: str = "{\n"
|
||||
self.level += 1
|
||||
for prop in type.properties:
|
||||
res += prop.accept(self)
|
||||
for member in type.members:
|
||||
res += member.accept(self)
|
||||
res += "\n"
|
||||
self.level -= 1
|
||||
res += self.indented("}")
|
||||
return res
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> str:
|
||||
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
||||
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
|
||||
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
|
||||
args: list[str] = pos_args
|
||||
|
||||
if len(pos_args) != 0:
|
||||
args.append("/")
|
||||
args += mixed_args
|
||||
if len(kw_args) != 0:
|
||||
args.append("*")
|
||||
args += kw_args
|
||||
|
||||
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
|
||||
|
||||
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||
res: str = ""
|
||||
if arg.name is not None:
|
||||
res += arg.name.lexeme
|
||||
res += ": "
|
||||
res += arg.type.accept(self)
|
||||
if not arg.required:
|
||||
res += "?"
|
||||
return res
|
||||
|
||||
|
||||
class PythonAstPrinter(
|
||||
AstPrinter,
|
||||
@@ -597,7 +664,7 @@ class PythonAstPrinter(
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> None:
|
||||
self._write_line("LiteralExpr")
|
||||
with self._child_level(single=True):
|
||||
self._write_line(f"value: {expr.value}")
|
||||
self._write_line(f"value: {expr.value!r}")
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> None:
|
||||
self._write_line("VariableExpr")
|
||||
@@ -617,17 +684,6 @@ class PythonAstPrinter(
|
||||
with self._child_level(single=True):
|
||||
expr.right.accept(self)
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> None:
|
||||
self._write_line("SetExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line(f"name: {expr.name}")
|
||||
self._write_line("value", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.value.accept(self)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self._write_line("CastExpr")
|
||||
with self._child_level():
|
||||
@@ -652,3 +708,31 @@ class PythonAstPrinter(
|
||||
self._write_line("if_false", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.if_false.accept(self)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||
self._write_line("ListExpr")
|
||||
with self._child_level():
|
||||
self._write_line("items", last=True)
|
||||
with self._child_level():
|
||||
for i, item in enumerate(expr.items):
|
||||
self._idx = i
|
||||
if i == len(expr.items) - 1:
|
||||
self._mark_last()
|
||||
item.accept(self)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||
self._write_line("SubscriptExpr")
|
||||
with self._child_level():
|
||||
self._write_line("object")
|
||||
with self._child_level(single=True):
|
||||
expr.object.accept(self)
|
||||
self._write_line("index", last=True)
|
||||
with self._child_level(single=True):
|
||||
expr.index.accept(self)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
|
||||
self._write_line("SliceExpr")
|
||||
with self._child_level():
|
||||
self._write_optional_child("lower", expr.lower)
|
||||
self._write_optional_child("upper", expr.upper)
|
||||
self._write_optional_child("step", expr.step, last=True)
|
||||
|
||||
@@ -14,6 +14,7 @@ from midas.ast.location import Location
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
####################
|
||||
# Type annotations #
|
||||
####################
|
||||
@@ -214,15 +215,21 @@ class Expr(ABC):
|
||||
@abstractmethod
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_set_expr(self, expr: SetExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_cast_expr(self, expr: CastExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_ternary_expr(self, expr: TernaryExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_list_expr(self, expr: ListExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_subscript_expr(self, expr: SubscriptExpr) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BinaryExpr(Expr):
|
||||
@@ -298,16 +305,6 @@ class LogicalExpr(Expr):
|
||||
return visitor.visit_logical_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SetExpr(Expr):
|
||||
object: Expr
|
||||
name: str
|
||||
value: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_set_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CastExpr(Expr):
|
||||
type: MidasType
|
||||
@@ -325,3 +322,30 @@ class TernaryExpr(Expr):
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_ternary_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListExpr(Expr):
|
||||
items: list[Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_list_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SubscriptExpr(Expr):
|
||||
object: Expr
|
||||
index: Expr
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_subscript_expr(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SliceExpr(Expr):
|
||||
lower: Optional[Expr]
|
||||
upper: Optional[Expr]
|
||||
step: Optional[Expr]
|
||||
|
||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||
return visitor.visit_slice_expr(self)
|
||||
|
||||
152
midas/checker/builtins.midas
Normal file
152
midas/checker/builtins.midas
Normal file
@@ -0,0 +1,152 @@
|
||||
extend float {
|
||||
def hex: fn() -> str
|
||||
def is_integer: fn() -> bool
|
||||
prop real: float
|
||||
prop imag: float
|
||||
def conjugate: fn() -> float
|
||||
def __add__: fn(value: float, /) -> float
|
||||
def __sub__: fn(value: float, /) -> float
|
||||
def __mul__: fn(value: float, /) -> float
|
||||
def __floordiv__: fn(value: float, /) -> float
|
||||
def __truediv__: fn(value: float, /) -> float
|
||||
def __mod__: fn(value: float, /) -> float
|
||||
// def __divmod__: fn(value: float, /) -> tuple[float, float]
|
||||
|
||||
def __pow__: fn(value: int, /) -> float
|
||||
// positive __value -> float; negative __value -> complex
|
||||
// return type must be Any as `float | complex` causes too many false-positive errors
|
||||
def __pow__: fn(value: float, /) -> Any
|
||||
def __radd__: fn(value: float, /) -> float
|
||||
def __rsub__: fn(value: float, /) -> float
|
||||
def __rmul__: fn(value: float, /) -> float
|
||||
def __rfloordiv__: fn(value: float, /) -> float
|
||||
def __rtruediv__: fn(value: float, /) -> float
|
||||
def __rmod__: fn(value: float, /) -> float
|
||||
// def __rdivmod__: fn(value: float, /) -> tuple[float, float]
|
||||
// def __rpow__: fn(value: _PositiveInteger, mod: None = None, /) -> float
|
||||
// def __rpow__: fn(value: _NegativeInteger, mod: None = None, /) -> complex
|
||||
// Returning `complex` for the general case gives too many false-positive errors.
|
||||
// def __rpow__: fn(value: float, mod: None = None, /) -> Any
|
||||
// def __getnewargs__: fn() -> tuple[float]
|
||||
def __trunc__: fn() -> int
|
||||
def __ceil__: fn() -> int
|
||||
def __floor__: fn() -> int
|
||||
def __round__: fn(ndigits: None?, /) -> int
|
||||
def __round__: fn(ndigits: int, /) -> float
|
||||
def __eq__: fn(value: object, /) -> bool
|
||||
def __ne__: fn(value: object, /) -> bool
|
||||
def __lt__: fn(value: float, /) -> bool
|
||||
def __le__: fn(value: float, /) -> bool
|
||||
def __gt__: fn(value: float, /) -> bool
|
||||
def __ge__: fn(value: float, /) -> bool
|
||||
def __neg__: fn() -> float
|
||||
def __pos__: fn() -> float
|
||||
def __int__: fn() -> int
|
||||
def __float__: fn() -> float
|
||||
def __abs__: fn() -> float
|
||||
def __hash__: fn() -> int
|
||||
def __bool__: fn() -> bool
|
||||
def __format__: fn(format_spec: str, /) -> str
|
||||
}
|
||||
|
||||
extend int {
|
||||
prop real: int
|
||||
prop imag: int
|
||||
prop numerator: int
|
||||
prop denominator: int
|
||||
def conjugate: fn() -> int
|
||||
def bit_length: fn() -> int
|
||||
def bit_count: fn() -> int
|
||||
// def to_bytes: fn(length: int?, byteorder: str?, *, signed: bool?) -> bytes
|
||||
|
||||
def __add__: fn(value: int, /) -> int
|
||||
def __sub__: fn(value: int, /) -> int
|
||||
def __mul__: fn(value: int, /) -> int
|
||||
def __floordiv__: fn(value: int, /) -> int
|
||||
def __truediv__: fn(value: int, /) -> float
|
||||
def __mod__: fn(value: int, /) -> int
|
||||
// def __divmod__: fn(value: int, /) -> tuple[int, int]
|
||||
def __radd__: fn(value: int, /) -> int
|
||||
def __rsub__: fn(value: int, /) -> int
|
||||
def __rmul__: fn(value: int, /) -> int
|
||||
def __rfloordiv__: fn(value: int, /) -> int
|
||||
def __rtruediv__: fn(value: int, /) -> float
|
||||
def __rmod__: fn(value: int, /) -> int
|
||||
// def __rdivmod__: fn(value: int, /) -> tuple[int, int]
|
||||
def __pow__: fn(value: int, /) -> int
|
||||
// def __pow__: fn(value: _PositiveInteger, mod: None = None, /) -> int
|
||||
// def __pow__: fn(value: _NegativeInteger, mod: None = None, /) -> float
|
||||
// positive __value -> int; negative __value -> float
|
||||
// return type must be Any as `int | float` causes too many false-positive errors
|
||||
// def __pow__: fn(value: int, mod: None = None, /) -> Any
|
||||
// def __pow__: fn(value: int, mod: int, /) -> int
|
||||
def __rpow__: fn(value: int, /) -> Any
|
||||
def __and__: fn(value: int, /) -> int
|
||||
def __or__: fn(value: int, /) -> int
|
||||
def __xor__: fn(value: int, /) -> int
|
||||
def __lshift__: fn(value: int, /) -> int
|
||||
def __rshift__: fn(value: int, /) -> int
|
||||
def __rand__: fn(value: int, /) -> int
|
||||
def __ror__: fn(value: int, /) -> int
|
||||
def __rxor__: fn(value: int, /) -> int
|
||||
def __rlshift__: fn(value: int, /) -> int
|
||||
def __rrshift__: fn(value: int, /) -> int
|
||||
def __neg__: fn() -> int
|
||||
def __pos__: fn() -> int
|
||||
def __invert__: fn() -> int
|
||||
def __trunc__: fn() -> int
|
||||
def __ceil__: fn() -> int
|
||||
def __floor__: fn() -> int
|
||||
def __round__: fn(ndigits: None?, /) -> int
|
||||
def __round__: fn(ndigits: int, /) -> int
|
||||
|
||||
// def __getnewargs__: fn() -> tuple[int]
|
||||
def __eq__: fn(value: object, /) -> bool
|
||||
def __ne__: fn(value: object, /) -> bool
|
||||
def __lt__: fn(value: int, /) -> bool
|
||||
def __le__: fn(value: int, /) -> bool
|
||||
def __gt__: fn(value: int, /) -> bool
|
||||
def __ge__: fn(value: int, /) -> bool
|
||||
def __float__: fn() -> float
|
||||
def __int__: fn() -> int
|
||||
def __abs__: fn() -> int
|
||||
def __hash__: fn() -> int
|
||||
def __bool__: fn() -> bool
|
||||
def __index__: fn() -> int
|
||||
def __format__: fn(format_spec: str, /) -> str
|
||||
}
|
||||
|
||||
extend list[T] {
|
||||
def copy: fn () -> list[T]
|
||||
def append: fn (object: T, /) -> None
|
||||
def extend: fn (iterable: list[T], /) -> None
|
||||
def pop: fn (index: int?, /) -> T
|
||||
def index: fn (value: T, start: int?, stop: int?, /) -> int
|
||||
def count: fn (value: T, /) -> int
|
||||
def insert: fn (index: int, object: T, /) -> None
|
||||
def remove: fn (value: T, /) -> None
|
||||
def sort: fn (*, reverse: bool?) -> None
|
||||
def __len__: fn () -> int
|
||||
// def __iter__: fn () -> Iterator[T]
|
||||
def __getitem__: fn (i: int, /) -> T
|
||||
def __getitem__: fn (s: slice, /) -> list[T]
|
||||
def __setitem__: fn (key: int, value: T, /) -> None
|
||||
def __setitem__: fn (key: slice, value: list[T], /) -> None
|
||||
def __delitem__: fn (key: int, /) -> None
|
||||
def __delitem__: fn (key: slice, /) -> None
|
||||
// def __add__: fn[S <: T] (value: list[S], /) -> list[T]
|
||||
def __add__: fn (value: list[T], /) -> list[T]
|
||||
def __iadd__: fn (value: list[T], /) -> list[T]
|
||||
def __mul__: fn (value: int, /) -> list[T]
|
||||
def __rmul__: fn (value: int, /) -> list[T]
|
||||
def __imul__: fn (value: int, /) -> list[T]
|
||||
def __contains__: fn (key: object, /) -> bool
|
||||
// def __reversed__: fn (self) -> Iterator[_T]
|
||||
def __gt__: fn (value: list[T], /) -> bool
|
||||
def __ge__: fn (value: list[T], /) -> bool
|
||||
def __lt__: fn (value: list[T], /) -> bool
|
||||
def __le__: fn (value: list[T], /) -> bool
|
||||
def __eq__: fn (value: object, /) -> bool
|
||||
|
||||
prop __doc__: str
|
||||
}
|
||||
41
midas/checker/builtins.py
Normal file
41
midas/checker/builtins.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from midas.checker.types import (
|
||||
BaseType,
|
||||
GenericType,
|
||||
TopType,
|
||||
TypeVar,
|
||||
UnitType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.registry import TypesRegistry
|
||||
|
||||
|
||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||
"float": {"int"},
|
||||
"int": {"bool"},
|
||||
}
|
||||
|
||||
|
||||
def define_builtins(reg: TypesRegistry):
|
||||
"""Define builtin types and operations"""
|
||||
any = reg.define_type("Any", TopType())
|
||||
unit = reg.define_type("None", UnitType())
|
||||
object = reg.define_type("object", BaseType(name="object"))
|
||||
bool = reg.define_type("bool", BaseType(name="bool"))
|
||||
int = reg.define_type("int", BaseType(name="int"))
|
||||
float = reg.define_type("float", BaseType(name="float"))
|
||||
str = reg.define_type("str", BaseType(name="str"))
|
||||
slice = reg.define_type("slice", BaseType(name="slice"))
|
||||
|
||||
list = reg.define_type(
|
||||
"list",
|
||||
GenericType(
|
||||
name="list",
|
||||
params=[TypeVar(name="T", bound=None)],
|
||||
body=BaseType(name="list"),
|
||||
),
|
||||
)
|
||||
@@ -1,549 +1,35 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS
|
||||
from midas.checker.types import Function, Type, UnitType, UnknownType
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
from midas.resolver.midas import MidasResolver
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.checker.midas import MidasTyper
|
||||
from midas.checker.python import PythonTyper
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import Reporter
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
class TypeChecker:
|
||||
def __init__(self):
|
||||
self.types: TypesRegistry = TypesRegistry()
|
||||
self.reporter: Reporter = Reporter()
|
||||
|
||||
self.midas_typer = MidasTyper(self.types, self.reporter)
|
||||
self.python_typer = PythonTyper(self.types, self.reporter)
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument:
|
||||
expr: p.Expr
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
def import_midas(self, path: Path):
|
||||
source: str = path.read_text()
|
||||
return self.import_midas_source(source, path=str(path))
|
||||
|
||||
def import_midas_source(self, source: str, path: Optional[str] = None):
|
||||
self.midas_typer.process(source, path)
|
||||
|
||||
class Checker(
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[Type],
|
||||
p.MidasType.Visitor[Type],
|
||||
):
|
||||
"""A type checker which can use custom type definitions"""
|
||||
def type_check(self, path: Path):
|
||||
source: str = path.read_text()
|
||||
return self.type_check_source(source, path=str(path))
|
||||
|
||||
def __init__(self, locals: dict[p.Expr, int], file_path: Path):
|
||||
self.logger: logging.Logger = logging.getLogger("Checker")
|
||||
self.file_path: Path = file_path
|
||||
self.ctx: MidasResolver = MidasResolver()
|
||||
self.global_env: Environment = Environment()
|
||||
self.env: Environment = self.global_env
|
||||
self.locals: dict[p.Expr, int] = locals
|
||||
self.diagnostics: list[Diagnostic] = []
|
||||
def type_check_source(self, source: str, path: Optional[str] = None):
|
||||
self.python_typer.process(source, path)
|
||||
|
||||
def diagnostic(self, type: DiagnosticType, location: Location, message: str):
|
||||
self.diagnostics.append(
|
||||
Diagnostic(
|
||||
file_path=self.file_path,
|
||||
location=location,
|
||||
type=type,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
def error(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.ERROR,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def warning(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.WARNING,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def info(self, location: Location, message: str):
|
||||
self.diagnostic(
|
||||
type=DiagnosticType.INFO,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def evaluate(self, expr: p.Expr) -> Type:
|
||||
"""Evaluate the type of an expression
|
||||
|
||||
Args:
|
||||
expr (p.Expr): the expression to evaluate
|
||||
|
||||
Returns:
|
||||
Type: the type of the given expression
|
||||
"""
|
||||
return expr.accept(self)
|
||||
|
||||
def evaluate_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
||||
"""Evaluate a sequence of statements
|
||||
|
||||
Args:
|
||||
block (list[p.Stmt]): the statements to evaluate
|
||||
env (Environment): the environment in which to evaluate
|
||||
|
||||
Returns:
|
||||
bool: whether a return statement is present in the block
|
||||
"""
|
||||
previous_env: Environment = self.env
|
||||
self.env = env
|
||||
returned: bool = False
|
||||
for i, stmt in enumerate(block):
|
||||
try:
|
||||
stmt.accept(self)
|
||||
except ReturnException:
|
||||
returned = True
|
||||
if i < len(block) - 1:
|
||||
self.warning(block[i + 1].location, "Unreachable statement")
|
||||
break
|
||||
self.env = previous_env
|
||||
return returned
|
||||
|
||||
def check(self, statements: list[p.Stmt]) -> list[Diagnostic]:
|
||||
"""Type check a sequence of statements and returns diagnostics
|
||||
|
||||
Args:
|
||||
statements (list[p.Stmt]): the statements to evaluate and check
|
||||
|
||||
Returns:
|
||||
list[Diagnostic]: the list of diagnostics (errors, warning, etc.)
|
||||
"""
|
||||
self.diagnostics = []
|
||||
for stmt in statements:
|
||||
stmt.accept(self)
|
||||
|
||||
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
||||
return self.diagnostics
|
||||
|
||||
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
||||
"""Look up a variable in the environment it was declared
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
expr (p.Expr): the variable expression, used to lookup the scope distance
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the type of the variable, or None if it was not found
|
||||
"""
|
||||
distance: Optional[int] = self.locals.get(expr)
|
||||
if distance is not None:
|
||||
return self.env.get_at(distance, name)
|
||||
return self.global_env.get(name)
|
||||
|
||||
def parse_midas_import(self, expr: p.CallExpr) -> Optional[Path]:
|
||||
"""Parse a Midas import statement
|
||||
|
||||
The statement should be written as `midas.using("path/to/types.midas")`
|
||||
|
||||
Args:
|
||||
expr (p.CallExpr): the import call expression
|
||||
|
||||
Returns:
|
||||
Optional[Path]: the path to the imported file, or None if the expression is malformed
|
||||
"""
|
||||
match expr:
|
||||
case p.CallExpr(
|
||||
callee=p.GetExpr(
|
||||
object=p.VariableExpr(name="midas"),
|
||||
name="using",
|
||||
),
|
||||
arguments=[
|
||||
p.LiteralExpr(value=path),
|
||||
],
|
||||
):
|
||||
return Path(path)
|
||||
return None
|
||||
|
||||
def import_midas(self, path: Path) -> None:
|
||||
"""Import Midas definitions from a path
|
||||
|
||||
Args:
|
||||
path (Path): the import path
|
||||
"""
|
||||
self.logger.debug(f"Importing type definitions from {path}")
|
||||
path = (self.file_path.parent / path).resolve()
|
||||
lexer: MidasLexer = MidasLexer(path.read_text())
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
self.ctx.resolve(stmts)
|
||||
self.logger.debug(f"Midas types: {self.ctx._types}")
|
||||
self.logger.debug(f"Midas operations: {self.ctx._operations}")
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
self.evaluate(stmt.expr)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
env: Environment = Environment(self.env)
|
||||
pos_args: list[Function.Argument] = []
|
||||
args: list[Function.Argument] = []
|
||||
kw_args: list[Function.Argument] = []
|
||||
|
||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||
if arg.type is not None:
|
||||
return arg.type.accept(self)
|
||||
if arg.default is not None:
|
||||
return arg.default.accept(self)
|
||||
return UnknownType()
|
||||
|
||||
for arg in stmt.posonlyargs:
|
||||
pos_args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
for arg in stmt.args:
|
||||
args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
for arg in stmt.kwonlyargs:
|
||||
kw_args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
|
||||
for arg in pos_args + args + kw_args:
|
||||
env.define(arg.name, arg.type)
|
||||
|
||||
returns_hint: Optional[Type] = None
|
||||
if stmt.returns is not None:
|
||||
returns_hint = stmt.returns.accept(self)
|
||||
# Early define to handle simple fully-typed recursion
|
||||
inside_function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns_hint,
|
||||
)
|
||||
self.env.define(stmt.name, inside_function)
|
||||
|
||||
returned: bool = self.evaluate_block(stmt.body, env)
|
||||
inferred_return: Type = UnknownType()
|
||||
if not returned:
|
||||
env.return_types.append(UnitType())
|
||||
return_types: set[Type] = set(env.return_types)
|
||||
if len(return_types) == 1:
|
||||
inferred_return = list(return_types)[0]
|
||||
elif len(return_types) > 1:
|
||||
self.error(
|
||||
stmt.location,
|
||||
f"Mixed return types: {env.return_types}",
|
||||
)
|
||||
|
||||
returns: Type = UnknownType()
|
||||
if returns_hint is not None:
|
||||
assert stmt.returns is not None
|
||||
returns = returns_hint
|
||||
if returns != inferred_return:
|
||||
self.error(
|
||||
stmt.returns.location,
|
||||
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||
)
|
||||
else:
|
||||
returns = inferred_return
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns,
|
||||
)
|
||||
self.env.define(stmt.name, function)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
# TODO check not yet defined locally
|
||||
type: Type = stmt.type.accept(self)
|
||||
self.env.define(stmt.name, type)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
value: Type = self.evaluate(stmt.value)
|
||||
for target in stmt.targets:
|
||||
if not isinstance(target, p.VariableExpr):
|
||||
self.logger.warning(f"Unsupported assignment to {target}")
|
||||
self.warning(target.location, f"Unsupported assignment to {target}")
|
||||
continue
|
||||
name: str = target.name
|
||||
var_type: Optional[Type] = self.look_up_variable(name, target)
|
||||
|
||||
if var_type is None:
|
||||
self.env.define(name, value)
|
||||
else:
|
||||
# TODO: implement real comparison method
|
||||
if var_type != value:
|
||||
self.error(
|
||||
stmt.location,
|
||||
f"Cannot assign {value} to {name} of type {var_type}",
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
type: Type = stmt.value.accept(self) if stmt.value is not None else UnitType()
|
||||
self.env.return_types.append(type)
|
||||
raise ReturnException()
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
||||
# For example:
|
||||
# if (m := 1 + 1) < 2:
|
||||
# ...
|
||||
# print(m) # <- m is still defined
|
||||
test_type: Type = stmt.test.accept(self)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.ctx.get_type("bool"):
|
||||
self.error(
|
||||
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
env: Environment = Environment(self.env)
|
||||
body_returned: bool = self.evaluate_block(stmt.body, env)
|
||||
else_returned: bool = self.evaluate_block(stmt.orelse, env)
|
||||
self.env.return_types.extend(env.return_types)
|
||||
if body_returned and else_returned:
|
||||
raise ReturnException()
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
||||
return UnknownType()
|
||||
left: Type = self.evaluate(expr.left)
|
||||
right: Type = self.evaluate(expr.right)
|
||||
|
||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
||||
if result is None:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
return result
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.warning(expr.location, f"Unsupported operator {expr.operator}")
|
||||
return UnknownType()
|
||||
left: Type = self.evaluate(expr.left)
|
||||
right: Type = self.evaluate(expr.right)
|
||||
|
||||
result: Optional[Type] = self.ctx.get_operation_result(left, method, right)
|
||||
if result is None:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
return result
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type: ...
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||
if path := self.parse_midas_import(expr):
|
||||
self.import_midas(path)
|
||||
return UnknownType()
|
||||
callee: Type = self.evaluate(expr.callee)
|
||||
if not isinstance(callee, Function):
|
||||
self.error(expr.callee.location, "Callee is not a function")
|
||||
return UnknownType()
|
||||
function: Function = callee
|
||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||
for arg in mapped:
|
||||
if arg.type != arg.argument.type:
|
||||
self.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
return function.returns
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
||||
match expr.value:
|
||||
case bool(): # Must be before int
|
||||
return self.ctx.get_type("bool")
|
||||
case int():
|
||||
return self.ctx.get_type("int")
|
||||
case float():
|
||||
return self.ctx.get_type("float")
|
||||
case str():
|
||||
return self.ctx.get_type("str")
|
||||
case _:
|
||||
self.warning(expr.location, f"Unknown literal {expr}")
|
||||
return UnknownType()
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
||||
return self.look_up_variable(expr.name, expr) or UnknownType()
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type: ...
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> Type: ...
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||
return expr.type.accept(self)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||
test_type: Type = expr.test.accept(self)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.ctx.get_type("bool"):
|
||||
self.error(
|
||||
expr.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
true_type: Type = expr.if_true.accept(self)
|
||||
false_type: Type = expr.if_false.accept(self)
|
||||
if true_type != false_type:
|
||||
self.error(
|
||||
expr.location,
|
||||
f"Type mismatch in ternary if branches: true={true_type} != false={false_type}",
|
||||
)
|
||||
return UnknownType()
|
||||
return true_type
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||
return self.ctx.get_type(node.base)
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> Type: ...
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
||||
|
||||
def map_call_arguments(
|
||||
self, function: Function, call: p.CallExpr
|
||||
) -> list[MappedArgument]:
|
||||
"""Map call arguments to function parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
call (p.CallExpr): the call expression
|
||||
|
||||
Returns:
|
||||
list[MappedArgument]: the list of mapped arguments
|
||||
"""
|
||||
positional: list[tuple[p.Expr, Type]] = [
|
||||
(arg, self.evaluate(arg)) for arg in call.arguments
|
||||
]
|
||||
keywords: dict[str, tuple[p.Expr, Type]] = {
|
||||
name: (arg, self.evaluate(arg)) for name, arg in call.keywords.items()
|
||||
}
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
self.error(arg[0].location, "Too many positional arguments")
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if name in set_args:
|
||||
self.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
self.error(
|
||||
call.location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
self.error(
|
||||
call.location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
|
||||
return mapped
|
||||
@property
|
||||
def diagnostics(self) -> list[Diagnostic]:
|
||||
return self.reporter.diagnostics
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
@@ -14,12 +13,13 @@ class DiagnosticType(StrEnum):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Diagnostic:
|
||||
file_path: Path
|
||||
file_path: Optional[str]
|
||||
location: Location
|
||||
type: DiagnosticType
|
||||
message: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
@property
|
||||
def location_str(self) -> str:
|
||||
start_loc: str = f"L{self.location.lineno}:{self.location.col_offset+1}"
|
||||
end_loc: Optional[str] = ""
|
||||
if (
|
||||
@@ -27,7 +27,16 @@ class Diagnostic:
|
||||
and self.location.end_col_offset is not None
|
||||
):
|
||||
end_loc = f"L{self.location.end_lineno}:{self.location.end_col_offset+1}"
|
||||
loc: str = (
|
||||
f"at {start_loc}" if end_loc is None else f"from {start_loc} to {end_loc}"
|
||||
)
|
||||
return f"{self.type} in {self.file_path} {loc}: {self.message}"
|
||||
|
||||
loc: str = ""
|
||||
if self.file_path is not None:
|
||||
loc += f" in {self.file_path}"
|
||||
if end_loc is None:
|
||||
loc += f" at {start_loc}"
|
||||
else:
|
||||
loc += f" from {start_loc} to {end_loc}"
|
||||
|
||||
return f"{self.type}{loc}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.location_str}: {self.message}"
|
||||
|
||||
206
midas/checker/midas.py
Normal file
206
midas/checker/midas.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.builtins import define_builtins
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
ComplexType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token
|
||||
from midas.parser.midas import MidasParser
|
||||
|
||||
|
||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||
self.logger: logging.Logger = logging.getLogger("MidasTyper")
|
||||
self.reporter: FileReporter = reporter.for_file(None)
|
||||
|
||||
self.types: TypesRegistry = types
|
||||
self._local_variables: dict[str, TypeVar] = {}
|
||||
|
||||
self._current_name: Optional[str] = None
|
||||
|
||||
define_builtins(self.types)
|
||||
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
|
||||
self.process(builtins_path.read_text(), str(builtins_path))
|
||||
|
||||
def process(self, source: str, path: Optional[str]):
|
||||
self.reporter = self.reporter.for_file(path)
|
||||
lexer: MidasLexer = MidasLexer(source)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser: MidasParser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
for error in parser.errors:
|
||||
self.reporter.error(error.token.get_location(), error.message)
|
||||
self.resolve(stmts)
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
|
||||
Raises:
|
||||
NameError: if the type is not defined
|
||||
|
||||
Returns:
|
||||
Type: the type
|
||||
"""
|
||||
if name in self._local_variables:
|
||||
return self._local_variables[name]
|
||||
return self.types.get_type(name)
|
||||
|
||||
def resolve(self, stmts: list[m.Stmt]):
|
||||
"""Process a sequence of statements
|
||||
|
||||
Args:
|
||||
stmts (list[m.Stmt]): the statements
|
||||
"""
|
||||
for stmt in stmts:
|
||||
stmt.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
name: str = stmt.name.lexeme
|
||||
self._current_name = name
|
||||
params: list[TypeVar] = self._resolve_type_params(stmt.params)
|
||||
|
||||
type: Type = stmt.type.accept(self)
|
||||
if len(params) != 0:
|
||||
type = GenericType(name=name, params=params, body=type)
|
||||
else:
|
||||
type = AliasType(name=name, type=type)
|
||||
self.types.define_type(name, type)
|
||||
self._local_variables.clear()
|
||||
self._current_name = None
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt) -> None: ...
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self._resolve_type_params(stmt.params)
|
||||
base_name: str = stmt.name.lexeme
|
||||
try:
|
||||
_ = self.get_type(base_name)
|
||||
except NameError:
|
||||
self.reporter.error(stmt.name.get_location(), f"Unknown type '{base_name}'")
|
||||
|
||||
for member in stmt.members:
|
||||
member_type: Type = member.type.accept(self)
|
||||
self.types.define_member(
|
||||
base_name,
|
||||
member.name.lexeme,
|
||||
member_type,
|
||||
member.kind == m.MemberKind.METHOD,
|
||||
)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||
self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
|
||||
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
||||
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
||||
self.reporter.warning(expr.location, "GetExpr not yet supported")
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
|
||||
self.reporter.warning(expr.location, "VariableExpr not yet supported")
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||
return expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
||||
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
||||
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||
name: str = type.name.lexeme
|
||||
try:
|
||||
return self.get_type(name)
|
||||
except NameError:
|
||||
msg: str = f"Undefined type {name}"
|
||||
if self._current_name == name:
|
||||
msg += ". Recursive types are not supported, use an extend block"
|
||||
self.reporter.error(type.name.get_location(), msg)
|
||||
return UnknownType()
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
args: list[Type] = [arg.accept(self) for arg in type.args]
|
||||
try:
|
||||
return self.types.apply_generic(type_, args)
|
||||
except Exception as e:
|
||||
self.reporter.error(type.location, f"Cannot apply generic type: {e}")
|
||||
return UnknownType()
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
type.constraint.accept(self)
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
|
||||
return ComplexType(
|
||||
members={
|
||||
member.name.lexeme: member.type.accept(self) for member in type.members
|
||||
}
|
||||
)
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> Type:
|
||||
return ExtensionType(
|
||||
base=type.base.accept(self),
|
||||
extension=self.visit_complex_type(type.extension),
|
||||
)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||
n_pos_args: int = len(type.pos_args)
|
||||
n_args: int = len(type.args)
|
||||
|
||||
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
||||
return Function.Argument(
|
||||
pos=i,
|
||||
name=arg.name.lexeme if arg.name is not None else str(i),
|
||||
type=arg.type.accept(self),
|
||||
required=arg.required,
|
||||
)
|
||||
|
||||
return Function(
|
||||
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
|
||||
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
|
||||
kw_args=[
|
||||
process_arg(arg, i + n_pos_args + n_args)
|
||||
for i, arg in enumerate(type.kw_args)
|
||||
],
|
||||
returns=type.returns.accept(self),
|
||||
)
|
||||
|
||||
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||
vars: list[TypeVar] = []
|
||||
for param in params:
|
||||
name: str = param.name.lexeme
|
||||
bound: Optional[Type] = None
|
||||
if param.bound is not None:
|
||||
bound = param.bound.accept(self)
|
||||
var = TypeVar(name=name, bound=bound)
|
||||
self._local_variables[name] = var
|
||||
vars.append(var)
|
||||
return vars
|
||||
@@ -29,3 +29,10 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
||||
# ast.In: "__in__",
|
||||
# ast.NotIn: "__notin__",
|
||||
}
|
||||
|
||||
UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
||||
ast.Invert: "__invert__",
|
||||
# ast.Not: "",
|
||||
ast.UAdd: "__pos__",
|
||||
ast.USub: "__neg__",
|
||||
}
|
||||
|
||||
859
midas/checker/python.py
Normal file
859
midas/checker/python.py
Normal file
@@ -0,0 +1,859 @@
|
||||
import ast
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter, Reporter
|
||||
from midas.checker.resolver import Resolver
|
||||
from midas.checker.types import (
|
||||
Function,
|
||||
OverloadedFunction,
|
||||
Type,
|
||||
UnitType,
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.parser.python import PythonParser
|
||||
|
||||
TypedExpr = tuple[p.Expr, Type]
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument:
|
||||
expr: p.Expr
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OverloadCandidate:
|
||||
function: Function
|
||||
mapped: list[MappedArgument]
|
||||
|
||||
|
||||
class PythonTyper(
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[Type],
|
||||
p.MidasType.Visitor[Type],
|
||||
):
|
||||
"""A type checker which can use custom type definitions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
types: TypesRegistry,
|
||||
reporter: Reporter,
|
||||
):
|
||||
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
||||
self.reporter: FileReporter = reporter.for_file(None)
|
||||
self.types: TypesRegistry = types
|
||||
self.global_env: Environment = Environment()
|
||||
self.env: Environment = self.global_env
|
||||
self.locals: dict[p.Expr, int] = {}
|
||||
self.judgements: list[tuple[p.Expr, Type]] = []
|
||||
|
||||
def process(self, source: str, path: Optional[str]):
|
||||
self.reporter = self.reporter.for_file(path)
|
||||
|
||||
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
resolver = Resolver()
|
||||
resolver.resolve(*stmts)
|
||||
|
||||
self.env = self.global_env
|
||||
self.locals = resolver.locals
|
||||
self.judgements = []
|
||||
|
||||
self.check(stmts)
|
||||
|
||||
def type_of(self, expr: p.Expr) -> Type:
|
||||
"""Evaluate the type of an expression
|
||||
|
||||
Args:
|
||||
expr (p.Expr): the expression to evaluate
|
||||
|
||||
Returns:
|
||||
Type: the type of the given expression
|
||||
"""
|
||||
type: Type = expr.accept(self)
|
||||
self.judgements.append((expr, type))
|
||||
return type
|
||||
|
||||
def resolve_type_expr(self, expr: p.MidasType) -> Type:
|
||||
return expr.accept(self)
|
||||
|
||||
def process_stmt(self, stmt: p.Stmt) -> None:
|
||||
stmt.accept(self)
|
||||
|
||||
def process_block(self, block: list[p.Stmt], env: Environment) -> bool:
|
||||
"""Evaluate a sequence of statements
|
||||
|
||||
Args:
|
||||
block (list[p.Stmt]): the statements to evaluate
|
||||
env (Environment): the environment in which to evaluate
|
||||
|
||||
Returns:
|
||||
bool: whether a return statement is present in the block
|
||||
"""
|
||||
previous_env: Environment = self.env
|
||||
self.env = env
|
||||
returned: bool = False
|
||||
for i, stmt in enumerate(block):
|
||||
try:
|
||||
self.process_stmt(stmt)
|
||||
except ReturnException:
|
||||
returned = True
|
||||
if i < len(block) - 1:
|
||||
self.reporter.warning(
|
||||
block[i + 1].location, "Unreachable statement"
|
||||
)
|
||||
break
|
||||
self.env = previous_env
|
||||
return returned
|
||||
|
||||
def check(self, statements: list[p.Stmt]) -> None:
|
||||
"""Type check a sequence of statements and returns diagnostics
|
||||
|
||||
Args:
|
||||
statements (list[p.Stmt]): the statements to evaluate and check
|
||||
"""
|
||||
for stmt in statements:
|
||||
self.process_stmt(stmt)
|
||||
|
||||
self.logger.debug(f"Final environment: {self.env.flat_dict()}")
|
||||
|
||||
def look_up_variable(self, name: str, expr: p.Expr) -> Optional[Type]:
|
||||
"""Look up a variable in the environment it was declared
|
||||
|
||||
Args:
|
||||
name (str): the name of the variable
|
||||
expr (p.Expr): the variable expression, used to lookup the scope distance
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the type of the variable, or None if it was not found
|
||||
"""
|
||||
distance: Optional[int] = self.locals.get(expr)
|
||||
if distance is not None:
|
||||
return self.env.get_at(distance, name)
|
||||
return self.global_env.get(name)
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
return self.types.is_subtype(type1, type2)
|
||||
|
||||
def visit_expression_stmt(self, stmt: p.ExpressionStmt) -> None:
|
||||
self.type_of(stmt.expr)
|
||||
|
||||
def visit_function(self, stmt: p.Function) -> None:
|
||||
env: Environment = Environment(self.env)
|
||||
pos_args: list[Function.Argument] = []
|
||||
args: list[Function.Argument] = []
|
||||
kw_args: list[Function.Argument] = []
|
||||
|
||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||
if arg.type is not None:
|
||||
return self.resolve_type_expr(arg.type)
|
||||
if arg.default is not None:
|
||||
return self.type_of(arg.default)
|
||||
return UnknownType()
|
||||
|
||||
pos: int = 0
|
||||
for arg in stmt.posonlyargs:
|
||||
pos_args.append(
|
||||
Function.Argument(
|
||||
pos=pos,
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
pos += 1
|
||||
for arg in stmt.args:
|
||||
args.append(
|
||||
Function.Argument(
|
||||
pos=pos,
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
pos += 1
|
||||
for arg in stmt.kwonlyargs:
|
||||
kw_args.append(
|
||||
Function.Argument(
|
||||
pos=pos, # not relevant
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
pos += 1
|
||||
|
||||
for arg in pos_args + args + kw_args:
|
||||
env.define(arg.name, arg.type)
|
||||
|
||||
returns_hint: Optional[Type] = None
|
||||
if stmt.returns is not None:
|
||||
returns_hint = self.resolve_type_expr(stmt.returns)
|
||||
# Early define to handle simple fully-typed recursion
|
||||
inside_function: Function = Function(
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns_hint,
|
||||
)
|
||||
self.env.define(stmt.name, inside_function)
|
||||
|
||||
returned: bool = self.process_block(stmt.body, env)
|
||||
inferred_return: Type = UnknownType()
|
||||
if not returned:
|
||||
env.return_types.append(UnitType())
|
||||
return_types: list[Type] = self.types.reduce_types(env.return_types)
|
||||
if len(return_types) == 1:
|
||||
inferred_return = return_types[0]
|
||||
elif len(return_types) > 1:
|
||||
self.reporter.error(
|
||||
stmt.location,
|
||||
f"Mixed return types: {return_types}",
|
||||
)
|
||||
|
||||
returns: Type = UnknownType()
|
||||
if returns_hint is not None:
|
||||
assert stmt.returns is not None
|
||||
returns = returns_hint
|
||||
if returns != inferred_return:
|
||||
self.reporter.error(
|
||||
stmt.returns.location,
|
||||
f"Return type mismatch, annotated {returns} but returns {inferred_return}",
|
||||
)
|
||||
else:
|
||||
returns = inferred_return
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
function: Function = Function(
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns,
|
||||
)
|
||||
self.env.define(stmt.name, function)
|
||||
|
||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||
# TODO check not yet defined locally
|
||||
type: Type = self.resolve_type_expr(stmt.type)
|
||||
self.env.define(stmt.name, type)
|
||||
|
||||
def visit_assign_stmt(self, stmt: p.AssignStmt) -> None:
|
||||
value_type: Type = self.type_of(stmt.value)
|
||||
for target in stmt.targets:
|
||||
self._assign(stmt.location, target, value_type)
|
||||
|
||||
def _assign(self, location: Location, target: p.Expr, value_type: Type):
|
||||
match target:
|
||||
case p.VariableExpr():
|
||||
self._assign_var(location, target, value_type)
|
||||
|
||||
case p.GetExpr(object=object, name=name):
|
||||
self._assign_attr(location, object, name, value_type)
|
||||
|
||||
case _:
|
||||
if not isinstance(target, p.VariableExpr):
|
||||
self.logger.warning(f"Unsupported assignment to {target}")
|
||||
self.reporter.warning(
|
||||
target.location, f"Unsupported assignment to {target}"
|
||||
)
|
||||
|
||||
def _assign_var(self, location: Location, target: p.VariableExpr, value_type: Type):
|
||||
name: str = target.name
|
||||
var_type: Optional[Type] = self.look_up_variable(name, target)
|
||||
|
||||
if var_type is None:
|
||||
self.env.define(name, value_type)
|
||||
else:
|
||||
# S <: T
|
||||
# Γ, x: T v: S
|
||||
# x = v
|
||||
if not self.is_subtype(value_type, var_type):
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Cannot assign {value_type} to variable '{name}' of type {var_type}",
|
||||
)
|
||||
|
||||
def _assign_attr(
|
||||
self, location: Location, object: p.Expr, name: str, value_type: Type
|
||||
):
|
||||
object_type: Type = self.type_of(object)
|
||||
member: Optional[Type] = self.types.lookup_member(object_type, name)
|
||||
if member is None:
|
||||
self.reporter.error(location, f"Unknown member '{name}' of {object_type}")
|
||||
return
|
||||
self.logger.debug(f"Member '{name}' of {object_type} has type {member}")
|
||||
if not self.is_subtype(value_type, member):
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
|
||||
)
|
||||
|
||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
|
||||
self.env.return_types.append(type)
|
||||
raise ReturnException()
|
||||
|
||||
def visit_if_stmt(self, stmt: p.IfStmt) -> None:
|
||||
# Not evaluated in sub-environment because assignments in the test leak out of the if
|
||||
# For example:
|
||||
# if (m := 1 + 1) < 2:
|
||||
# ...
|
||||
# print(m) # <- m is still defined
|
||||
test_type: Type = self.type_of(stmt.test)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.types.get_type("bool"):
|
||||
self.reporter.error(
|
||||
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
env: Environment = Environment(self.env)
|
||||
body_returned: bool = self.process_block(stmt.body, env)
|
||||
else_returned: bool = self.process_block(stmt.orelse, env)
|
||||
self.env.return_types.extend(env.return_types)
|
||||
if body_returned and else_returned:
|
||||
raise ReturnException()
|
||||
|
||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unsupported operator {expr.operator}"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||
|
||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unsupported operator {expr.operator}"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||
|
||||
def _visit_binary_expr(
|
||||
self, location: Location, left_expr: p.Expr, right_expr: p.Expr, method: str
|
||||
) -> Type:
|
||||
left: Type = self.type_of(left_expr)
|
||||
right: Type = self.type_of(right_expr)
|
||||
|
||||
operation: Optional[Type] = self.types.lookup_member(left, method)
|
||||
if operation is None:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Undefined operation {method} between {left} and {right}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return self._get_call_result(location, operation, [(right_expr, right)], {})
|
||||
|
||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
||||
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
|
||||
if method is None:
|
||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unsupported operator {expr.operator}"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
operand: Type = self.type_of(expr.right)
|
||||
operation: Optional[Type] = self.types.lookup_member(operand, method)
|
||||
if operation is None:
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Undefined operation {method} for {operand}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
return self._get_call_result(
|
||||
expr.location, operation, [(expr.right, operand)], {}
|
||||
)
|
||||
|
||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
positional: list[TypedExpr] = [
|
||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||
]
|
||||
keywords: dict[str, TypedExpr] = {
|
||||
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||
}
|
||||
return self._get_call_result(
|
||||
location=expr.location,
|
||||
callee=callee,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> Type:
|
||||
object: Type = self.type_of(expr.object)
|
||||
member: Optional[Type] = self.types.lookup_member(object, expr.name)
|
||||
if member is None:
|
||||
self.reporter.error(
|
||||
expr.location, f"Unknown member '{expr.name}' of {object}"
|
||||
)
|
||||
return UnknownType()
|
||||
self.logger.debug(f"Member '{expr.name}' of {object} has type {member}")
|
||||
return member
|
||||
|
||||
def visit_literal_expr(self, expr: p.LiteralExpr) -> Type:
|
||||
match expr.value:
|
||||
case bool(): # Must be before int
|
||||
return self.types.get_type("bool")
|
||||
case int():
|
||||
return self.types.get_type("int")
|
||||
case float():
|
||||
return self.types.get_type("float")
|
||||
case str():
|
||||
return self.types.get_type("str")
|
||||
case _:
|
||||
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||
return UnknownType()
|
||||
|
||||
def visit_variable_expr(self, expr: p.VariableExpr) -> Type:
|
||||
type: Optional[Type] = self.look_up_variable(expr.name, expr)
|
||||
if type is None:
|
||||
self.logger.debug(f"Unknown variable {expr.name} in {self.env.flat_dict()}")
|
||||
self.reporter.warning(expr.location, "Unknown variable")
|
||||
return type or UnknownType()
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> Type:
|
||||
left: Type = self.type_of(expr.left)
|
||||
right: Type = self.type_of(expr.right)
|
||||
|
||||
if self.is_subtype(left, right):
|
||||
return right
|
||||
if self.is_subtype(right, left):
|
||||
return left
|
||||
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Incompatible operand types, {left=} and {right=}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||
return self.resolve_type_expr(expr.type)
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||
test_type: Type = self.type_of(expr.test)
|
||||
|
||||
# TODO Allow subtypes or any type
|
||||
if test_type != self.types.get_type("bool"):
|
||||
self.reporter.error(
|
||||
expr.test.location, f"If test must be a boolean, got {test_type}"
|
||||
)
|
||||
|
||||
true_type: Type = self.type_of(expr.if_true)
|
||||
false_type: Type = self.type_of(expr.if_false)
|
||||
if self.is_subtype(true_type, false_type):
|
||||
return false_type
|
||||
if self.is_subtype(false_type, true_type):
|
||||
return true_type
|
||||
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Incompatible types in ternary if branches: true={true_type} and false={false_type}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> Type:
|
||||
list_type: Type = self.types.get_type("list")
|
||||
item_types: list[Type] = [self.type_of(item) for item in expr.items]
|
||||
item_types = self.types.reduce_types(item_types)
|
||||
|
||||
if len(item_types) == 0:
|
||||
return list_type
|
||||
|
||||
if len(item_types) == 1:
|
||||
item_type: Type = item_types[0]
|
||||
return self.types.apply_generic(list_type, [item_type])
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Heterogeneous list items: {item_types}",
|
||||
)
|
||||
return self.types.apply_generic(list_type, [UnknownType()])
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
|
||||
object: Type = self.type_of(expr.object)
|
||||
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
|
||||
if operation is None:
|
||||
self.reporter.error(
|
||||
expr.location,
|
||||
f"Undefined method __getitem__ on {object}",
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
index: Type = self.type_of(expr.index)
|
||||
return self._get_call_result(
|
||||
expr.location, operation, [(expr.index, index)], {}
|
||||
)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
|
||||
return self.types.get_type("slice")
|
||||
|
||||
def visit_base_type(self, node: p.BaseType) -> Type:
|
||||
base: Type
|
||||
try:
|
||||
base = self.types.get_type(node.base)
|
||||
except NameError:
|
||||
self.reporter.warning(node.location, f"Unknown type '{node.base}'")
|
||||
return UnknownType()
|
||||
|
||||
if node.param is not None:
|
||||
param: Type = self.resolve_type_expr(node.param)
|
||||
return self.types.apply_generic(base, [param])
|
||||
return base
|
||||
|
||||
def visit_constraint_type(self, node: p.ConstraintType) -> Type:
|
||||
self.reporter.warning(node.location, "ConstraintType not yet supported")
|
||||
return UnknownType()
|
||||
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> Type:
|
||||
self.reporter.warning(node.location, "FrameColumn not yet supported")
|
||||
return UnknownType()
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> Type:
|
||||
self.reporter.warning(node.location, "FrameType not yet supported")
|
||||
return UnknownType()
|
||||
|
||||
def _get_call_result(
|
||||
self,
|
||||
location: Location,
|
||||
callee: Type,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
"""Get the result type of a function call
|
||||
|
||||
If the function has overloads, the function will try to resolve the
|
||||
appropriate signature.
|
||||
Argument types are matched to the defined parameters.
|
||||
The function doesn't take the raw expression as a parameter to accomodate
|
||||
for desugared calls such as for operators.
|
||||
|
||||
Args:
|
||||
location (Location): the call location
|
||||
callee (Type): the called function
|
||||
positional (list[TypedExpr]): the list positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
|
||||
Returns:
|
||||
Type: the return type of the call, or `UnknownType` if either
|
||||
the call is invalid or no overload matched the arguments uniquely
|
||||
"""
|
||||
match callee:
|
||||
case Function() as function:
|
||||
valid: bool
|
||||
mapped: list[MappedArgument]
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function, location, positional, keywords
|
||||
)
|
||||
valid = valid and self._are_arguments_valid(mapped)
|
||||
if not valid:
|
||||
return UnknownType()
|
||||
return function.returns
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
function = self._match_overload(
|
||||
overloads, location, positional, keywords
|
||||
)
|
||||
if function is None:
|
||||
return UnknownType()
|
||||
return function.returns
|
||||
case _:
|
||||
self.reporter.error(location, f"{callee} is not callable")
|
||||
return UnknownType()
|
||||
|
||||
def _are_arguments_valid(
|
||||
self,
|
||||
arguments: list[MappedArgument],
|
||||
report_errors: bool = True,
|
||||
) -> bool:
|
||||
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||
|
||||
Args:
|
||||
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||
"""
|
||||
valid: bool = True
|
||||
for arg in arguments:
|
||||
if not self.is_subtype(arg.type, arg.argument.type):
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
valid = False
|
||||
return valid
|
||||
|
||||
def _match_overload(
|
||||
self,
|
||||
overloads: list[Type],
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Optional[Function]:
|
||||
"""Try and resolve the appropriate overload for the given arguments
|
||||
|
||||
Args:
|
||||
overloads (list[Type]): the list of possible overloads
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||
|
||||
Returns:
|
||||
Optional[Function]: the resolved function signature if it can be
|
||||
determined unambigously, or `None`.
|
||||
"""
|
||||
candidates: list[OverloadCandidate] = []
|
||||
for overload in overloads:
|
||||
function: Type = unfold_type(overload)
|
||||
if not isinstance(function, Function):
|
||||
self.logger.error(
|
||||
f"Overload is not a function: {overload} is {function}"
|
||||
)
|
||||
continue
|
||||
valid, mapped = self.map_call_arguments(
|
||||
function=function,
|
||||
location=location,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
report_errors=False,
|
||||
)
|
||||
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||
candidates.append(
|
||||
OverloadCandidate(
|
||||
function=function,
|
||||
mapped=mapped,
|
||||
)
|
||||
)
|
||||
|
||||
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||
kw_types: str = ", ".join(
|
||||
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||
)
|
||||
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||
|
||||
n_candidates: int = len(candidates)
|
||||
|
||||
# Exactly 1 match -> return it
|
||||
if n_candidates == 1:
|
||||
return candidates[0].function
|
||||
|
||||
# No match -> invalid call
|
||||
if n_candidates == 0:
|
||||
overloads_str: str = ", ".join(map(str, overloads))
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"No matching overload in [{overloads_str}] {for_args}",
|
||||
)
|
||||
return None
|
||||
|
||||
# Multiple matches -> see if one <: all others (more specific)
|
||||
for i1, c1 in enumerate(candidates):
|
||||
mapped1: list[MappedArgument] = c1.mapped
|
||||
best_match: bool = True
|
||||
for i2, c2 in enumerate(candidates):
|
||||
if i1 == i2:
|
||||
continue
|
||||
mapped2: list[MappedArgument] = c2.mapped
|
||||
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||
best_match = False
|
||||
break
|
||||
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||
if best_match:
|
||||
return c1.function
|
||||
|
||||
candidates_str: str = ", ".join(
|
||||
str(candidate.function) for candidate in candidates
|
||||
)
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Multiple matching overloads {for_args}: {candidates_str}",
|
||||
)
|
||||
return None
|
||||
|
||||
def map_call_arguments(
|
||||
self,
|
||||
function: Function,
|
||||
location: Location,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
report_errors: bool = True,
|
||||
) -> tuple[bool, list[MappedArgument]]:
|
||||
"""Map call arguments to a function's parameters as defined in its signature
|
||||
|
||||
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||
with the arguments passed at the call site
|
||||
|
||||
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||
unless `report_errors` is set to `False`
|
||||
|
||||
Args:
|
||||
function (Function): the function definition
|
||||
location (Location): the call location
|
||||
positional (list[TypedExpr]): the list of positional arguments
|
||||
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||
the call is valid and the list of mapped arguments
|
||||
"""
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: list[str] = [
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
]
|
||||
required_keyword: list[str] = [
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
]
|
||||
|
||||
mapped: list[MappedArgument] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
valid_call: bool = True
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
arg[0].location, "Too many positional arguments"
|
||||
)
|
||||
valid_call = False
|
||||
break
|
||||
name: str = param.name
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if report_errors:
|
||||
if name in set_args:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.reporter.error(
|
||||
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||
)
|
||||
valid_call = False
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
if name in required_positional:
|
||||
required_positional.remove(name)
|
||||
if name in required_keyword:
|
||||
required_keyword.remove(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
def join_args(args: list[str]) -> str:
|
||||
args = list(map(lambda a: f"'{a}'", args))
|
||||
if len(args) == 0:
|
||||
return ""
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||
|
||||
if len(required_positional) != 0:
|
||||
plural: str = "" if len(required_positional) == 1 else "s"
|
||||
args: str = join_args(required_positional)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required positional argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||
args: str = join_args(required_keyword)
|
||||
if report_errors:
|
||||
self.reporter.error(
|
||||
location,
|
||||
f"Missing required keyword argument{plural}: {args}",
|
||||
)
|
||||
valid_call = False
|
||||
|
||||
return valid_call, mapped
|
||||
|
||||
def _are_mapped_subtypes(
|
||||
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
||||
) -> bool:
|
||||
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||
|
||||
This function checks whether the argument mappings `mapped1` are subtypes
|
||||
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||
|
||||
This is used to check whether a given overload is
|
||||
a more specific function/ a subtype of another.
|
||||
|
||||
Args:
|
||||
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||
|
||||
Returns:
|
||||
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||
"""
|
||||
by_expr: dict[p.Expr, Type] = {}
|
||||
for arg in mapped1:
|
||||
by_expr[arg.expr] = arg.argument.type
|
||||
|
||||
for arg in mapped2:
|
||||
type2: Type = arg.argument.type
|
||||
type1: Type = by_expr[arg.expr]
|
||||
if not self.is_subtype(type1, type2):
|
||||
return False
|
||||
return True
|
||||
347
midas/checker/registry.py
Normal file
347
midas/checker/registry.py
Normal file
@@ -0,0 +1,347 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||
from midas.checker.types import (
|
||||
AliasType,
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ComplexType,
|
||||
ExtensionType,
|
||||
Function,
|
||||
GenericType,
|
||||
OverloadedFunction,
|
||||
TopType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
substitute_typevars,
|
||||
)
|
||||
|
||||
|
||||
class TypesRegistry:
|
||||
def __init__(self) -> None:
|
||||
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
||||
self._types: dict[str, Type] = {}
|
||||
self._members: dict[str, dict[str, Type]] = {}
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
|
||||
Raises:
|
||||
NameError: if the type is not defined
|
||||
|
||||
Returns:
|
||||
Type: the type
|
||||
"""
|
||||
if name in self._types:
|
||||
return self._types[name]
|
||||
raise NameError(f"Undefined type {name}")
|
||||
|
||||
def define_type(self, name: str, type: Type) -> Type:
|
||||
"""Define a type in the registry
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
type (Type): the type to define
|
||||
|
||||
Raises:
|
||||
ValueError: if a type is already defined with that name
|
||||
|
||||
Returns:
|
||||
Type: the defined type
|
||||
"""
|
||||
if name in self._types:
|
||||
raise ValueError(f"Type {name} already defined")
|
||||
self._types[name] = type
|
||||
return type
|
||||
|
||||
def define_member(
|
||||
self, type_name: str, member_name: str, member_type: Type, is_method: bool
|
||||
):
|
||||
members: dict[str, Type] = self._members.setdefault(type_name, {})
|
||||
if member_name in members:
|
||||
if not is_method:
|
||||
self.logger.error(
|
||||
f"Member '{member_name}' already defined for type {type_name}"
|
||||
)
|
||||
return
|
||||
current: Type = members[member_name]
|
||||
combined: Type
|
||||
match current:
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
combined = OverloadedFunction(overloads=overloads + [member_type])
|
||||
case _:
|
||||
combined = OverloadedFunction(overloads=[current, member_type])
|
||||
members[member_name] = combined
|
||||
|
||||
else:
|
||||
members[member_name] = member_type
|
||||
|
||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||
"""Check whether `type1` is a subtype of `type2`
|
||||
|
||||
For more details on the rules checked here, see TAPL Chap. 15-16-17
|
||||
|
||||
Args:
|
||||
type1 (Type): the potential subtype
|
||||
type2 (Type): the potential supertype
|
||||
|
||||
Returns:
|
||||
bool: whether `type1` is a subtype of `type2`
|
||||
"""
|
||||
|
||||
if type1 == type2:
|
||||
return True
|
||||
|
||||
match (type1, type2):
|
||||
case (_, TopType()):
|
||||
return True
|
||||
|
||||
case (AliasType(type=base1), _):
|
||||
return self.is_subtype(base1, type2)
|
||||
|
||||
case (BaseType(name=name1), BaseType(name=name2)):
|
||||
return name1 in BUILTIN_SUBTYPES.get(name2, set())
|
||||
|
||||
case (ComplexType(properties=props1), ComplexType(properties=props2)):
|
||||
for k, t in props2.items():
|
||||
if k not in props1:
|
||||
return False
|
||||
if not self.is_subtype(props1[k], t):
|
||||
return False
|
||||
return True
|
||||
|
||||
case (Function(), Function()):
|
||||
return self.is_func_subtype(type1, type2)
|
||||
|
||||
case (TypeVar(bound=bound), _):
|
||||
if bound is None:
|
||||
return False
|
||||
return self.is_subtype(bound, type2)
|
||||
|
||||
return False
|
||||
|
||||
# TODO: verify the logic in here
|
||||
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
||||
"""Check whether a function is a subtype of another
|
||||
|
||||
Args:
|
||||
func1 (Function): the potential function subtype
|
||||
func2 (Function): the potential function supertype
|
||||
|
||||
Returns:
|
||||
bool: whether `func1` is a subtype of `func2`
|
||||
"""
|
||||
if not self.is_subtype(func1.returns, func2.returns):
|
||||
return False
|
||||
|
||||
pos1: list[Function.Argument] = func1.pos_args
|
||||
mixed1: list[Function.Argument] = func1.args
|
||||
kw1: dict[str, Function.Argument] = {a.name: a for a in func1.kw_args}
|
||||
pos2: list[Function.Argument] = func2.pos_args
|
||||
mixed2: list[Function.Argument] = func2.args
|
||||
kw2: dict[str, Function.Argument] = {a.name: a for a in func2.kw_args}
|
||||
|
||||
mixed_by_pos: dict[int, Function.Argument] = {arg.pos: arg for arg in mixed2}
|
||||
mixed_by_name: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2}
|
||||
|
||||
def is_arg_subtype(sub: Function.Argument, sup: Function.Argument) -> bool:
|
||||
if not self.is_subtype(sub.type, sup.type):
|
||||
return False
|
||||
if not sup.required and sub.required:
|
||||
return False
|
||||
return True
|
||||
|
||||
for arg1 in pos1:
|
||||
arg2: Function.Argument
|
||||
if arg1.pos < len(pos2):
|
||||
arg2 = pos2[arg1.pos]
|
||||
elif arg1.pos in mixed_by_pos:
|
||||
arg2 = mixed_by_pos[arg1.pos]
|
||||
elif not arg1.required:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
if not is_arg_subtype(arg2, arg1):
|
||||
return False
|
||||
|
||||
for name, arg1 in kw1.items():
|
||||
arg2: Function.Argument
|
||||
if name in kw2:
|
||||
arg2 = kw2[name]
|
||||
elif name in mixed_by_name:
|
||||
arg2 = mixed_by_name[name]
|
||||
elif not arg1.required:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
if not is_arg_subtype(arg2, arg1):
|
||||
return False
|
||||
|
||||
for arg1 in mixed1:
|
||||
pos_arg2: Optional[Function.Argument] = None
|
||||
kw_arg2: Optional[Function.Argument] = None
|
||||
if arg1.name in kw2:
|
||||
kw_arg2 = kw2[arg1.name]
|
||||
elif arg1.name in mixed_by_name:
|
||||
kw_arg2 = mixed_by_name[arg1.name]
|
||||
if arg1.pos < len(pos2):
|
||||
pos_arg2 = pos2[arg1.pos]
|
||||
elif arg1.pos in mixed_by_pos:
|
||||
pos_arg2 = mixed_by_pos[arg1.pos]
|
||||
|
||||
# No match in func2 and arg is required
|
||||
if pos_arg2 is None and kw_arg2 is None and arg1.required:
|
||||
return False
|
||||
|
||||
# Matching keyword argument
|
||||
if kw_arg2 is not None and not is_arg_subtype(kw_arg2, arg1):
|
||||
return False
|
||||
|
||||
# Matching positional argument
|
||||
if pos_arg2 is not None and not is_arg_subtype(pos_arg2, arg1):
|
||||
return False
|
||||
|
||||
mixed_positions: set[int] = {a.pos for a in mixed1}
|
||||
mixed_names: set[str] = {a.name for a in mixed1}
|
||||
for arg2 in pos2:
|
||||
if not arg2.required:
|
||||
continue
|
||||
if arg2.pos >= len(pos1) and arg2.pos not in mixed_positions:
|
||||
return False
|
||||
|
||||
for name, arg2 in kw2.items():
|
||||
if not arg2.required:
|
||||
continue
|
||||
if name not in kw1 and name not in mixed_names:
|
||||
return False
|
||||
|
||||
for arg2 in mixed2:
|
||||
if arg2.required:
|
||||
continue
|
||||
pos_match: bool = arg2.pos < len(pos1) or arg2.pos in mixed_positions
|
||||
kw_match: bool = arg2.name in kw1 or arg2.name in mixed_names
|
||||
if not pos_match or not kw_match:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def apply_generic(self, type: Type, args: list[Type]) -> Type:
|
||||
match type:
|
||||
case AliasType(name=name, type=base):
|
||||
return AliasType(name=name, type=self.apply_generic(base, args))
|
||||
|
||||
case GenericType(name=name, params=type_vars, body=body):
|
||||
n_args: int = len(args)
|
||||
n_type_vars: int = len(type_vars)
|
||||
if n_args < n_type_vars:
|
||||
raise ValueError(
|
||||
f"Missing type arguments, expected {n_type_vars} but only {n_args} provided"
|
||||
)
|
||||
if n_args > n_type_vars:
|
||||
raise ValueError(
|
||||
f"Too many type arguments, expected {n_type_vars} but {n_args} provided"
|
||||
)
|
||||
substitutions: dict[str, Type] = {}
|
||||
for arg, type_var in zip(args, type_vars):
|
||||
if type_var.bound is not None and not self.is_subtype(
|
||||
arg, type_var.bound
|
||||
):
|
||||
raise ValueError(
|
||||
f"Type argument {arg} is not a subtype of {type_var.bound}"
|
||||
)
|
||||
substitutions[type_var.name] = arg
|
||||
return AppliedType(
|
||||
name=name,
|
||||
args=args,
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"{type} is not a generic type")
|
||||
|
||||
def reduce_types(self, types: list[Type]) -> list[Type]:
|
||||
"""Reduce a list of types to remove subtypes and only keep the highest types
|
||||
|
||||
Args:
|
||||
types (list[Type]): the types to reduce
|
||||
|
||||
Returns:
|
||||
list[Type]: the reduced list of types
|
||||
"""
|
||||
|
||||
reduced: bool = True
|
||||
keep: list[int] = list(range(len(types)))
|
||||
while reduced:
|
||||
reduced = False
|
||||
for i, i1 in enumerate(keep):
|
||||
type1: Type = types[i1]
|
||||
for i2 in keep[i + 1 :]:
|
||||
type2 = types[i2]
|
||||
if self.is_subtype(type1, type2):
|
||||
keep.remove(i1)
|
||||
elif self.is_subtype(type2, type1):
|
||||
keep.remove(i2)
|
||||
else:
|
||||
continue
|
||||
reduced = True
|
||||
break
|
||||
return [types[i] for i in keep]
|
||||
|
||||
def lookup_member(self, type: Type, member_name: str) -> Optional[Type]:
|
||||
match type:
|
||||
case BaseType(name=name):
|
||||
if name in self._members:
|
||||
if member_name in self._members[name]:
|
||||
return self._members[name][member_name]
|
||||
return None
|
||||
|
||||
case AliasType(name=name, type=base):
|
||||
if name in self._members:
|
||||
if member_name in self._members[name]:
|
||||
return self._members[name][member_name]
|
||||
return self.lookup_member(base, member_name)
|
||||
|
||||
case AppliedType(name=name, body=body, args=args):
|
||||
generic: Type = self.get_type(name)
|
||||
|
||||
if not isinstance(generic, GenericType):
|
||||
raise ValueError("AppliedType not derived from a GenericType")
|
||||
|
||||
substitutions = {
|
||||
type_var.name: arg for arg, type_var in zip(args, generic.params)
|
||||
}
|
||||
if name in self._members:
|
||||
if member_name in self._members[name]:
|
||||
member_type: Type = self._members[name][member_name]
|
||||
return substitute_typevars(member_type, substitutions)
|
||||
|
||||
member_type2: Optional[Type] = self.lookup_member(body, member_name)
|
||||
if member_type2 is not None:
|
||||
member_type2 = substitute_typevars(member_type2, substitutions)
|
||||
return member_type2
|
||||
|
||||
case ComplexType(members=members):
|
||||
if member_name in members:
|
||||
return members[member_name]
|
||||
self.logger.debug(f"No member '{member_name}' in {type}")
|
||||
return None
|
||||
|
||||
case ExtensionType(base=base, extension=ComplexType(members=members)):
|
||||
if member_name in members:
|
||||
return members[member_name]
|
||||
self.logger.debug(
|
||||
f"No member '{member_name}' on {type}, looking up in base"
|
||||
)
|
||||
return self.lookup_member(base, member_name)
|
||||
|
||||
case UnknownType():
|
||||
return UnknownType()
|
||||
|
||||
case _:
|
||||
self.logger.debug(f"Can't get member on {type}")
|
||||
return None
|
||||
63
midas/checker/reporter.py
Normal file
63
midas/checker/reporter.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
|
||||
|
||||
class Reporter:
|
||||
def __init__(self):
|
||||
self.diagnostics: list[Diagnostic] = []
|
||||
|
||||
def report(
|
||||
self,
|
||||
path: Optional[str],
|
||||
type: DiagnosticType,
|
||||
location: Location,
|
||||
message: str,
|
||||
):
|
||||
self.diagnostics.append(
|
||||
Diagnostic(
|
||||
file_path=path,
|
||||
location=location,
|
||||
type=type,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
def for_file(self, path: Optional[str]) -> FileReporter:
|
||||
return FileReporter(self, path)
|
||||
|
||||
|
||||
class FileReporter:
|
||||
def __init__(self, base_reporter: Reporter, path: Optional[str]) -> None:
|
||||
self.base_reporter: Reporter = base_reporter
|
||||
self.path: Optional[str] = path
|
||||
|
||||
def for_file(self, path: Optional[str]) -> FileReporter:
|
||||
return FileReporter(self.base_reporter, path)
|
||||
|
||||
def report(self, type: DiagnosticType, location: Location, message: str):
|
||||
self.base_reporter.report(self.path, type, location, message)
|
||||
|
||||
def error(self, location: Location, message: str):
|
||||
self.report(
|
||||
type=DiagnosticType.ERROR,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def warning(self, location: Location, message: str):
|
||||
self.report(
|
||||
type=DiagnosticType.WARNING,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def info(self, location: Location, message: str):
|
||||
self.report(
|
||||
type=DiagnosticType.INFO,
|
||||
location=location,
|
||||
message=message,
|
||||
)
|
||||
@@ -13,7 +13,7 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
|
||||
def __init__(self):
|
||||
self.locals: dict[p.Expr, int] = {}
|
||||
self.scopes: list[dict[str, bool]] = []
|
||||
self.scopes: list[dict[str, bool]] = [{}]
|
||||
|
||||
def resolve(self, *objects: p.Stmt | p.Expr) -> None:
|
||||
"""Resolve the given statements or expressions"""
|
||||
@@ -77,6 +77,12 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
self.locals[expr] = i
|
||||
return
|
||||
|
||||
def is_defined(self, name: str) -> bool:
|
||||
for scope in self.scopes:
|
||||
if name in scope:
|
||||
return True
|
||||
return False
|
||||
|
||||
def resolve_function(self, function: p.Function) -> None:
|
||||
"""Resolve a function definition
|
||||
|
||||
@@ -112,8 +118,13 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
for target in stmt.targets:
|
||||
match target:
|
||||
case p.VariableExpr(name=name):
|
||||
self.resolve_local(target, name)
|
||||
# TODO: declare if not found
|
||||
if not self.is_defined(name):
|
||||
self.declare(name)
|
||||
self.define(name)
|
||||
target.accept(self)
|
||||
|
||||
case p.GetExpr():
|
||||
target.accept(self)
|
||||
case _:
|
||||
raise Exception(f"Unsupported assignment to {target}")
|
||||
|
||||
@@ -174,10 +185,6 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
self.resolve(expr.left)
|
||||
self.resolve(expr.right)
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> None:
|
||||
self.resolve(expr.value)
|
||||
self.resolve(expr.object)
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None:
|
||||
self.resolve(expr.expr)
|
||||
|
||||
@@ -185,3 +192,19 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
||||
self.resolve(expr.test)
|
||||
self.resolve(expr.if_true)
|
||||
self.resolve(expr.if_false)
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||
for item in expr.items:
|
||||
self.resolve(item)
|
||||
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||
self.resolve(expr.object)
|
||||
self.resolve(expr.index)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
|
||||
if expr.lower is not None:
|
||||
self.resolve(expr.lower)
|
||||
if expr.upper is not None:
|
||||
self.resolve(expr.upper)
|
||||
if expr.step is not None:
|
||||
self.resolve(expr.step)
|
||||
@@ -1,54 +1,233 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TopType:
|
||||
def __str__(self) -> str:
|
||||
return "Any"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class BaseType:
|
||||
name: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AliasType:
|
||||
name: str
|
||||
type: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnknownType:
|
||||
pass
|
||||
def __str__(self) -> str:
|
||||
return "<Unknown>"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnitType:
|
||||
pass
|
||||
def __str__(self) -> str:
|
||||
return "None"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Function:
|
||||
name: str
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
returns: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
args: list[str] = []
|
||||
if len(self.pos_args) != 0:
|
||||
args += list(map(str, self.pos_args))
|
||||
if len(self.args) + len(self.kw_args) != 0:
|
||||
args.append("/")
|
||||
|
||||
if len(self.args) != 0:
|
||||
args += list(map(str, self.args))
|
||||
|
||||
if len(self.kw_args) != 0:
|
||||
if len(args) != 0:
|
||||
args.append("*")
|
||||
args += list(map(str, self.kw_args))
|
||||
|
||||
return f"({', '.join(args)}) -> {self.returns}"
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Argument:
|
||||
pos: int
|
||||
name: str
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
def __str__(self) -> str:
|
||||
opt: str = "" if self.required else "?"
|
||||
return f"{self.name}: {self.type}{opt}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OverloadedFunction:
|
||||
overloads: list[Type]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "<overloaded function>"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ComplexType:
|
||||
properties: dict[str, Type]
|
||||
members: dict[str, Type]
|
||||
|
||||
def __str__(self) -> str:
|
||||
props: list[str] = [f"{name}: {type}" for name, type in self.members.items()]
|
||||
return f"{{{', '.join(props)}}}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class UnionType:
|
||||
alternatives: list[Type]
|
||||
class ExtensionType:
|
||||
base: Type
|
||||
extension: ComplexType
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.base} & {self.extension}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TypeVar:
|
||||
name: str
|
||||
bound: Optional[Type]
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.bound is not None:
|
||||
return f"{self.name} <: {self.bound}"
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class GenericType:
|
||||
name: str
|
||||
params: list[TypeVar]
|
||||
body: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}[{', '.join(map(str, self.params))}]"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AppliedType:
|
||||
name: str
|
||||
args: list[Type]
|
||||
body: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
||||
|
||||
|
||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||
def sub_argument(arg: Function.Argument):
|
||||
return Function.Argument(
|
||||
pos=arg.pos,
|
||||
name=arg.name,
|
||||
type=substitute_typevars(arg.type, substitutions),
|
||||
required=arg.required,
|
||||
)
|
||||
|
||||
match type:
|
||||
case BaseType(name=name) if name in substitutions:
|
||||
return substitutions[name]
|
||||
|
||||
case BaseType():
|
||||
return type
|
||||
|
||||
case AliasType(name=name, type=type2):
|
||||
return AliasType(name=name, type=substitute_typevars(type2, substitutions))
|
||||
|
||||
case Function(
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=returns,
|
||||
):
|
||||
return Function(
|
||||
pos_args=list(map(sub_argument, pos_args)),
|
||||
args=list(map(sub_argument, args)),
|
||||
kw_args=list(map(sub_argument, kw_args)),
|
||||
returns=substitute_typevars(returns, substitutions),
|
||||
)
|
||||
|
||||
case OverloadedFunction(overloads=overloads):
|
||||
return OverloadedFunction(
|
||||
overloads=[
|
||||
substitute_typevars(overload, substitutions)
|
||||
for overload in overloads
|
||||
]
|
||||
)
|
||||
|
||||
case ComplexType(members=members):
|
||||
members2: dict[str, Type] = {
|
||||
name: substitute_typevars(prop, substitutions)
|
||||
for name, prop in members.items()
|
||||
}
|
||||
return ComplexType(members=members2)
|
||||
|
||||
case ExtensionType(base=base, extension=ComplexType(members=members)):
|
||||
return ExtensionType(
|
||||
base=substitute_typevars(base, substitutions),
|
||||
extension=ComplexType(
|
||||
members={
|
||||
name: substitute_typevars(prop, substitutions)
|
||||
for name, prop in members.items()
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
case AppliedType(name=name, args=args, body=body):
|
||||
return AppliedType(
|
||||
name=name,
|
||||
args=[substitute_typevars(arg, substitutions) for arg in args],
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case TypeVar(name=name):
|
||||
if name in substitutions:
|
||||
return substitutions[name]
|
||||
raise ValueError(f"Missing TypeVar substitution for {name}")
|
||||
|
||||
case UnknownType() | UnitType():
|
||||
return type
|
||||
|
||||
case _:
|
||||
raise NotImplementedError(f"Unsupported type {type}")
|
||||
|
||||
|
||||
def unfold_type(type: Type) -> Type:
|
||||
match type:
|
||||
case AliasType(type=ref_type):
|
||||
return unfold_type(ref_type)
|
||||
case _:
|
||||
return type
|
||||
|
||||
|
||||
Type = (
|
||||
BaseType | AliasType | UnknownType | UnitType | Function | ComplexType | UnionType
|
||||
TopType
|
||||
| BaseType
|
||||
| AliasType
|
||||
| UnknownType
|
||||
| UnitType
|
||||
| Function
|
||||
| OverloadedFunction
|
||||
| ComplexType
|
||||
| ExtensionType
|
||||
| TypeVar
|
||||
| GenericType
|
||||
| AppliedType
|
||||
)
|
||||
|
||||
41
midas/cli/ansi.py
Normal file
41
midas/cli/ansi.py
Normal file
@@ -0,0 +1,41 @@
|
||||
class Ansi:
|
||||
CTRL = "\x1b["
|
||||
RESET = CTRL + "0m"
|
||||
BOLD = CTRL + "1m"
|
||||
DIM = CTRL + "2m"
|
||||
ITALIC = CTRL + "3m"
|
||||
UNDERLINE = CTRL + "4m"
|
||||
|
||||
BLACK = 0
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
YELLOW = 3
|
||||
BLUE = 4
|
||||
MAGENTA = 5
|
||||
CYAN = 6
|
||||
WHITE = 7
|
||||
|
||||
BRIGHT_BLACK = 60
|
||||
BRIGHT_RED = 61
|
||||
BRIGHT_GREEN = 62
|
||||
BRIGHT_YELLOW = 63
|
||||
BRIGHT_BLUE = 64
|
||||
BRIGHT_MAGENTA = 65
|
||||
BRIGHT_CYAN = 66
|
||||
BRIGHT_WHITE = 67
|
||||
|
||||
@classmethod
|
||||
def FG(cls, col: int) -> str:
|
||||
return f"{cls.CTRL}{30 + col}m"
|
||||
|
||||
@classmethod
|
||||
def BG(cls, col: int) -> str:
|
||||
return f"{cls.CTRL}{40 + col}m"
|
||||
|
||||
@classmethod
|
||||
def FG_RGB(cls, r: int, g: int, b: int) -> str:
|
||||
return f"{cls.CTRL}38;2;{r};{g};{b}m"
|
||||
|
||||
@classmethod
|
||||
def BG_RGB(cls, r: int, g: int, b: int) -> str:
|
||||
return f"{cls.CTRL}48;2;{r};{g};{b}m"
|
||||
@@ -53,5 +53,6 @@ span {
|
||||
|
||||
&.keyword {
|
||||
color: rgb(211, 72, 9);
|
||||
pointer-events: none;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generic, Optional, Protocol, TextIO, TypeVar
|
||||
|
||||
@@ -8,6 +9,7 @@ import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.lexer.token import Token
|
||||
|
||||
H = TypeVar("H", bound="Highlighter", contravariant=True)
|
||||
|
||||
@@ -22,6 +24,15 @@ class Locatable(Protocol):
|
||||
def location(self) -> Optional[Location]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LocatableToken:
|
||||
token: Token
|
||||
|
||||
@property
|
||||
def location(self) -> Location:
|
||||
return self.token.get_location()
|
||||
|
||||
|
||||
class Highlighter(ABC):
|
||||
BASE_CSS_PATH: Path = Path(__file__).parent / "highlight.css"
|
||||
EXTRA_CSS_PATH: Optional[Path] = None
|
||||
@@ -199,61 +210,55 @@ class PythonHighlighter(
|
||||
|
||||
def visit_logical_expr(self, expr: p.LogicalExpr) -> None: ...
|
||||
|
||||
def visit_set_expr(self, expr: p.SetExpr) -> None: ...
|
||||
|
||||
def visit_cast_expr(self, expr: p.CastExpr) -> None: ...
|
||||
|
||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None: ...
|
||||
|
||||
def visit_list_expr(self, expr: p.ListExpr) -> None:
|
||||
for item in expr.items:
|
||||
item.accept(self)
|
||||
|
||||
class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
|
||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||
expr.object.accept(self)
|
||||
expr.index.accept(self)
|
||||
|
||||
def visit_slice_expr(self, expr: p.SliceExpr) -> None:
|
||||
if expr.lower is not None:
|
||||
expr.lower.accept(self)
|
||||
if expr.upper is not None:
|
||||
expr.upper.accept(self)
|
||||
if expr.step is not None:
|
||||
expr.step.accept(self)
|
||||
|
||||
|
||||
class MidasHighlighter(
|
||||
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
|
||||
):
|
||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_midas.css"
|
||||
|
||||
def highlight(self, node: Highlightable[MidasHighlighter]):
|
||||
node.accept(self)
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: m.SimpleTypeStmt) -> None:
|
||||
self.wrap(stmt, "simple-type")
|
||||
if stmt.template is not None:
|
||||
stmt.template.accept(self)
|
||||
stmt.base.accept(self)
|
||||
if stmt.constraint is not None:
|
||||
self.wrap(stmt.constraint, "constraint")
|
||||
stmt.constraint.accept(self)
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: m.ComplexTypeStmt) -> None:
|
||||
self.wrap(stmt, "complex-type")
|
||||
if stmt.template is not None:
|
||||
stmt.template.accept(self)
|
||||
for prop in stmt.properties:
|
||||
prop.accept(self)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None:
|
||||
self.wrap(stmt, "property")
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
self.wrap(stmt, "type-stmt")
|
||||
self.wrap(LocatableToken(stmt.name), "type-name")
|
||||
stmt.type.accept(self)
|
||||
|
||||
def visit_member_stmt(self, stmt: m.MemberStmt) -> None:
|
||||
self.wrap(stmt, "member")
|
||||
stmt.type.accept(self)
|
||||
if stmt.constraint is not None:
|
||||
self.wrap(stmt.constraint, "constraint")
|
||||
stmt.constraint.accept(self)
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
self.wrap(stmt, "extend")
|
||||
stmt.type.accept(self)
|
||||
for op in stmt.operations:
|
||||
op.accept(self)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None:
|
||||
self.wrap(stmt, "op")
|
||||
stmt.operand.accept(self)
|
||||
stmt.result.accept(self)
|
||||
for member in stmt.members:
|
||||
member.accept(self)
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||
self.wrap(stmt, "predicate")
|
||||
self.wrap(LocatableToken(stmt.name), "predicate-name")
|
||||
stmt.type.accept(self)
|
||||
stmt.condition.accept(self)
|
||||
|
||||
def visit_simple_type_expr(self, expr: m.SimpleTypeExpr) -> None:
|
||||
self.wrap(expr, "simple-type-expr")
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||
self.wrap(expr, "logical-expr")
|
||||
expr.left.accept(self)
|
||||
@@ -282,14 +287,35 @@ class MidasHighlighter(Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None]):
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||
|
||||
def visit_template_expr(self, expr: m.TemplateExpr) -> None:
|
||||
self.wrap(expr, "template")
|
||||
expr.type.accept(self)
|
||||
def visit_named_type(self, type: m.NamedType) -> None:
|
||||
self.wrap(type, "named-type")
|
||||
|
||||
def visit_type_expr(self, expr: m.TypeExpr) -> None:
|
||||
self.wrap(expr, "type")
|
||||
if expr.template is not None:
|
||||
expr.template.accept(self)
|
||||
def visit_generic_type(self, type: m.GenericType) -> None:
|
||||
self.wrap(type, "generic-type")
|
||||
type.type.accept(self)
|
||||
for arg in type.args:
|
||||
arg.accept(self)
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> None:
|
||||
self.wrap(type, "constraint-type")
|
||||
type.type.accept(self)
|
||||
type.constraint.accept(self)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> None:
|
||||
self.wrap(type, "complex-type")
|
||||
for member in type.members:
|
||||
member.accept(self)
|
||||
|
||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||
self.wrap(type, "function")
|
||||
for arg in type.pos_args + type.args + type.kw_args:
|
||||
arg.type.accept(self)
|
||||
type.returns.accept(self)
|
||||
|
||||
def visit_extension_type(self, type: m.ExtensionType) -> None:
|
||||
self.wrap(type, "extension")
|
||||
type.base.accept(self)
|
||||
type.extension.accept(self)
|
||||
|
||||
|
||||
class DiagnosticsHighlighter(Highlighter):
|
||||
|
||||
@@ -5,12 +5,11 @@ span {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
&.simple-type {
|
||||
--col: 108, 233, 108;
|
||||
}
|
||||
|
||||
&.named-type,
|
||||
&.generic-type,
|
||||
&.constraint-type,
|
||||
&.complex-type {
|
||||
--col: 233, 206, 108;
|
||||
--col: 150, 150, 150;
|
||||
}
|
||||
|
||||
&.constraint {
|
||||
@@ -33,10 +32,6 @@ span {
|
||||
--col: 193, 108, 233;
|
||||
}
|
||||
|
||||
&.simple-type-expr {
|
||||
--col: 150, 150, 150;
|
||||
}
|
||||
|
||||
&.logical-expr,
|
||||
&.binary-expr,
|
||||
&.unary-expr,
|
||||
@@ -48,7 +43,9 @@ span {
|
||||
--col: 163, 117, 71;
|
||||
}
|
||||
|
||||
&.type {
|
||||
&.type-name,
|
||||
&.op-name,
|
||||
&.predicate-name {
|
||||
--col: 200, 200, 200;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, TextIO, get_args
|
||||
|
||||
@@ -10,13 +9,15 @@ import click
|
||||
import midas.ast.midas as m
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.ast.printer import MidasAstPrinter, PythonAstPrinter
|
||||
from midas.checker.checker import Checker
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.ast.printer import MidasAstPrinter, MidasPrinter, PythonAstPrinter
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.checker.diagnostic import Diagnostic, DiagnosticType
|
||||
from midas.checker.types import Type
|
||||
from midas.cli.ansi import Ansi
|
||||
from midas.cli.highlighter import (
|
||||
DiagnosticsHighlighter,
|
||||
Highlighter,
|
||||
LocatableToken,
|
||||
MidasHighlighter,
|
||||
PythonHighlighter,
|
||||
)
|
||||
@@ -24,41 +25,126 @@ from midas.lexer.midas import MidasLexer
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.parser.midas import MidasParser
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.resolver.resolver import Resolver
|
||||
from midas.utils import UniversalJSONDumper
|
||||
|
||||
|
||||
@click.group()
|
||||
def midas():
|
||||
click.echo("Welcome to Midas!")
|
||||
pass
|
||||
|
||||
|
||||
def print_diagnostic(lines: list[str], diagnostic: Diagnostic, indent: int = 4):
|
||||
"""Pretty-print a diagnostic, showing some context if possible
|
||||
|
||||
If the diagnostic concerns a specific part of one line, the line is shown
|
||||
with the affected part highlighted. The message is clearly printed under the
|
||||
line with an underline further indicating the target expression.
|
||||
|
||||
If multiple lines are concerned, no context is shown, only the
|
||||
diagnostic type, location and message
|
||||
|
||||
Args:
|
||||
lines (list[str]): source code lines
|
||||
diagnostic (Diagnostic): the diagnostic to print
|
||||
indent (int, optional): the number of spaces added before the target line to indent if from the location header. Defaults to 4.
|
||||
"""
|
||||
|
||||
loc: Location = diagnostic.location
|
||||
if loc.lineno != loc.end_lineno:
|
||||
print(diagnostic)
|
||||
return
|
||||
|
||||
start_offset: int = loc.col_offset
|
||||
end_offset: int = loc.end_col_offset or (start_offset + 1)
|
||||
|
||||
line: str = lines[loc.lineno - 1]
|
||||
before: str = line[:start_offset]
|
||||
after: str = line[end_offset:]
|
||||
|
||||
color: int = {
|
||||
DiagnosticType.ERROR: Ansi.RED,
|
||||
DiagnosticType.WARNING: Ansi.YELLOW,
|
||||
DiagnosticType.INFO: Ansi.CYAN,
|
||||
}.get(diagnostic.type, Ansi.WHITE)
|
||||
|
||||
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
|
||||
cursor: str = (
|
||||
" " * start_offset
|
||||
+ Ansi.FG(color)
|
||||
+ "~" * (end_offset - start_offset)
|
||||
+ "> "
|
||||
+ diagnostic.message
|
||||
+ Ansi.RESET
|
||||
)
|
||||
|
||||
indent_str: str = " " * indent
|
||||
print(diagnostic.location_str + ":")
|
||||
print(indent_str + before + subject + after)
|
||||
print(indent_str + cursor)
|
||||
print()
|
||||
|
||||
|
||||
@midas.command()
|
||||
@click.option("-l", "--highlight", type=click.File("w"))
|
||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||
@click.option("-v", "--verbose", is_flag=True)
|
||||
@click.option("-j", "--show-judgements", is_flag=True)
|
||||
@click.argument("file", type=click.File("r"))
|
||||
def compile(highlight: Optional[TextIO], file: TextIO):
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
def compile(
|
||||
highlight: Optional[TextIO],
|
||||
types: tuple[TextIO],
|
||||
verbose: bool,
|
||||
show_judgements: bool,
|
||||
file: TextIO,
|
||||
):
|
||||
logging.basicConfig(level=logging.DEBUG if verbose else logging.WARN)
|
||||
source: str = file.read()
|
||||
tree: ast.Module = ast.parse(source, filename=file.name)
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
resolver = Resolver()
|
||||
resolver.resolve(*stmts)
|
||||
checker = Checker(resolver.locals, file_path=Path(file.name).resolve())
|
||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
||||
for diagnostic in diagnostics:
|
||||
print(diagnostic)
|
||||
source_path: Path = Path(file.name).resolve()
|
||||
|
||||
print(
|
||||
json.dumps(
|
||||
UniversalJSONDumper.dump(
|
||||
checker.global_env,
|
||||
[("Environment", "_children")],
|
||||
lambda obj: isinstance(obj, get_args(Type)),
|
||||
),
|
||||
indent=4,
|
||||
checker = TypeChecker()
|
||||
for types_file in types:
|
||||
checker.import_midas(Path(types_file.name).resolve())
|
||||
|
||||
checker.type_check_source(source, str(source_path))
|
||||
diagnostics: list[Diagnostic] = checker.diagnostics.copy()
|
||||
lines: list[str] = source.split("\n")
|
||||
files: dict[Optional[str], list[str]] = {None: []}
|
||||
|
||||
if show_judgements:
|
||||
for expr, type in checker.python_typer.judgements:
|
||||
print(f"Judged that {expr} at {expr.location} is of type {type}")
|
||||
diagnostics.append(
|
||||
Diagnostic(
|
||||
file_path=str(source_path),
|
||||
location=expr.location,
|
||||
type=DiagnosticType.INFO,
|
||||
message=f"Type: {type}",
|
||||
)
|
||||
)
|
||||
|
||||
for diagnostic in diagnostics:
|
||||
filename: Optional[str] = diagnostic.file_path
|
||||
if filename is not None and filename not in files:
|
||||
path: Path = Path(filename)
|
||||
if path.exists() and path.is_file():
|
||||
files[filename] = path.read_text().split("\n")
|
||||
else:
|
||||
files[filename] = []
|
||||
|
||||
lines: list[str] = files[filename]
|
||||
print_diagnostic(lines, diagnostic)
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
json.dumps(
|
||||
UniversalJSONDumper.dump(
|
||||
checker.python_typer.global_env,
|
||||
[("Environment", "_children")],
|
||||
lambda obj: isinstance(obj, get_args(Type)),
|
||||
),
|
||||
indent=4,
|
||||
)
|
||||
)
|
||||
)
|
||||
if highlight is not None:
|
||||
highlighter = DiagnosticsHighlighter(source)
|
||||
highlighter.highlight(diagnostics)
|
||||
@@ -142,14 +228,6 @@ def highlight_midas(source: str, path: str) -> Highlighter:
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LocatableToken:
|
||||
token: Token
|
||||
|
||||
@property
|
||||
def location(self) -> Location:
|
||||
return self.token.get_location()
|
||||
|
||||
for stmt in stmts:
|
||||
highlighter.highlight(stmt)
|
||||
for token in tokens:
|
||||
@@ -176,5 +254,21 @@ def highlight(output: TextIO, file: TextIO):
|
||||
highlighter.dump(output)
|
||||
|
||||
|
||||
@midas.command()
|
||||
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||
@click.argument("file", type=click.File("r"))
|
||||
def format(output: TextIO, file: TextIO):
|
||||
source: str = file.read()
|
||||
printer = MidasPrinter()
|
||||
lexer = MidasLexer(source, file=file.name)
|
||||
tokens: list[Token] = lexer.process()
|
||||
parser = MidasParser(tokens)
|
||||
stmts: list[m.Stmt] = parser.parse()
|
||||
for err in parser.errors:
|
||||
print(err.get_report())
|
||||
for stmt in stmts:
|
||||
output.write(printer.print(stmt) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
midas()
|
||||
|
||||
@@ -18,8 +18,6 @@ class MidasLexer(Lexer):
|
||||
self.add_token(TokenType.LEFT_BRACE)
|
||||
case "}":
|
||||
self.add_token(TokenType.RIGHT_BRACE)
|
||||
case "|":
|
||||
self.add_token(TokenType.PIPE)
|
||||
case "<":
|
||||
self.add_token(
|
||||
TokenType.LESS_EQUAL if self.match("=") else TokenType.LESS
|
||||
@@ -52,12 +50,14 @@ class MidasLexer(Lexer):
|
||||
# self.add_token(TokenType.PLUS)
|
||||
case "-":
|
||||
self.add_token(TokenType.MINUS)
|
||||
# case "*":
|
||||
# self.add_token(TokenType.STAR)
|
||||
case "*":
|
||||
self.add_token(TokenType.STAR)
|
||||
case "/" if self.match("/"):
|
||||
self.scan_comment()
|
||||
case "/" if self.match("*"):
|
||||
self.scan_comment_multiline()
|
||||
case "/":
|
||||
self.add_token(TokenType.SLASH)
|
||||
case "\n":
|
||||
self.add_token(TokenType.NEWLINE)
|
||||
case " " | "\r" | "\t":
|
||||
|
||||
@@ -23,13 +23,12 @@ class TokenType(Enum):
|
||||
AND = auto()
|
||||
QMARK = auto()
|
||||
DOT = auto()
|
||||
PIPE = auto()
|
||||
|
||||
# Operators
|
||||
# PLUS = auto()
|
||||
MINUS = auto()
|
||||
# STAR = auto()
|
||||
# SLASH = auto()
|
||||
STAR = auto()
|
||||
SLASH = auto()
|
||||
GREATER = auto()
|
||||
GREATER_EQUAL = auto()
|
||||
LESS = auto()
|
||||
@@ -47,10 +46,12 @@ class TokenType(Enum):
|
||||
|
||||
# Keywords
|
||||
TYPE = auto()
|
||||
OP = auto()
|
||||
PREDICATE = auto()
|
||||
EXTEND = auto()
|
||||
WHERE = auto()
|
||||
PROP = auto()
|
||||
DEF = auto()
|
||||
FUNC = auto()
|
||||
|
||||
# Misc
|
||||
COMMENT = auto()
|
||||
@@ -61,13 +62,15 @@ class TokenType(Enum):
|
||||
|
||||
KEYWORDS: dict[str, TokenType] = {
|
||||
"type": TokenType.TYPE,
|
||||
"op": TokenType.OP,
|
||||
"predicate": TokenType.PREDICATE,
|
||||
"extend": TokenType.EXTEND,
|
||||
"where": TokenType.WHERE,
|
||||
"true": TokenType.TRUE,
|
||||
"false": TokenType.FALSE,
|
||||
"none": TokenType.NONE,
|
||||
"prop": TokenType.PROP,
|
||||
"def": TokenType.DEF,
|
||||
"fn": TokenType.FUNC,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -7,24 +7,26 @@ from midas.ast.midas import (
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
ExtensionType,
|
||||
FunctionType,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MemberKind,
|
||||
MemberStmt,
|
||||
NamedType,
|
||||
OpStmt,
|
||||
PredicateStmt,
|
||||
PropertyStmt,
|
||||
Stmt,
|
||||
Type,
|
||||
TypeParam,
|
||||
TypeStmt,
|
||||
UnaryExpr,
|
||||
UnionType,
|
||||
VariableExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
from midas.lexer.token import Token, TokenType
|
||||
from midas.lexer.token import KEYWORDS, Token, TokenType
|
||||
from midas.parser.base import Parser
|
||||
from midas.parser.errors import ParsingError
|
||||
|
||||
@@ -34,9 +36,10 @@ class MidasParser(Parser):
|
||||
|
||||
SYNC_BOUNDARY: set[TokenType] = {
|
||||
TokenType.TYPE,
|
||||
TokenType.OP,
|
||||
TokenType.EXTEND,
|
||||
TokenType.PREDICATE,
|
||||
TokenType.PROP,
|
||||
TokenType.FUNC,
|
||||
}
|
||||
|
||||
def parse(self) -> list[Stmt]:
|
||||
@@ -108,10 +111,8 @@ class MidasParser(Parser):
|
||||
TypeStmt: the parsed type declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
params: list[TypeStmt.Param] = []
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
params = self.type_stmt_params()
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
params: list[TypeParam] = self.type_params()
|
||||
|
||||
self.consume(TokenType.EQUAL, "Expected '=' before type definition")
|
||||
|
||||
@@ -124,24 +125,27 @@ class MidasParser(Parser):
|
||||
type=type,
|
||||
)
|
||||
|
||||
def type_stmt_params(self) -> list[TypeStmt.Param]:
|
||||
"""Parse a generic template expression
|
||||
def type_params(self) -> list[TypeParam]:
|
||||
"""Parse a list of type parameters
|
||||
|
||||
A template is written `[TypeExpr]`
|
||||
Type parameters are a comma-separated list of type variables wrapped in brackets.
|
||||
Each type variable is either a simple variable, or a bounded variable written `S <: T`
|
||||
|
||||
Returns:
|
||||
TemplateExpr: the parsed template expression
|
||||
list[TypeParam]: the list of type parameters, if any, or an empty list
|
||||
"""
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before template expression")
|
||||
params: list[TypeStmt.Param] = []
|
||||
if not self.match(TokenType.LEFT_BRACKET):
|
||||
return []
|
||||
|
||||
params: list[TypeParam] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type variable")
|
||||
name: Token = self.consume_identifier("Expected type variable")
|
||||
bound: Optional[Type] = None
|
||||
if self.match(TokenType.LESS):
|
||||
self.consume(TokenType.COLON, "Expected ':' after '<'")
|
||||
bound = self.type_expr()
|
||||
params.append(
|
||||
TypeStmt.Param(
|
||||
TypeParam(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
bound=bound,
|
||||
@@ -149,7 +153,7 @@ class MidasParser(Parser):
|
||||
)
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after template expression")
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after type parameters")
|
||||
return params
|
||||
|
||||
def type_expr(self) -> Type:
|
||||
@@ -161,18 +165,19 @@ class MidasParser(Parser):
|
||||
Returns:
|
||||
TypeExpr: the parsed type expression
|
||||
"""
|
||||
return self.union_type()
|
||||
|
||||
def union_type(self) -> Type:
|
||||
types: list[Type] = [self.constraint_type()]
|
||||
while self.match(TokenType.PIPE):
|
||||
types.append(self.constraint_type())
|
||||
if len(types) == 1:
|
||||
return types[0]
|
||||
return UnionType(
|
||||
location=Location.span(types[0].location, types[-1].location),
|
||||
types=types,
|
||||
)
|
||||
base: Type
|
||||
if self.match(TokenType.FUNC):
|
||||
base = self.function()
|
||||
else:
|
||||
base = self.constraint_type()
|
||||
if self.match(TokenType.AND):
|
||||
extension: ComplexType = self.complex_type()
|
||||
return ExtensionType(
|
||||
location=Location.span(base.location, extension.location),
|
||||
base=base,
|
||||
extension=extension,
|
||||
)
|
||||
return base
|
||||
|
||||
def constraint_type(self) -> Type:
|
||||
type: Type = self.base_type()
|
||||
@@ -199,55 +204,57 @@ class MidasParser(Parser):
|
||||
def generic_type(self) -> Type:
|
||||
type: Type = self.named_type()
|
||||
if self.check(TokenType.LEFT_BRACKET):
|
||||
params: list[Type] = self.type_params()
|
||||
args: list[Type] = self.type_args()
|
||||
return GenericType(
|
||||
location=Location.span(type.location, self.previous().get_location()),
|
||||
type=type,
|
||||
params=params,
|
||||
args=args,
|
||||
)
|
||||
return type
|
||||
|
||||
def type_params(self) -> list[Type]:
|
||||
params: list[Type] = []
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic parameters")
|
||||
def type_args(self) -> list[Type]:
|
||||
args: list[Type] = []
|
||||
self.consume(TokenType.LEFT_BRACKET, "Missing '[' before generic arguments")
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACKET):
|
||||
params.append(self.type_expr())
|
||||
args.append(self.type_expr())
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic parameters")
|
||||
return params
|
||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
|
||||
return args
|
||||
|
||||
def named_type(self) -> Type:
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected type name")
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
return NamedType(
|
||||
location=name.get_location(),
|
||||
name=name,
|
||||
)
|
||||
|
||||
def complex_type(self) -> Type:
|
||||
def complex_type(self) -> ComplexType:
|
||||
"""Parse a type definition body
|
||||
|
||||
A type definition body is a set of whitespace-separated
|
||||
property statements enclosed in curly braces
|
||||
|
||||
Returns:
|
||||
list[PropertyStmt]: the parsed type properties
|
||||
ComplexType: the parsed complex type
|
||||
"""
|
||||
left: Token = self.consume(
|
||||
TokenType.LEFT_BRACE, "Expected '{' to start type body"
|
||||
)
|
||||
properties: list[PropertyStmt] = []
|
||||
members: list[MemberStmt] = []
|
||||
# TODO: add keyword to differentiate properties and methods,
|
||||
# and allow multiple methods with the same name but not properties
|
||||
names: set[str] = set()
|
||||
while not self.check(TokenType.RIGHT_BRACE) and not self.is_at_end():
|
||||
prop: PropertyStmt = self.property_stmt()
|
||||
if prop.name.lexeme in names:
|
||||
raise self.error(prop.name, "Duplicate property")
|
||||
names.add(prop.name.lexeme)
|
||||
properties.append(prop)
|
||||
member: MemberStmt = self.member_stmt()
|
||||
# if member.name.lexeme in names:
|
||||
# raise self.error(member.name, "Duplicate property")
|
||||
# names.add(member.name.lexeme)
|
||||
members.append(member)
|
||||
right: Token = self.consume(TokenType.RIGHT_BRACE, "Unclosed type body")
|
||||
return ComplexType(
|
||||
location=left.location_to(right),
|
||||
properties=properties,
|
||||
members=members,
|
||||
)
|
||||
|
||||
def constraint(self) -> Expr:
|
||||
@@ -334,9 +341,7 @@ class MidasParser(Parser):
|
||||
"""
|
||||
expr: Expr = self.primary()
|
||||
while self.match(TokenType.DOT):
|
||||
name: Token = self.consume(
|
||||
TokenType.IDENTIFIER, "Expected property name after '.'"
|
||||
)
|
||||
name: Token = self.consume_identifier("Expected property name after '.'")
|
||||
location: Location = Location.span(expr.location, name.get_location())
|
||||
expr = GetExpr(location=location, expr=expr, name=name)
|
||||
return expr
|
||||
@@ -360,7 +365,7 @@ class MidasParser(Parser):
|
||||
if self.match(TokenType.NUMBER):
|
||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||
|
||||
if self.match(TokenType.IDENTIFIER):
|
||||
if self.match_identifier():
|
||||
return VariableExpr(location=token.get_location(), name=token)
|
||||
|
||||
if self.match(TokenType.UNDERSCORE):
|
||||
@@ -373,64 +378,70 @@ class MidasParser(Parser):
|
||||
|
||||
raise self.error(self.peek(), "Expected expression")
|
||||
|
||||
def property_stmt(self) -> PropertyStmt:
|
||||
"""Parse a property statement
|
||||
def consume_identifier(self, message: str = "Expected identifier") -> Token:
|
||||
if not self.match_identifier():
|
||||
raise self.error(self.peek(), message)
|
||||
return self.previous()
|
||||
|
||||
A type property statement is written `name: Type` or `name: Type where Condition`
|
||||
def match_identifier(self) -> bool:
|
||||
return self.match(TokenType.IDENTIFIER, *KEYWORDS.values())
|
||||
|
||||
def check_identifier(self) -> bool:
|
||||
for tt in [TokenType.IDENTIFIER, *KEYWORDS.values()]:
|
||||
if self.check(tt):
|
||||
return True
|
||||
return False
|
||||
|
||||
def member_stmt(self) -> MemberStmt:
|
||||
"""Parse a member statement
|
||||
|
||||
A type member statement is written `prop name: Type` or `def name: Type`
|
||||
|
||||
Returns:
|
||||
PropertyStmt: the parsed property statement
|
||||
MemberStmt: the parsed member statement
|
||||
"""
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected property name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after property name")
|
||||
kind: MemberKind
|
||||
if self.match(TokenType.PROP):
|
||||
kind = MemberKind.PROPERTY
|
||||
elif self.match(TokenType.DEF):
|
||||
kind = MemberKind.METHOD
|
||||
else:
|
||||
raise self.error(self.peek(), "Expected 'prop' or 'def'")
|
||||
|
||||
name: Token = self.consume_identifier("Expected member name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after member name")
|
||||
|
||||
type: Type = self.type_expr()
|
||||
return PropertyStmt(
|
||||
return MemberStmt(
|
||||
location=name.location_to(self.previous()),
|
||||
name=name,
|
||||
type=type,
|
||||
kind=kind,
|
||||
)
|
||||
|
||||
def extend_declaration(self) -> ExtendStmt:
|
||||
"""Parse an extension definition
|
||||
|
||||
An extension is written `extend Type { operations }`
|
||||
An extension is written `extend Type { operations }` or `extend[S <: T, U] Type { operations }`
|
||||
|
||||
Returns:
|
||||
ExtendStmt: the parsed extension statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
type: Type = self.type_expr()
|
||||
name: Token = self.consume_identifier("Expected type name")
|
||||
params: list[TypeParam] = self.type_params()
|
||||
|
||||
self.consume(TokenType.LEFT_BRACE, "Expected '{' to start extend body")
|
||||
operations: list[OpStmt] = []
|
||||
members: list[MemberStmt] = []
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_BRACE):
|
||||
operations.append(self.op_declaration())
|
||||
members.append(self.member_stmt())
|
||||
self.consume(TokenType.RIGHT_BRACE, "Unclosed extend body")
|
||||
location: Location = keyword.location_to(self.previous())
|
||||
return ExtendStmt(location=location, type=type, operations=operations)
|
||||
|
||||
def op_declaration(self) -> OpStmt:
|
||||
"""Parse an operation definition
|
||||
|
||||
An operation is written `op name(Type) -> Type`
|
||||
|
||||
Returns:
|
||||
OpStmt: the parsed operation statement
|
||||
"""
|
||||
keyword: Token = self.consume(TokenType.OP, "Expected 'op' keyword")
|
||||
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected operation name")
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before operand type")
|
||||
operand: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after operand type")
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return OpStmt(
|
||||
location=keyword.location_to(self.previous()),
|
||||
return ExtendStmt(
|
||||
location=location,
|
||||
name=name,
|
||||
operand=operand,
|
||||
result=result,
|
||||
params=params,
|
||||
members=members,
|
||||
)
|
||||
|
||||
def predicate_declaration(self) -> PredicateStmt:
|
||||
@@ -442,9 +453,9 @@ class MidasParser(Parser):
|
||||
PredicateStmt: the parsed predicate declaration statement
|
||||
"""
|
||||
keyword: Token = self.previous()
|
||||
name: Token = self.consume(TokenType.IDENTIFIER, "Expected predicate name")
|
||||
name: Token = self.consume_identifier("Expected predicate name")
|
||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
||||
subject: Token = self.consume(TokenType.IDENTIFIER, "Expected subject name")
|
||||
subject: Token = self.consume_identifier("Expected subject name")
|
||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
||||
type: Type = self.type_expr()
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
||||
@@ -457,3 +468,72 @@ class MidasParser(Parser):
|
||||
type=type,
|
||||
condition=condition,
|
||||
)
|
||||
|
||||
def function(self) -> FunctionType:
|
||||
l_paren: Token = self.consume(
|
||||
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||
)
|
||||
pos_args: list[FunctionType.Argument] = []
|
||||
args: list[FunctionType.Argument] = []
|
||||
kw_args: list[FunctionType.Argument] = []
|
||||
|
||||
args_first_tokens: list[Token] = []
|
||||
|
||||
section: int = 0
|
||||
while not self.is_at_end() and not self.check(TokenType.RIGHT_PAREN):
|
||||
match section:
|
||||
case 0 if self.match(TokenType.SLASH):
|
||||
pos_args = args
|
||||
args = []
|
||||
args_first_tokens = []
|
||||
section = 1
|
||||
case 0 | 1 if self.match(TokenType.STAR):
|
||||
section = 2
|
||||
case _:
|
||||
# Record first token of mixed argument for errors if unnamed
|
||||
if section != 2:
|
||||
args_first_tokens.append(self.peek())
|
||||
|
||||
name: Optional[Token] = None
|
||||
if section == 2:
|
||||
name = self.consume_identifier("Expected keyword argument name")
|
||||
self.consume(
|
||||
TokenType.COLON, "Expected ':' after argument name"
|
||||
)
|
||||
elif self.check_identifier() and self.check_next(TokenType.COLON):
|
||||
name = self.advance()
|
||||
self.advance()
|
||||
|
||||
type: Type = self.type_expr()
|
||||
optional: bool = self.match(TokenType.QMARK)
|
||||
arg = FunctionType.Argument(
|
||||
location=None,
|
||||
name=name,
|
||||
type=type,
|
||||
required=not optional,
|
||||
)
|
||||
if section == 2:
|
||||
kw_args.append(arg)
|
||||
else:
|
||||
args.append(arg)
|
||||
|
||||
if not self.match(TokenType.COMMA):
|
||||
break
|
||||
|
||||
for arg, token in zip(args, args_first_tokens):
|
||||
if arg.name is None:
|
||||
# Not raised because we can keep parsing
|
||||
self.error(token, "Unnamed mixed argument")
|
||||
|
||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||
|
||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||
result: Type = self.type_expr()
|
||||
|
||||
return FunctionType(
|
||||
location=l_paren.location_to(self.previous()),
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
returns=result,
|
||||
)
|
||||
|
||||
@@ -17,11 +17,14 @@ from midas.ast.python import (
|
||||
Function,
|
||||
GetExpr,
|
||||
IfStmt,
|
||||
ListExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ReturnStmt,
|
||||
SliceExpr,
|
||||
Stmt,
|
||||
SubscriptExpr,
|
||||
TernaryExpr,
|
||||
TypeAssign,
|
||||
UnaryExpr,
|
||||
@@ -87,6 +90,9 @@ class PythonParser:
|
||||
case ast.If():
|
||||
return self.parse_if(node)
|
||||
|
||||
case ast.Pass():
|
||||
return None
|
||||
|
||||
case _:
|
||||
print(f"Unsupported statement: {ast.unparse(node)}")
|
||||
return None
|
||||
@@ -311,6 +317,13 @@ class PythonParser:
|
||||
constraint=right_expr,
|
||||
)
|
||||
|
||||
case ast.Constant(value=None):
|
||||
return BaseType(
|
||||
location=loc,
|
||||
base="None",
|
||||
param=None,
|
||||
)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(type_expr)
|
||||
|
||||
@@ -406,6 +419,27 @@ class PythonParser:
|
||||
case ast.Name(id=name):
|
||||
return VariableExpr(location=location, name=name)
|
||||
|
||||
case ast.List(elts=items):
|
||||
return ListExpr(
|
||||
location=location,
|
||||
items=[self.parse_expr(item) for item in items],
|
||||
)
|
||||
|
||||
case ast.Subscript(value=value, slice=index):
|
||||
return SubscriptExpr(
|
||||
location=location,
|
||||
object=self.parse_expr(value),
|
||||
index=self.parse_expr(index),
|
||||
)
|
||||
|
||||
case ast.Slice(lower=lower, upper=upper, step=step):
|
||||
return SliceExpr(
|
||||
location=location,
|
||||
lower=self.parse_expr(lower) if lower is not None else None,
|
||||
upper=self.parse_expr(upper) if upper is not None else None,
|
||||
step=self.parse_expr(step) if step is not None else None,
|
||||
)
|
||||
|
||||
case _:
|
||||
raise UnsupportedSyntaxError(node)
|
||||
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from midas.checker.types import BaseType, Type, UnitType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.resolver.midas import MidasResolver
|
||||
|
||||
|
||||
def op(ctx: MidasResolver, t1: Type, operator: str, t2: Type, t3: Type):
|
||||
ctx.define_operation(
|
||||
left=t1,
|
||||
operator=operator,
|
||||
right=t2,
|
||||
result=t3,
|
||||
)
|
||||
|
||||
|
||||
def basic_op(ctx: MidasResolver, type: Type, op: str):
|
||||
ctx.define_operation(
|
||||
left=type,
|
||||
operator=op,
|
||||
right=type,
|
||||
result=type,
|
||||
)
|
||||
|
||||
|
||||
def define_builtins(ctx: MidasResolver):
|
||||
"""Define builtin types and operations"""
|
||||
unit = ctx.define_type("None", UnitType())
|
||||
bool = ctx.define_type("bool", BaseType(name="bool"))
|
||||
int = ctx.define_type("int", BaseType(name="int"))
|
||||
float = ctx.define_type("float", BaseType(name="float"))
|
||||
str = ctx.define_type("str", BaseType(name="str"))
|
||||
|
||||
basic_op(ctx, int, "__add__") # int + int = int
|
||||
basic_op(ctx, int, "__sub__") # int - int = int
|
||||
basic_op(ctx, int, "__mul__") # int * int = int
|
||||
basic_op(ctx, int, "__pow__") # int ** int = int
|
||||
basic_op(ctx, int, "__mod__") # int % int = int
|
||||
basic_op(ctx, int, "__and__") # int & int = int
|
||||
basic_op(ctx, int, "__or__") # int | int = int
|
||||
basic_op(ctx, int, "__xor__") # int ^ int = int
|
||||
op(ctx, int, "__lt__", int, bool) # int < int = bool
|
||||
op(ctx, int, "__gt__", int, bool) # int > int = bool
|
||||
op(ctx, int, "__le__", int, bool) # int <= int = bool
|
||||
op(ctx, int, "__ge__", int, bool) # int >= int = bool
|
||||
op(ctx, int, "__eq__", int, bool) # int == int = bool
|
||||
basic_op(ctx, float, "__add__") # float + float = float
|
||||
basic_op(ctx, float, "__sub__") # float - float = float
|
||||
basic_op(ctx, float, "__mul__") # float * float = float
|
||||
basic_op(ctx, float, "__truediv__") # float / float = float
|
||||
op(ctx, float, "__lt__", float, bool) # float < float = bool
|
||||
op(ctx, float, "__gt__", float, bool) # float > float = bool
|
||||
op(ctx, float, "__le__", float, bool) # float <= float = bool
|
||||
op(ctx, float, "__ge__", float, bool) # float >= float = bool
|
||||
op(ctx, float, "__eq__", float, bool) # float == float = bool
|
||||
basic_op(ctx, str, "__add__") # str + str = str
|
||||
op(ctx, str, "__eq__", str, bool) # str == str = bool
|
||||
|
||||
op(ctx, int, "__lt__", float, bool) # int < float = bool
|
||||
op(ctx, int, "__gt__", float, bool) # int > float = bool
|
||||
op(ctx, int, "__le__", float, bool) # int <= float = bool
|
||||
op(ctx, int, "__ge__", float, bool) # int >= float = bool
|
||||
op(ctx, int, "__eq__", float, bool) # int == float = bool
|
||||
|
||||
op(ctx, float, "__lt__", int, bool) # float < int = bool
|
||||
op(ctx, float, "__gt__", int, bool) # float > int = bool
|
||||
op(ctx, float, "__le__", int, bool) # float <= int = bool
|
||||
op(ctx, float, "__ge__", int, bool) # float >= int = bool
|
||||
op(ctx, float, "__eq__", int, bool) # float == int = bool
|
||||
@@ -1,166 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import midas.ast.midas as m
|
||||
from midas.checker.types import (
|
||||
Type,
|
||||
UnionType,
|
||||
UnknownType,
|
||||
)
|
||||
from midas.resolver.builtin import define_builtins
|
||||
|
||||
|
||||
class MidasResolver(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._types: dict[str, Type] = {}
|
||||
self._operations: dict[tuple[Type, str, Type], Type] = {}
|
||||
|
||||
define_builtins(self)
|
||||
|
||||
def get_type(self, name: str) -> Type:
|
||||
"""Get a type from its name
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
|
||||
Raises:
|
||||
NameError: if the type is not defined
|
||||
|
||||
Returns:
|
||||
Type: the type
|
||||
"""
|
||||
type: Optional[Type] = self._types.get(name)
|
||||
if type is None:
|
||||
raise NameError(f"Undefined type {name}")
|
||||
return type
|
||||
|
||||
def get_operation_result(
|
||||
self, left: Type, operator: str, right: Type
|
||||
) -> Optional[Type]:
|
||||
"""Get the resulting type of an operation
|
||||
|
||||
Args:
|
||||
left (Type): the type of the left operand
|
||||
operator (str): the operation name
|
||||
right (Type): the type of the right operand
|
||||
|
||||
Returns:
|
||||
Optional[Type]: the result type, or None if no matching operation was found
|
||||
"""
|
||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
||||
result: Optional[Type] = self._operations.get(operation)
|
||||
return result
|
||||
|
||||
def define_type(self, name: str, type: Type) -> Type:
|
||||
"""Define a type in the registry
|
||||
|
||||
Args:
|
||||
name (str): the name of the type
|
||||
type (Type): the type to define
|
||||
|
||||
Raises:
|
||||
ValueError: if a type is already defined with that name
|
||||
|
||||
Returns:
|
||||
Type: the defined type
|
||||
"""
|
||||
if name in self._types:
|
||||
raise ValueError(f"Type {name} already defined")
|
||||
self._types[name] = type
|
||||
return type
|
||||
|
||||
def define_operation(self, left: Type, operator: str, right: Type, result: Type):
|
||||
"""Define an operation in the registry
|
||||
|
||||
Args:
|
||||
left (Type): the type of the left operand
|
||||
operator (str): the operation name
|
||||
right (Type): the type of the right operand
|
||||
result (Type): the result type
|
||||
|
||||
Raises:
|
||||
ValueError: if an operation is already defined with these operands and name
|
||||
"""
|
||||
operation: tuple[Type, str, Type] = (left, operator, right)
|
||||
if operation in self._operations:
|
||||
raise ValueError(
|
||||
f"Operation {operator} already defined between {left} and {right}"
|
||||
)
|
||||
self._operations[operation] = result
|
||||
|
||||
def resolve(self, stmts: list[m.Stmt]):
|
||||
"""Process a sequence of statements
|
||||
|
||||
Args:
|
||||
stmts (list[m.Stmt]): the statements
|
||||
"""
|
||||
for stmt in stmts:
|
||||
stmt.accept(self)
|
||||
|
||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||
type: Type = stmt.type.accept(self)
|
||||
for param in stmt.params:
|
||||
if param.bound is not None:
|
||||
param.bound.accept(self)
|
||||
self.define_type(stmt.name.lexeme, type)
|
||||
|
||||
def visit_property_stmt(self, stmt: m.PropertyStmt) -> None: ...
|
||||
|
||||
def visit_extend_stmt(self, stmt: m.ExtendStmt) -> None:
|
||||
base: Type = stmt.type.accept(self)
|
||||
for op in stmt.operations:
|
||||
right: Type = op.operand.accept(self)
|
||||
result: Type = op.result.accept(self)
|
||||
self.define_operation(
|
||||
left=base,
|
||||
operator=op.name.lexeme,
|
||||
right=right,
|
||||
result=result,
|
||||
)
|
||||
|
||||
def visit_op_stmt(self, stmt: m.OpStmt) -> None: ...
|
||||
|
||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None: ...
|
||||
|
||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None: ...
|
||||
|
||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None: ...
|
||||
|
||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None: ...
|
||||
|
||||
def visit_get_expr(self, expr: m.GetExpr) -> None: ...
|
||||
|
||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None: ...
|
||||
|
||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
||||
return expr.expr.accept(self)
|
||||
|
||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None: ...
|
||||
|
||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None: ...
|
||||
|
||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||
return self.get_type(type.name.lexeme)
|
||||
|
||||
def visit_generic_type(self, type: m.GenericType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
params: list[Type] = [param.accept(self) for param in type.params]
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||
type_: Type = type.type.accept(self)
|
||||
type.constraint.accept(self)
|
||||
# TODO
|
||||
return UnknownType()
|
||||
|
||||
def visit_union_type(self, type: m.UnionType) -> Type:
|
||||
types: list[Type] = [type_.accept(self) for type_ in type.types]
|
||||
return UnionType(alternatives=types)
|
||||
|
||||
def visit_complex_type(self, type: m.ComplexType) -> Type:
|
||||
for prop in type.properties:
|
||||
prop.accept(self)
|
||||
# TODO
|
||||
return UnknownType()
|
||||
@@ -19,16 +19,24 @@ Comparison ::= Unary (ComparisonOp Unary)*
|
||||
Equality ::= Comparison (EqualityOp Comparison)*
|
||||
Constraint ::= Equality ("&" Equality)*
|
||||
|
||||
SimpleType ::= Identifier "?"?
|
||||
Template ::= "[" Type "]"
|
||||
Type ::= Identifier Template? "?"?
|
||||
TemplateParam ::= Identifier ("<:" Type)?
|
||||
Template ::= "[" (TemplateParam ("," TemplateParam)*)? "]"
|
||||
|
||||
|
||||
TypeProperty ::= Identifier ":" Type
|
||||
ComplexType ::= "{" TypeProperty* "}"
|
||||
NamedType ::= Identifier
|
||||
TypeParams ::= "[" (Type ("," Type)*)? "]"
|
||||
GenericType ::= NamedType TypeParams?
|
||||
GroupedType ::= "(" Type ")"
|
||||
BaseType ::= GroupedType | ComplexType | GenericType
|
||||
ConstraintType ::= BaseType ("where" Constraint)?
|
||||
Type ::= ConstraintType
|
||||
|
||||
TypeProperty ::= Identifier ":" Type ("where" Constraints)?
|
||||
ComplexTypeBody ::= "{" TypeProperty* "}"
|
||||
OpDefinition ::= "op" Identifier "(" Type ")" "->" Type
|
||||
ExtendBody ::= "{" OpDefinition* "}"
|
||||
|
||||
TypeStatement ::= "type" Identifier Template? ("(" Type ")" ("where" Constraint)? | ComplexTypeBody)
|
||||
TypeStatement ::= "type" Identifier Template? "=" Type
|
||||
ExtendStatement ::= "extend" Type ExtendBody
|
||||
PredicateStatement ::= "predicate" Identifier "(" Identifier ":" Type ")" "=" Constraint
|
||||
|
||||
|
||||
@@ -43,28 +43,52 @@ svg.railroad .terminal rect {
|
||||
{[`constraint` 'equality'*"&"]}
|
||||
```
|
||||
|
||||
#let simple-type = ```
|
||||
{[`simple-type` 'identifier' <!, "?">]}
|
||||
#let template-param = ```
|
||||
{[`template-param` 'identifier' <!, ["<:" 'type']>]}
|
||||
```
|
||||
|
||||
#let template = ```
|
||||
{[`template` "[" 'type' "]"]}
|
||||
```
|
||||
|
||||
#let type = ```
|
||||
{[`type` 'identifier' <!, 'template'> <!, "?">]}
|
||||
{[`template` "[" <!, 'template-param'*","> "]"]}
|
||||
```
|
||||
|
||||
#let type-property = ```
|
||||
{[`type-property` 'identifier' ":" 'type' <!, ["where" 'constraint']>]}
|
||||
{[`type-property` 'identifier' ":" 'type']}
|
||||
```
|
||||
|
||||
#let type-body = ```
|
||||
{[`type-body` "{" <!, 'type-property'*!> "}"]}
|
||||
#let complex-type = ```
|
||||
{[`complex-type` "{" <!, 'type-property'*!> "}"]}
|
||||
```
|
||||
|
||||
#let named-type = ```
|
||||
{[`named-type` 'identifier']}
|
||||
```
|
||||
|
||||
#let type-params = ```
|
||||
{[`type-params` "[" <!, 'type'*","> "]"]}
|
||||
```
|
||||
|
||||
#let generic-type = ```
|
||||
{[`generic-type` 'named-type' <!, 'type-params'>]}
|
||||
```
|
||||
|
||||
#let grouped-type = ```
|
||||
{[`grouped-type` "(" 'type' ")"]}
|
||||
```
|
||||
|
||||
#let base-type = ```
|
||||
{[`base-type` <'grouped-type', 'complex-type', 'generic-type'>]}
|
||||
```
|
||||
|
||||
#let constraint-type = ```
|
||||
{[`constraint-type` 'base-type' <!, ["where" 'constraint']>]}
|
||||
```
|
||||
|
||||
#let type = ```
|
||||
{[`type` 'constraint-type']}
|
||||
```
|
||||
|
||||
#let type-statement = ```
|
||||
{[`type-statement` "type" 'identifier' <!, 'template'> <[["(" 'type' ")"] <!, ["where" 'constraint']>], 'type-body'>]}
|
||||
{[`type-statement` "type" 'identifier' <!, 'template'> "=" 'type']}
|
||||
```
|
||||
|
||||
#let op-definition = ```
|
||||
@@ -92,11 +116,17 @@ svg.railroad .terminal rect {
|
||||
comparison: comparison,
|
||||
equality: equality,
|
||||
constraint: constraint,
|
||||
simple-type: simple-type,
|
||||
template-param: template-param,
|
||||
template: template,
|
||||
type: type,
|
||||
type-property: type-property,
|
||||
type-body: type-body,
|
||||
complex-type: complex-type,
|
||||
named-type: named-type,
|
||||
type-params: type-params,
|
||||
generic-type: generic-type,
|
||||
grouped-type: grouped-type,
|
||||
base-type: base-type,
|
||||
constraint-type: constraint-type,
|
||||
type: type,
|
||||
type-statement: type-statement,
|
||||
op-definition: op-definition,
|
||||
extend-statement: extend-statement,
|
||||
@@ -107,10 +137,16 @@ svg.railroad .terminal rect {
|
||||
#let inline = (
|
||||
"grouping",
|
||||
"value",
|
||||
"template-param",
|
||||
"template",
|
||||
"simple-type",
|
||||
"type-property",
|
||||
"type-body",
|
||||
"complex-type",
|
||||
"type-params",
|
||||
"named-type",
|
||||
"grouped-type",
|
||||
"generic-type",
|
||||
"base-type",
|
||||
"constraint-type",
|
||||
"op-definition",
|
||||
"type-statement",
|
||||
"extend-statement",
|
||||
|
||||
@@ -29,7 +29,7 @@ class Tester(ABC):
|
||||
def _list_tests(self) -> list[Path]: ...
|
||||
|
||||
def run_all_tests(self) -> bool:
|
||||
paths: list[Path] = self._list_tests()
|
||||
paths: list[Path] = sorted(self._list_tests())
|
||||
return self.run_tests(paths)
|
||||
|
||||
def run_tests(self, tests: list[Path]) -> bool:
|
||||
@@ -40,7 +40,7 @@ class Tester(ABC):
|
||||
|
||||
print(rule)
|
||||
for i, test in enumerate(tests):
|
||||
print(f"Case {i+1}/{n}: {test.relative_to(self.CASES_DIR)}")
|
||||
print(f"Case {i+1}/{n}: {test.resolve().relative_to(self.CASES_DIR)}")
|
||||
success: bool = self._run_test(test)
|
||||
if success:
|
||||
successes += 1
|
||||
@@ -78,7 +78,7 @@ class Tester(ABC):
|
||||
def _exec_case(self, path: Path) -> CaseResult: ...
|
||||
|
||||
def update_all_tests(self):
|
||||
paths: list[Path] = self._list_tests()
|
||||
paths: list[Path] = sorted(self._list_tests())
|
||||
return self.update_tests(paths)
|
||||
|
||||
def update_tests(self, tests: list[Path]):
|
||||
@@ -141,3 +141,9 @@ class Tester(ABC):
|
||||
success = tester.run_tests(args.FILE)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
case None:
|
||||
print("No subcommand provided. Available subcommands: run, update")
|
||||
sys.exit(1)
|
||||
case _:
|
||||
print(f"Unknown subcommand '{args.subcommand}'")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
{
|
||||
"diagnostics": []
|
||||
"diagnostics": [
|
||||
{
|
||||
"type": "Warning",
|
||||
"location": {
|
||||
"start": [
|
||||
6,
|
||||
4
|
||||
],
|
||||
"end": [
|
||||
13,
|
||||
5
|
||||
]
|
||||
},
|
||||
"message": "FrameType not yet supported"
|
||||
}
|
||||
],
|
||||
"judgments": []
|
||||
}
|
||||
@@ -12,7 +12,7 @@
|
||||
13
|
||||
]
|
||||
},
|
||||
"message": "Cannot assign BaseType(name='str') to c of type BaseType(name='int')"
|
||||
"message": "Cannot assign str to variable 'c' of type int"
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
@@ -26,21 +26,166 @@
|
||||
9
|
||||
]
|
||||
},
|
||||
"message": "Undefined operation __add__ between BaseType(name='bool') and BaseType(name='bool')"
|
||||
"message": "Undefined operation __add__ between bool and bool"
|
||||
}
|
||||
],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L1:9",
|
||||
"to": "L1:10"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 3
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Error",
|
||||
"location": {
|
||||
"start": [
|
||||
11,
|
||||
0
|
||||
],
|
||||
"end": [
|
||||
11,
|
||||
12
|
||||
]
|
||||
"from": "L2:9",
|
||||
"to": "L2:10"
|
||||
},
|
||||
"message": "Cannot assign BaseType(name='int') to f of type BaseType(name='float')"
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 4
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L4:4",
|
||||
"to": "L4:5"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L4:8",
|
||||
"to": "L4:9"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L4:4",
|
||||
"to": "L4:9"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:4",
|
||||
"to": "L6:13"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "invalid"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:4",
|
||||
"to": "L8:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": true
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:4",
|
||||
"to": "L9:5"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "d"
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:8",
|
||||
"to": "L9:9"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "d"
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L9:4",
|
||||
"to": "L9:9"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "d"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "d"
|
||||
}
|
||||
},
|
||||
"type": {}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:11",
|
||||
"to": "L11:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,14 +1,14 @@
|
||||
type Meter(float)
|
||||
type Second(float)
|
||||
type MeterPerSecond(float)
|
||||
type Meter = float
|
||||
type Second = float
|
||||
type MeterPerSecond = float
|
||||
|
||||
extend Meter {
|
||||
op __add__(Meter) -> Meter
|
||||
op __sub__(Meter) -> Meter
|
||||
op __truediv__(Second) -> MeterPerSecond
|
||||
def __add__: fn(Meter, /) -> Meter
|
||||
def __sub__: fn(Meter, /) -> Meter
|
||||
def __truediv__: fn(Second, /) -> MeterPerSecond
|
||||
}
|
||||
|
||||
extend Second {
|
||||
op __add__(Second) -> Second
|
||||
op __sub__(Second) -> Second
|
||||
def __add__: fn(Second, /) -> Second
|
||||
def __sub__: fn(Second, /) -> Second
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# type: ignore
|
||||
# ruff: disable [F821]
|
||||
|
||||
midas.using("04_custom_types.midas")
|
||||
|
||||
distance: Meter = cast(Meter, 123.45)
|
||||
time: Second = cast(Second, 6.7)
|
||||
speed = distance / time
|
||||
|
||||
@@ -1,3 +1,109 @@
|
||||
{
|
||||
"diagnostics": []
|
||||
"diagnostics": [],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L4:18",
|
||||
"to": "L4:37"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "CastExpr",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Meter",
|
||||
"param": null
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 123.45
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "Meter",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L5:15",
|
||||
"to": "L5:32"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "CastExpr",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Second",
|
||||
"param": null
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 6.7
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "Second",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:8",
|
||||
"to": "L6:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "distance"
|
||||
},
|
||||
"type": {
|
||||
"name": "Meter",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:19",
|
||||
"to": "L6:23"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "time"
|
||||
},
|
||||
"type": {
|
||||
"name": "Second",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:8",
|
||||
"to": "L6:23"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "distance"
|
||||
},
|
||||
"operator": "/",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "time"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "MeterPerSecond",
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -42,5 +42,409 @@
|
||||
},
|
||||
"message": "Mixed return types: [BaseType(name='int'), BaseType(name='str')]"
|
||||
}
|
||||
],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L2:11",
|
||||
"to": "L2:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L2:15",
|
||||
"to": "L2:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L2:11",
|
||||
"to": "L2:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L5:7",
|
||||
"to": "L5:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L5:11",
|
||||
"to": "L5:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L5:7",
|
||||
"to": "L5:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "CompareExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"operator": "<",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:15",
|
||||
"to": "L6:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:19",
|
||||
"to": "L6:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:15",
|
||||
"to": "L6:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"operator": "-",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:15",
|
||||
"to": "L8:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:19",
|
||||
"to": "L8:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:15",
|
||||
"to": "L8:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"operator": "-",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:7",
|
||||
"to": "L15:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:11",
|
||||
"to": "L15:13"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 10
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L15:7",
|
||||
"to": "L15:13"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "CompareExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"operator": ">",
|
||||
"right": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 10
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:15",
|
||||
"to": "L16:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:19",
|
||||
"to": "L16:21"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 10
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L16:15",
|
||||
"to": "L16:21"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"operator": "-",
|
||||
"right": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 10
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L18:15",
|
||||
"to": "L18:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L22:7",
|
||||
"to": "L22:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L22:11",
|
||||
"to": "L22:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L22:7",
|
||||
"to": "L22:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "CompareExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"operator": "<",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L23:15",
|
||||
"to": "L23:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L23:19",
|
||||
"to": "L23:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L23:15",
|
||||
"to": "L23:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"operator": "-",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L25:15",
|
||||
"to": "L25:21"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "oops"
|
||||
},
|
||||
"type": {
|
||||
"name": "str"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
12
tests/cases/checker/06_subtyping.py
Normal file
12
tests/cases/checker/06_subtyping.py
Normal file
@@ -0,0 +1,12 @@
|
||||
v1: int = 3
|
||||
v2: float = 4
|
||||
|
||||
|
||||
def maximum(a: float, b: float):
|
||||
if b > a:
|
||||
return b
|
||||
return a
|
||||
|
||||
|
||||
v3 = maximum(v1, v2)
|
||||
v3 = v2 + v1
|
||||
239
tests/cases/checker/06_subtyping.py.ref.json
Normal file
239
tests/cases/checker/06_subtyping.py.ref.json
Normal file
@@ -0,0 +1,239 @@
|
||||
{
|
||||
"diagnostics": [],
|
||||
"judgments": [
|
||||
{
|
||||
"location": {
|
||||
"from": "L1:10",
|
||||
"to": "L1:11"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 3
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L2:12",
|
||||
"to": "L2:13"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 4
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:7",
|
||||
"to": "L6:8"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:11",
|
||||
"to": "L6:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L6:7",
|
||||
"to": "L6:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "CompareExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"operator": ">",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "bool"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L7:15",
|
||||
"to": "L7:16"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "b"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L8:11",
|
||||
"to": "L8:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "a"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:5",
|
||||
"to": "L11:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "maximum"
|
||||
},
|
||||
"type": {
|
||||
"pos_args": [],
|
||||
"args": [
|
||||
{
|
||||
"pos": 0,
|
||||
"name": "a",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"pos": 1,
|
||||
"name": "b",
|
||||
"type": {
|
||||
"name": "float"
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"kw_args": [],
|
||||
"returns": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:13",
|
||||
"to": "L11:15"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:17",
|
||||
"to": "L11:19"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L11:5",
|
||||
"to": "L11:20"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "CallExpr",
|
||||
"callee": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "maximum"
|
||||
},
|
||||
"arguments": [
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
},
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
}
|
||||
],
|
||||
"keywords": {}
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:5",
|
||||
"to": "L12:7"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:10",
|
||||
"to": "L12:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
},
|
||||
"type": {
|
||||
"name": "int"
|
||||
}
|
||||
},
|
||||
{
|
||||
"location": {
|
||||
"from": "L12:5",
|
||||
"to": "L12:12"
|
||||
},
|
||||
"expr": {
|
||||
"_type": "BinaryExpr",
|
||||
"left": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v2"
|
||||
},
|
||||
"operator": "+",
|
||||
"right": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "v1"
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"name": "float"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,17 +1,17 @@
|
||||
// Simple custom type derived from float
|
||||
type Custom(float)
|
||||
type Custom = float
|
||||
|
||||
// Simple custom types with constraints
|
||||
type Latitude(float) where (-90 <= _ <= 90)
|
||||
type Longitude(float) where (-180 <= _ <= 180)
|
||||
type Latitude = float where (-90 <= _ <= 90)
|
||||
type Longitude = float where (-180 <= _ <= 180)
|
||||
|
||||
// Generic custom type (a Difference of T is derived from T, e.g. a difference of floats is a float
|
||||
type Difference[T](T)
|
||||
type Difference[T] = T
|
||||
|
||||
// Complex custom type, containing two values accessible through properties
|
||||
type GeoLocation {
|
||||
lat: Latitude
|
||||
lon: Longitude
|
||||
type GeoLocation = {
|
||||
prop lat: Latitude
|
||||
prop lon: Longitude
|
||||
}
|
||||
|
||||
// Define operations on our custom type
|
||||
@@ -19,23 +19,23 @@ extend GeoLocation {
|
||||
// This type is compatible with the `-` operation with another GeoLocation
|
||||
// i.e. you can subtract a GeoLocation from another GeoLocation, resulting
|
||||
// in a Difference of GeoLocations
|
||||
op __sub__(GeoLocation) -> Difference[GeoLocation]
|
||||
def __sub__: fn(GeoLocation, /) -> Difference[GeoLocation]
|
||||
}
|
||||
|
||||
// For complex generics, you need to specify how the genericity the properties
|
||||
// are handled
|
||||
type Difference[GeoLocation] {
|
||||
lat: Difference[Latitude]
|
||||
lon: Difference[Longitude]
|
||||
type Difference[GeoLocation] = {
|
||||
prop lat: Difference[Latitude]
|
||||
prop lon: Difference[Longitude]
|
||||
}
|
||||
|
||||
// Simple operation defined on our custom types
|
||||
extend Latitude {
|
||||
op __sub__(Latitude) -> Difference[Latitude]
|
||||
def __sub__: fn(Latitude, /) -> Difference[Latitude]
|
||||
}
|
||||
|
||||
extend Longitude {
|
||||
op __sub__(Longitude) -> Difference[Longitude]
|
||||
def __sub__: fn(Longitude, /) -> Difference[Longitude]
|
||||
}
|
||||
|
||||
// Predefined custom predicates that can be referenced in other definitions
|
||||
@@ -44,14 +44,14 @@ predicate StrictlyPositive(v: float) = v > 0
|
||||
predicate Equatorial(loc: GeoLocation) = (-10 <= loc.lat <= 10)
|
||||
predicate Arctic(loc: GeoLocation) = (loc.lat >= 66)
|
||||
|
||||
type Person {
|
||||
name: str
|
||||
type Person = {
|
||||
prop name: str
|
||||
|
||||
// Property with an inline constraint
|
||||
age: int? where (0 <= _ < 150)
|
||||
prop age: Optional[int where (0 <= _ < 150)]
|
||||
|
||||
// Property referencing a predicate
|
||||
height: float where StrictlyPositive
|
||||
prop height: float where StrictlyPositive
|
||||
|
||||
home: GeoLocation
|
||||
prop home: GeoLocation
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,10 +2,6 @@
|
||||
# ruff: disable[F821]
|
||||
from __future__ import annotations
|
||||
|
||||
import midas
|
||||
|
||||
midas.using("02_custom_types.midas")
|
||||
|
||||
df: Frame[
|
||||
location: GeoLocation
|
||||
]
|
||||
|
||||
@@ -1,26 +1,5 @@
|
||||
{
|
||||
"stmts": [
|
||||
{
|
||||
"_type": "ExpressionStmt",
|
||||
"expr": {
|
||||
"_type": "CallExpr",
|
||||
"callee": {
|
||||
"_type": "GetExpr",
|
||||
"object": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "midas"
|
||||
},
|
||||
"name": "using"
|
||||
},
|
||||
"arguments": [
|
||||
{
|
||||
"_type": "LiteralExpr",
|
||||
"value": "02_custom_types.midas"
|
||||
}
|
||||
],
|
||||
"keywords": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "df",
|
||||
@@ -39,6 +18,80 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "lat",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "AssignStmt",
|
||||
"targets": [
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat"
|
||||
}
|
||||
],
|
||||
"value": {
|
||||
"_type": "GetExpr",
|
||||
"object": {
|
||||
"_type": "SubscriptExpr",
|
||||
"object": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "df"
|
||||
},
|
||||
"index": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "location"
|
||||
}
|
||||
},
|
||||
"name": "lat"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "lon",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Column",
|
||||
"param": {
|
||||
"_type": "BaseType",
|
||||
"base": "GeoLocation",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "AssignStmt",
|
||||
"targets": [
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "lon"
|
||||
}
|
||||
],
|
||||
"value": {
|
||||
"_type": "GetExpr",
|
||||
"object": {
|
||||
"_type": "SubscriptExpr",
|
||||
"object": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "df"
|
||||
},
|
||||
"index": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": "location"
|
||||
}
|
||||
},
|
||||
"name": "lon"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "ExpressionStmt",
|
||||
"expr": {
|
||||
@@ -54,6 +107,64 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "lat1",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Latitude",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "AssignStmt",
|
||||
"targets": [
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat1"
|
||||
}
|
||||
],
|
||||
"value": {
|
||||
"_type": "SubscriptExpr",
|
||||
"object": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat"
|
||||
},
|
||||
"index": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "lat2",
|
||||
"type": {
|
||||
"_type": "BaseType",
|
||||
"base": "Latitude",
|
||||
"param": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "AssignStmt",
|
||||
"targets": [
|
||||
{
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat2"
|
||||
}
|
||||
],
|
||||
"value": {
|
||||
"_type": "SubscriptExpr",
|
||||
"object": {
|
||||
"_type": "VariableExpr",
|
||||
"name": "lat"
|
||||
},
|
||||
"index": {
|
||||
"_type": "LiteralExpr",
|
||||
"value": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_type": "TypeAssign",
|
||||
"name": "lat_diff",
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import ast
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.checker.checker import Checker
|
||||
from midas.checker.checker import TypeChecker
|
||||
from midas.checker.diagnostic import Diagnostic
|
||||
from midas.parser.python import PythonParser
|
||||
from midas.resolver.resolver import Resolver
|
||||
from midas.checker.types import Type
|
||||
from tests.base import Tester
|
||||
from tests.serializer.python import PythonAstJsonSerializer
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaseResult:
|
||||
diagnostics: list[dict] = field(default_factory=list)
|
||||
judgments: list = field(default_factory=list)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(asdict(self), indent=2)
|
||||
@@ -33,15 +33,16 @@ class CheckerTester(Tester):
|
||||
if not path.is_file():
|
||||
raise TypeError(f"Test '{path}' is not a file")
|
||||
|
||||
source: str = path.read_text()
|
||||
tree: ast.Module = ast.parse(source, filename=path)
|
||||
parser = PythonParser()
|
||||
stmts: list[p.Stmt] = parser.parse_module(tree)
|
||||
resolver = Resolver()
|
||||
resolver.resolve(*stmts)
|
||||
result: CaseResult = CaseResult()
|
||||
checker = Checker(resolver.locals, file_path=path)
|
||||
diagnostics: list[Diagnostic] = checker.check(stmts)
|
||||
|
||||
checker = TypeChecker()
|
||||
types_path: Path = path.with_suffix(".midas")
|
||||
if types_path.exists():
|
||||
checker.import_midas(types_path)
|
||||
|
||||
checker.type_check(path)
|
||||
|
||||
diagnostics: list[Diagnostic] = checker.diagnostics
|
||||
for diagnostic in diagnostics:
|
||||
result.diagnostics.append(
|
||||
{
|
||||
@@ -60,6 +61,21 @@ class CheckerTester(Tester):
|
||||
}
|
||||
)
|
||||
|
||||
judgements: list[tuple[p.Expr, Type]] = checker.python_typer.judgements
|
||||
serializer = PythonAstJsonSerializer()
|
||||
for expr, type in judgements:
|
||||
loc = expr.location
|
||||
result.judgments.append(
|
||||
{
|
||||
"location": {
|
||||
"from": f"L{loc.lineno}:{loc.col_offset}",
|
||||
"to": f"L{loc.end_lineno}:{loc.end_col_offset}",
|
||||
},
|
||||
"expr": expr.accept(serializer),
|
||||
"type": asdict(type),
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -2,79 +2,76 @@ from typing import Optional, Sequence
|
||||
|
||||
from midas.ast.midas import (
|
||||
BinaryExpr,
|
||||
ComplexTypeStmt,
|
||||
ComplexType,
|
||||
ConstraintType,
|
||||
Expr,
|
||||
ExtendStmt,
|
||||
ExtensionType,
|
||||
FunctionType,
|
||||
GenericType,
|
||||
GetExpr,
|
||||
GroupingExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
OpStmt,
|
||||
MemberStmt,
|
||||
NamedType,
|
||||
PredicateStmt,
|
||||
PropertyStmt,
|
||||
SimpleTypeExpr,
|
||||
SimpleTypeStmt,
|
||||
Stmt,
|
||||
TemplateExpr,
|
||||
TypeExpr,
|
||||
Type,
|
||||
TypeParam,
|
||||
TypeStmt,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
WildcardExpr,
|
||||
)
|
||||
|
||||
|
||||
class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
|
||||
class MidasAstJsonSerializer(
|
||||
Stmt.Visitor[dict], Expr.Visitor[dict], Type.Visitor[dict]
|
||||
):
|
||||
"""An AST serializer which produces a JSON-compatible structure"""
|
||||
|
||||
def serialize(self, stmts: list[Stmt]) -> list[dict]:
|
||||
return [stmt.accept(self) for stmt in stmts]
|
||||
|
||||
def _serialize_optional(self, element: Optional[Stmt | Expr]) -> Optional[dict]:
|
||||
def _serialize_optional(
|
||||
self, element: Optional[Stmt | Expr | Type]
|
||||
) -> Optional[dict]:
|
||||
if element is None:
|
||||
return None
|
||||
return element.accept(self)
|
||||
|
||||
def _serialize_list(self, elements: Sequence[Stmt | Expr]) -> list[dict]:
|
||||
def _serialize_list(self, elements: Sequence[Stmt | Expr | Type]) -> list[dict]:
|
||||
return [element.accept(self) for element in elements]
|
||||
|
||||
def visit_simple_type_stmt(self, stmt: SimpleTypeStmt) -> dict:
|
||||
def visit_type_stmt(self, stmt: TypeStmt) -> dict:
|
||||
return {
|
||||
"_type": "SimpleTypeStmt",
|
||||
"_type": "TypeStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"template": self._serialize_optional(stmt.template),
|
||||
"base": stmt.base.accept(self),
|
||||
"constraint": self._serialize_optional(stmt.constraint),
|
||||
"params": [self._serialize_type_param(param) for param in stmt.params],
|
||||
"type": stmt.type.accept(self),
|
||||
}
|
||||
|
||||
def visit_complex_type_stmt(self, stmt: ComplexTypeStmt) -> dict:
|
||||
def _serialize_type_param(self, param: TypeParam) -> dict:
|
||||
return {
|
||||
"_type": "ComplexTypeStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"template": self._serialize_optional(stmt.template),
|
||||
"properties": self._serialize_list(stmt.properties),
|
||||
"name": param.name.lexeme,
|
||||
"bound": self._serialize_optional(param.bound),
|
||||
}
|
||||
|
||||
def visit_property_stmt(self, stmt: PropertyStmt) -> dict:
|
||||
def visit_member_stmt(self, stmt: MemberStmt) -> dict:
|
||||
return {
|
||||
"_type": "PropertyStmt",
|
||||
"_type": "MemberStmt",
|
||||
"kind": stmt.kind.name,
|
||||
"name": stmt.name.lexeme,
|
||||
"type": stmt.type.accept(self),
|
||||
"constraint": self._serialize_optional(stmt.constraint),
|
||||
}
|
||||
|
||||
def visit_extend_stmt(self, stmt: ExtendStmt) -> dict:
|
||||
return {
|
||||
"_type": "ExtendStmt",
|
||||
"type": stmt.type.accept(self),
|
||||
"operations": self._serialize_list(stmt.operations),
|
||||
}
|
||||
|
||||
def visit_op_stmt(self, stmt: OpStmt) -> dict:
|
||||
return {
|
||||
"_type": "OpStmt",
|
||||
"name": stmt.name.lexeme,
|
||||
"operand": stmt.operand.accept(self),
|
||||
"result": stmt.result.accept(self),
|
||||
"params": [self._serialize_type_param(param) for param in stmt.params],
|
||||
"members": self._serialize_list(stmt.members),
|
||||
}
|
||||
|
||||
def visit_predicate_stmt(self, stmt: PredicateStmt) -> dict:
|
||||
@@ -86,13 +83,6 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
|
||||
"condition": stmt.condition.accept(self),
|
||||
}
|
||||
|
||||
def visit_simple_type_expr(self, expr: SimpleTypeExpr) -> dict:
|
||||
return {
|
||||
"_type": "SimpleTypeExpr",
|
||||
"name": expr.name.lexeme,
|
||||
"optional": expr.optional,
|
||||
}
|
||||
|
||||
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||
return {
|
||||
"_type": "LogicalExpr",
|
||||
@@ -144,16 +134,51 @@ class MidasAstJsonSerializer(Stmt.Visitor[dict], Expr.Visitor[dict]):
|
||||
def visit_wildcard_expr(self, expr: WildcardExpr) -> dict:
|
||||
return {"_type": "WildcardExpr"}
|
||||
|
||||
def visit_template_expr(self, expr: TemplateExpr) -> dict:
|
||||
def visit_named_type(self, type: NamedType) -> dict:
|
||||
return {
|
||||
"_type": "TemplateExpr",
|
||||
"type": expr.type.accept(self),
|
||||
"_type": "NamedType",
|
||||
"name": type.name.lexeme,
|
||||
}
|
||||
|
||||
def visit_type_expr(self, expr: TypeExpr) -> dict:
|
||||
def visit_generic_type(self, type: GenericType) -> dict:
|
||||
return {
|
||||
"_type": "TypeExpr",
|
||||
"name": expr.name.lexeme,
|
||||
"template": self._serialize_optional(expr.template),
|
||||
"optional": expr.optional,
|
||||
"_type": "GenericType",
|
||||
"type": type.type.accept(self),
|
||||
"args": self._serialize_list(type.args),
|
||||
}
|
||||
|
||||
def visit_constraint_type(self, type: ConstraintType) -> dict:
|
||||
return {
|
||||
"_type": "ConstraintType",
|
||||
"type": type.type.accept(self),
|
||||
"constraint": type.constraint.accept(self),
|
||||
}
|
||||
|
||||
def visit_complex_type(self, type: ComplexType) -> dict:
|
||||
return {
|
||||
"_type": "ComplexType",
|
||||
"members": self._serialize_list(type.members),
|
||||
}
|
||||
|
||||
def visit_function_type(self, type: FunctionType) -> dict:
|
||||
return {
|
||||
"_type": "FunctionType",
|
||||
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
|
||||
"args": [self._serialize_func_arg(arg) for arg in type.args],
|
||||
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
|
||||
"returns": type.returns.accept(self),
|
||||
}
|
||||
|
||||
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||
return {
|
||||
"name": arg.name,
|
||||
"type": arg.type.accept(self),
|
||||
"required": arg.required,
|
||||
}
|
||||
|
||||
def visit_extension_type(self, type: ExtensionType) -> dict:
|
||||
return {
|
||||
"_type": "ExtensionType",
|
||||
"base": type.base.accept(self),
|
||||
"extension": type.extension.accept(self),
|
||||
}
|
||||
|
||||
@@ -16,12 +16,15 @@ from midas.ast.python import (
|
||||
Function,
|
||||
GetExpr,
|
||||
IfStmt,
|
||||
ListExpr,
|
||||
LiteralExpr,
|
||||
LogicalExpr,
|
||||
MidasType,
|
||||
ReturnStmt,
|
||||
SetExpr,
|
||||
SliceExpr,
|
||||
Stmt,
|
||||
SubscriptExpr,
|
||||
TernaryExpr,
|
||||
TypeAssign,
|
||||
UnaryExpr,
|
||||
VariableExpr,
|
||||
@@ -231,17 +234,38 @@ class PythonAstJsonSerializer(
|
||||
"right": expr.right.accept(self),
|
||||
}
|
||||
|
||||
def visit_set_expr(self, expr: SetExpr) -> dict:
|
||||
return {
|
||||
"_type": "SetExpr",
|
||||
"object": expr.object.accept(self),
|
||||
"name": expr.name,
|
||||
"value": expr.value.accept(self),
|
||||
}
|
||||
|
||||
def visit_cast_expr(self, expr: CastExpr) -> dict:
|
||||
return {
|
||||
"_type": "CastExpr",
|
||||
"type": expr.type.accept(self),
|
||||
"expr": expr.expr.accept(self),
|
||||
}
|
||||
|
||||
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
|
||||
return {
|
||||
"_type": "TernaryExpr",
|
||||
"test": expr.test.accept(self),
|
||||
"if_true": expr.if_true.accept(self),
|
||||
"if_false": expr.if_false.accept(self),
|
||||
}
|
||||
|
||||
def visit_list_expr(self, expr: ListExpr) -> dict:
|
||||
return {
|
||||
"_type": "ListExpr",
|
||||
"items": [item.accept(self) for item in expr.items],
|
||||
}
|
||||
|
||||
def visit_subscript_expr(self, expr: SubscriptExpr) -> dict:
|
||||
return {
|
||||
"_type": "SubscriptExpr",
|
||||
"object": expr.object.accept(self),
|
||||
"index": expr.index.accept(self),
|
||||
}
|
||||
|
||||
def visit_slice_expr(self, expr: SliceExpr) -> dict:
|
||||
return {
|
||||
"_type": "SliceExpr",
|
||||
"lower": self._serialize_optional(expr.lower),
|
||||
"upper": self._serialize_optional(expr.upper),
|
||||
"step": self._serialize_optional(expr.step),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user