Compare commits
117 Commits
11422d4364
...
feat/dataf
| Author | SHA1 | Date | |
|---|---|---|---|
|
16239db479
|
|||
|
dc2134c87d
|
|||
|
89f3c945e4
|
|||
|
45f7d1be2b
|
|||
|
27f3fa7d1e
|
|||
|
78eba39ae3
|
|||
|
3b78b37306
|
|||
|
9e14b30bc9
|
|||
|
a6a1075f91
|
|||
|
11be47fce3
|
|||
|
2eeede9826
|
|||
|
f796f4c6fa
|
|||
|
c333735580
|
|||
|
2416102494
|
|||
|
eb4971686a
|
|||
|
9f59366289
|
|||
|
fd0b410d74
|
|||
|
5b0c5c01ad
|
|||
|
43e40396a1
|
|||
|
0d265ef24c
|
|||
|
88c56c9d15
|
|||
|
d1c217a335
|
|||
|
5b3e87afcb
|
|||
|
894d5a7196
|
|||
|
eb809c6341
|
|||
|
bd68d1003f
|
|||
|
72c9236650
|
|||
|
90051c7981
|
|||
|
dd1e2e693c
|
|||
|
78e10e0895
|
|||
|
c81e4a9560
|
|||
|
6d0cf1a055
|
|||
|
cc5e7af143
|
|||
|
3bdbc80079
|
|||
|
c1b5284f72
|
|||
|
5e9ccd4e13
|
|||
|
cf083fc0c3
|
|||
|
a80da5db2c
|
|||
| f7c43837b5 | |||
|
32ed62a6f1
|
|||
|
66f39acec0
|
|||
|
6c04e2fee4
|
|||
| 2bb2e0a684 | |||
|
5630320d21
|
|||
|
9f05ba3224
|
|||
|
5fbe965919
|
|||
| 252a5abdfd | |||
|
55fba6a088
|
|||
|
70ce263ea2
|
|||
|
e1d5eac8b8
|
|||
|
82666a4918
|
|||
|
45f84a2f23
|
|||
|
dedfcb4dbb
|
|||
|
d9ea6365ea
|
|||
| 9c7a93412c | |||
|
d6b8fbfb60
|
|||
|
b290c59ac4
|
|||
|
093f2bc477
|
|||
|
7c771c4070
|
|||
|
a50a207385
|
|||
|
7e5ea5e414
|
|||
|
0ba0266bae
|
|||
|
216c80f08c
|
|||
|
f75d7722a1
|
|||
|
2f29c47274
|
|||
|
80af2b9048
|
|||
|
577454ee7e
|
|||
|
878693383e
|
|||
|
0b91de75a8
|
|||
| 739871c101 | |||
|
4395e9339b
|
|||
|
29e601128d
|
|||
|
b591f5508f
|
|||
|
41d0c84bbe
|
|||
| cccf2f8f9f | |||
|
3f48c2138f
|
|||
|
e4ab27673d
|
|||
|
b02ecc6326
|
|||
|
9e83079910
|
|||
|
ec468dd982
|
|||
|
3edc25d778
|
|||
|
451e54b009
|
|||
|
0dc14f67aa
|
|||
|
ff79f25628
|
|||
| 12782dda1e | |||
|
48a20b4aa0
|
|||
|
9467187313
|
|||
|
cd8f14153d
|
|||
| 6eea0c02e0 | |||
|
3205e7b961
|
|||
|
0aba134290
|
|||
|
1f0bcab2ca
|
|||
|
db8d88ef35
|
|||
|
7695d50537
|
|||
|
8461d05fa6
|
|||
|
43d2118db7
|
|||
|
6a87b5396f
|
|||
|
e6a581ba6e
|
|||
|
2a7aac69ed
|
|||
|
eb5bf19c61
|
|||
|
657406ea01
|
|||
|
2974386110
|
|||
|
92ca6b6732
|
|||
|
6aacdb98b7
|
|||
|
1b100b6ceb
|
|||
|
6b4c7d27bc
|
|||
|
2523d638f7
|
|||
|
5fc7461e29
|
|||
|
c5154bde81
|
|||
|
d07e8ac0ca
|
|||
|
3380995082
|
|||
|
7efc44c496
|
|||
|
ca94443699
|
|||
|
c513a85cf2
|
|||
|
2a106c5d07
|
|||
| 9672dfd588 | |||
|
7639ccc94d
|
99
README.md
99
README.md
@@ -1,4 +1,4 @@
|
|||||||
# Midas
|
<h1>Midas</h1>
|
||||||
|
|
||||||
*Midas* is a type system to _Maintain Integrity of Data with Annotated Structures_. In Greek mythology, [Midas](https://en.wikipedia.org/wiki/Midas) was a Phrygian king who was blessed with the gift of turning everything he touched into gold.
|
*Midas* is a type system to _Maintain Integrity of Data with Annotated Structures_. In Greek mythology, [Midas](https://en.wikipedia.org/wiki/Midas) was a Phrygian king who was blessed with the gift of turning everything he touched into gold.
|
||||||
|
|
||||||
@@ -6,6 +6,25 @@
|
|||||||
|
|
||||||
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
|
This framework is being developed as part of a Bachelor's Thesis by Louis Heredero at HEI Sion.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>Table of Contents</strong></summary>
|
||||||
|
|
||||||
|
- [Requirements](#requirements)
|
||||||
|
- [Installation](#installation)
|
||||||
|
- [Commands](#commands)
|
||||||
|
- [Type Checking](#type-checking)
|
||||||
|
- [Compiling](#compiling)
|
||||||
|
- [Formatting](#formatting)
|
||||||
|
- [Highlighting](#highlighting)
|
||||||
|
- [Dumping the AST](#dumping-the-ast)
|
||||||
|
- [Dumping the Registry](#dumping-the-registry)
|
||||||
|
- [Generating Stubs](#generating-stubs)
|
||||||
|
- [Showing Type Judgements](#showing-type-judgements)
|
||||||
|
- [Validating Definitions](#validating-definitions)
|
||||||
|
- [Tests](#tests)
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- Python 3.11+
|
- Python 3.11+
|
||||||
@@ -32,10 +51,26 @@ This framework is being developed as part of a Bachelor's Thesis by Louis Herede
|
|||||||
|
|
||||||
## Commands
|
## Commands
|
||||||
|
|
||||||
### Compiling
|
<!--
|
||||||
|
check
|
||||||
|
compile
|
||||||
|
format
|
||||||
|
highlight
|
||||||
|
parse
|
||||||
|
dump_registry
|
||||||
|
types
|
||||||
|
validate
|
||||||
|
-->
|
||||||
|
|
||||||
> [!NOTE]
|
### Type Checking
|
||||||
> In the current state of the project, the `compile` command doesn't generate any runnable code, it only runs the parsers and type checker on the provided files
|
|
||||||
|
```shell
|
||||||
|
midas check -t types.midas source.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This command parses the given files and run the type checkers against the Midas definitions and Python program. Diagnostics are then printed showing warnings and errors.
|
||||||
|
|
||||||
|
### Compiling
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
midas compile -t types.midas source.py
|
midas compile -t types.midas source.py
|
||||||
@@ -43,14 +78,22 @@ midas compile -t types.midas source.py
|
|||||||
|
|
||||||
With the `compile` command, you can process a source Python file, with any number of custom type definition files (`-t FILE` option), and the type checker will verify the coherence of your program and generate the runnable code with valid syntax and runtime assertions.
|
With the `compile` command, you can process a source Python file, with any number of custom type definition files (`-t FILE` option), and the type checker will verify the coherence of your program and generate the runnable code with valid syntax and runtime assertions.
|
||||||
|
|
||||||
The optional `-l FILE` option lets you produce a highlighted version of the source code showing diagnostics from the type checker (see [Highlighting](#highlighting))
|
### Formatting
|
||||||
|
|
||||||
|
```shell
|
||||||
|
midas format types.midas
|
||||||
|
midas format types.midas -o formatted.midas
|
||||||
|
```
|
||||||
|
|
||||||
|
This command parses the given Midas file and outputs a pretty printed file from the AST.
|
||||||
|
|
||||||
### Highlighting
|
### Highlighting
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
midas utils highlight source.py
|
midas highlight source.py
|
||||||
# or
|
midas highlight source.py -o highlighted.html
|
||||||
midas utils highlight types.midas
|
midas highlight types.midas
|
||||||
|
midas highlight types.midas -o highlighted.html
|
||||||
```
|
```
|
||||||
|
|
||||||
The `highlight` command takes in a source file (Python or Midas), runs the appropriate parser and outputs an HTML file containing the source code with added highlighting. This highlighting takes the form of hoverable annotations showing some of the parsed structures (e.g. a function definition, an assignment, a generic type, etc.)
|
The `highlight` command takes in a source file (Python or Midas), runs the appropriate parser and outputs an HTML file containing the source code with added highlighting. This highlighting takes the form of hoverable annotations showing some of the parsed structures (e.g. a function definition, an assignment, a generic type, etc.)
|
||||||
@@ -60,14 +103,43 @@ The optional `-o FILE` option can be used to specify an output path. By default,
|
|||||||
### Dumping the AST
|
### Dumping the AST
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
midas utils dump-ast source.py
|
midas parse source.py
|
||||||
# or
|
midas parse types.midas
|
||||||
midas utils dump-ast types.midas
|
|
||||||
```
|
```
|
||||||
|
|
||||||
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `-p` flags lets you toggle the custom AST parsing. Without `-p`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
|
For debugging purposes, you can output the AST parsed from a Python or Midas file. For Python files, the `--raw` flags lets you toggle the custom AST parsing. With `--raw`, the raw AST is returned, as produced by the builtin `ast` module. This flag has no effect on Midas files.
|
||||||
|
|
||||||
The optional `-o FILE` option can be used to specify an output path. By default, the file is printed in stdout (equivalent to `-o -`).
|
### Dumping the Registry
|
||||||
|
|
||||||
|
```shell
|
||||||
|
midas dump-registry -t types.midas
|
||||||
|
```
|
||||||
|
|
||||||
|
This command processes the given Midas definitions and dumps the contents of the types registry.
|
||||||
|
|
||||||
|
### Generating Stubs
|
||||||
|
|
||||||
|
```shell
|
||||||
|
midas stubs types.midas -o stubs.pyi
|
||||||
|
```
|
||||||
|
|
||||||
|
This command generate Python stubs from a Midas definition file
|
||||||
|
|
||||||
|
### Showing Type Judgements
|
||||||
|
|
||||||
|
```shell
|
||||||
|
midas types -t types.midas source.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This command type checks the given Python source file and logs all typing judgements made by the type checker.
|
||||||
|
|
||||||
|
### Validating Definitions
|
||||||
|
|
||||||
|
```shell
|
||||||
|
midas validate types.midas
|
||||||
|
```
|
||||||
|
|
||||||
|
This command lets you validate a Midas definition file by running the parser and type checker, verifying syntax and references.
|
||||||
|
|
||||||
## Tests
|
## Tests
|
||||||
|
|
||||||
@@ -77,6 +149,7 @@ Several snapshot tests are available to assert the good behaviour of the parsers
|
|||||||
uv run -m tests.midas run -a
|
uv run -m tests.midas run -a
|
||||||
uv run -m tests.python run -a
|
uv run -m tests.python run -a
|
||||||
uv run -m tests.checker run -a
|
uv run -m tests.checker run -a
|
||||||
|
uv run -m tests.generator run -a
|
||||||
```
|
```
|
||||||
|
|
||||||
**Available subcommands:**
|
**Available subcommands:**
|
||||||
|
|||||||
15
examples/02_demonstration/demo.midas
Normal file
15
examples/02_demonstration/demo.midas
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
predicate in_range(min: float, max: float)(v: float) = min <= v & v <= max
|
||||||
|
predicate is_ratio = in_range(0, 1)
|
||||||
|
|
||||||
|
type Currency = float
|
||||||
|
type Price[T <: Currency] = T where _ >= 0
|
||||||
|
|
||||||
|
extend Price[T <: Currency] {
|
||||||
|
def __add__: fn(Price[T], /) -> Price[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
type EUR = Currency
|
||||||
|
type USD = Currency
|
||||||
|
type CHF = Currency
|
||||||
|
|
||||||
|
type Discount = float where is_ratio(_)
|
||||||
35
examples/02_demonstration/demo.py
Normal file
35
examples/02_demonstration/demo.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from demo_stubs import CHF, EUR, USD, Currency, Discount, Price
|
||||||
|
|
||||||
|
from midas.typing import cast, unsafe_cast
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=Currency)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_discount(amount: Price[T], discount: Discount) -> Price[T]:
|
||||||
|
return cast(Price[T], (1.0 - discount) * amount)
|
||||||
|
|
||||||
|
|
||||||
|
a1 = cast(Price[EUR], 3.2)
|
||||||
|
a2 = cast(Price[USD], 10.4)
|
||||||
|
r1 = cast(Discount, 0.2)
|
||||||
|
|
||||||
|
print(apply_discount(a1, r1))
|
||||||
|
print(apply_discount(a2, r1))
|
||||||
|
|
||||||
|
a3 = a1 + a1
|
||||||
|
a4 = a1 + a2 # cannot add euros and dollars
|
||||||
|
a3 = a2 # cannot change variable type
|
||||||
|
|
||||||
|
dyn_price = float(input("Price (CHF): "))
|
||||||
|
dyn_discount = float(input("Discount (0.0-1.0): "))
|
||||||
|
discounted = apply_discount(
|
||||||
|
cast(Price[CHF], dyn_price),
|
||||||
|
cast(Discount, dyn_discount),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Discounted: CHF {discounted}")
|
||||||
|
|
||||||
|
large_data = [i * 10 for i in range(100)]
|
||||||
|
prices = unsafe_cast(list[Price[EUR]], large_data)
|
||||||
14
examples/02_demonstration/demo_stubs.pyi
Normal file
14
examples/02_demonstration/demo_stubs.pyi
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
class Currency(float): ...
|
||||||
|
|
||||||
|
_T0 = TypeVar("_T0", bound=Currency, covariant=True)
|
||||||
|
|
||||||
|
class Price(Currency, Generic[_T0]):
|
||||||
|
def __add__(self, _0: Price[_T0], /) -> Price[_T0]: ...
|
||||||
|
|
||||||
|
class EUR(Currency): ...
|
||||||
|
class USD(Currency): ...
|
||||||
|
class CHF(Currency): ...
|
||||||
|
class Discount(float): ...
|
||||||
33
gen/midas.py
33
gen/midas.py
@@ -26,6 +26,14 @@ class MemberKind(Enum):
|
|||||||
METHOD = auto()
|
METHOD = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ParamSpec:
|
||||||
|
l_paren: Token
|
||||||
|
pos: list[FunctionType.Argument]
|
||||||
|
mixed: list[FunctionType.Argument]
|
||||||
|
kw: list[FunctionType.Argument]
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|
||||||
|
|
||||||
@@ -50,9 +58,8 @@ class ExtendStmt:
|
|||||||
|
|
||||||
class PredicateStmt:
|
class PredicateStmt:
|
||||||
name: Token
|
name: Token
|
||||||
subject: Token
|
params: list[ParamSpec]
|
||||||
type: Type
|
body: Expr
|
||||||
condition: Expr
|
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
@@ -78,6 +85,12 @@ class UnaryExpr:
|
|||||||
right: Expr
|
right: Expr
|
||||||
|
|
||||||
|
|
||||||
|
class CallExpr:
|
||||||
|
callee: Expr
|
||||||
|
arguments: list[Expr]
|
||||||
|
keywords: dict[str, Expr]
|
||||||
|
|
||||||
|
|
||||||
class GetExpr:
|
class GetExpr:
|
||||||
expr: Expr
|
expr: Expr
|
||||||
name: Token
|
name: Token
|
||||||
@@ -128,9 +141,7 @@ class ExtensionType:
|
|||||||
|
|
||||||
|
|
||||||
class FunctionType:
|
class FunctionType:
|
||||||
pos_args: list[Argument]
|
params: ParamSpec
|
||||||
args: list[Argument]
|
|
||||||
kw_args: list[Argument]
|
|
||||||
returns: Type
|
returns: Type
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
@@ -141,4 +152,14 @@ class FunctionType:
|
|||||||
required: bool
|
required: bool
|
||||||
|
|
||||||
|
|
||||||
|
class FrameType:
|
||||||
|
columns: list[Column]
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Column:
|
||||||
|
location: Optional[Location] = None
|
||||||
|
name: Token
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
|
||||||
###<
|
###<
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from midas.ast.location import Location
|
|||||||
###> MidasType | Type annotations | node
|
###> MidasType | Type annotations | node
|
||||||
class BaseType:
|
class BaseType:
|
||||||
base: str
|
base: str
|
||||||
param: Optional[MidasType]
|
args: tuple[MidasType, ...]
|
||||||
|
|
||||||
|
|
||||||
class ConstraintType:
|
class ConstraintType:
|
||||||
@@ -145,6 +145,7 @@ class LogicalExpr:
|
|||||||
class CastExpr:
|
class CastExpr:
|
||||||
type: MidasType
|
type: MidasType
|
||||||
expr: Expr
|
expr: Expr
|
||||||
|
unsafe: bool
|
||||||
|
|
||||||
|
|
||||||
class TernaryExpr:
|
class TernaryExpr:
|
||||||
@@ -173,6 +174,10 @@ class SliceExpr:
|
|||||||
step: Optional[Expr]
|
step: Optional[Expr]
|
||||||
|
|
||||||
|
|
||||||
|
class TupleExpr:
|
||||||
|
items: tuple[Expr, ...]
|
||||||
|
|
||||||
|
|
||||||
class RawExpr:
|
class RawExpr:
|
||||||
expr: ast.expr
|
expr: ast.expr
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,14 @@ class MemberKind(Enum):
|
|||||||
METHOD = auto()
|
METHOD = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ParamSpec:
|
||||||
|
l_paren: Token
|
||||||
|
pos: list[FunctionType.Argument]
|
||||||
|
mixed: list[FunctionType.Argument]
|
||||||
|
kw: list[FunctionType.Argument]
|
||||||
|
|
||||||
|
|
||||||
##############
|
##############
|
||||||
# Statements #
|
# Statements #
|
||||||
##############
|
##############
|
||||||
@@ -86,9 +94,8 @@ class ExtendStmt(Stmt):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PredicateStmt(Stmt):
|
class PredicateStmt(Stmt):
|
||||||
name: Token
|
name: Token
|
||||||
subject: Token
|
params: list[ParamSpec]
|
||||||
type: Type
|
body: Expr
|
||||||
condition: Expr
|
|
||||||
|
|
||||||
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
def accept(self, visitor: Stmt.Visitor[T]) -> T:
|
||||||
return visitor.visit_predicate_stmt(self)
|
return visitor.visit_predicate_stmt(self)
|
||||||
@@ -116,6 +123,9 @@ class Expr(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
def visit_unary_expr(self, expr: UnaryExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_call_expr(self, expr: CallExpr) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
def visit_get_expr(self, expr: GetExpr) -> T: ...
|
||||||
|
|
||||||
@@ -161,6 +171,16 @@ class UnaryExpr(Expr):
|
|||||||
return visitor.visit_unary_expr(self)
|
return visitor.visit_unary_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CallExpr(Expr):
|
||||||
|
callee: Expr
|
||||||
|
arguments: list[Expr]
|
||||||
|
keywords: dict[str, Expr]
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_call_expr(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class GetExpr(Expr):
|
class GetExpr(Expr):
|
||||||
expr: Expr
|
expr: Expr
|
||||||
@@ -233,6 +253,9 @@ class Type(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_function_type(self, type: FunctionType) -> T: ...
|
def visit_function_type(self, type: FunctionType) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_frame_type(self, type: FrameType) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class NamedType(Type):
|
class NamedType(Type):
|
||||||
@@ -279,9 +302,7 @@ class ExtensionType(Type):
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class FunctionType(Type):
|
class FunctionType(Type):
|
||||||
pos_args: list[Argument]
|
params: ParamSpec
|
||||||
args: list[Argument]
|
|
||||||
kw_args: list[Argument]
|
|
||||||
returns: Type
|
returns: Type
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
@@ -293,3 +314,17 @@ class FunctionType(Type):
|
|||||||
|
|
||||||
def accept(self, visitor: Type.Visitor[T]) -> T:
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
return visitor.visit_function_type(self)
|
return visitor.visit_function_type(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FrameType(Type):
|
||||||
|
columns: list[Column]
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Column:
|
||||||
|
location: Optional[Location] = None
|
||||||
|
name: Token
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
def accept(self, visitor: Type.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_frame_type(self)
|
||||||
|
|||||||
@@ -150,13 +150,17 @@ class MidasAstPrinter(
|
|||||||
self._write_line("PredicateStmt")
|
self._write_line("PredicateStmt")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line(f'name: "{stmt.name.lexeme}"')
|
self._write_line(f'name: "{stmt.name.lexeme}"')
|
||||||
self._write_line(f'subject: "{stmt.subject.lexeme}"')
|
self._write_line("params")
|
||||||
self._write_line("type")
|
with self._child_level():
|
||||||
|
for i, spec in enumerate(stmt.params):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(stmt.params) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._visit_param_spec(spec)
|
||||||
|
|
||||||
|
self._write_line("body", last=True)
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
stmt.type.accept(self)
|
stmt.body.accept(self)
|
||||||
self._write_line("condition", last=True)
|
|
||||||
with self._child_level(single=True):
|
|
||||||
stmt.condition.accept(self)
|
|
||||||
|
|
||||||
# Expressions
|
# Expressions
|
||||||
|
|
||||||
@@ -195,6 +199,29 @@ class MidasAstPrinter(
|
|||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.right.accept(self)
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||||
|
self._write_line("CallExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("callee")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
expr.callee.accept(self)
|
||||||
|
self._write_line("arguments")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(expr.arguments):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(expr.arguments) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
arg.accept(self)
|
||||||
|
self._write_line("keywords", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, (name, arg) in enumerate(expr.keywords.items()):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(expr.keywords) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._write_line(name)
|
||||||
|
with self._child_level(single=True):
|
||||||
|
arg.accept(self)
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr):
|
def visit_get_expr(self, expr: m.GetExpr):
|
||||||
self._write_line("GetExpr")
|
self._write_line("GetExpr")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
@@ -276,34 +303,41 @@ class MidasAstPrinter(
|
|||||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||||
self._write_line("FunctionType")
|
self._write_line("FunctionType")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line("pos_args")
|
self._write_line("params")
|
||||||
with self._child_level():
|
with self._child_level(single=True):
|
||||||
for i, arg in enumerate(type.pos_args):
|
self._visit_param_spec(type.params)
|
||||||
self._idx = i
|
|
||||||
if i == len(type.pos_args) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
self._print_function_arg(arg)
|
|
||||||
|
|
||||||
self._write_line("args")
|
|
||||||
with self._child_level():
|
|
||||||
for i, arg in enumerate(type.args):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(type.args) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
self._print_function_arg(arg)
|
|
||||||
|
|
||||||
self._write_line("kw_args")
|
|
||||||
with self._child_level():
|
|
||||||
for i, arg in enumerate(type.kw_args):
|
|
||||||
self._idx = i
|
|
||||||
if i == len(type.kw_args) - 1:
|
|
||||||
self._mark_last()
|
|
||||||
self._print_function_arg(arg)
|
|
||||||
|
|
||||||
self._write_line("returns", last=True)
|
self._write_line("returns", last=True)
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
type.returns.accept(self)
|
type.returns.accept(self)
|
||||||
|
|
||||||
|
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
|
||||||
|
self._write_line("ParamSpec")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("pos")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(spec.pos):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(spec.pos) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
|
self._write_line("mixed")
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(spec.mixed):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(spec.mixed) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
|
self._write_line("kw", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(spec.kw):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(spec.kw) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_function_arg(arg)
|
||||||
|
|
||||||
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
def _print_function_arg(self, arg: m.FunctionType.Argument) -> None:
|
||||||
self._write_line("Argument")
|
self._write_line("Argument")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
@@ -316,6 +350,25 @@ class MidasAstPrinter(
|
|||||||
arg.type.accept(self)
|
arg.type.accept(self)
|
||||||
self._write_line(f"required: {arg.required}", last=True)
|
self._write_line(f"required: {arg.required}", last=True)
|
||||||
|
|
||||||
|
def visit_frame_type(self, type: m.FrameType) -> None:
|
||||||
|
self._write_line("FrameType")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
self._write_line("columns")
|
||||||
|
with self._child_level():
|
||||||
|
for i, column in enumerate(type.columns):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(type.columns) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
self._print_frame_column(column)
|
||||||
|
|
||||||
|
def _print_frame_column(self, column: m.FrameType.Column) -> None:
|
||||||
|
self._write_line("Column")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line(f'name: "{column.name.lexeme}"')
|
||||||
|
self._write_line("type")
|
||||||
|
with self._child_level(single=True):
|
||||||
|
column.type.accept(self)
|
||||||
|
|
||||||
|
|
||||||
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
|
||||||
def __init__(self, indent: int = 4):
|
def __init__(self, indent: int = 4):
|
||||||
@@ -367,10 +420,9 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt):
|
||||||
name: str = stmt.name.lexeme
|
name: str = stmt.name.lexeme
|
||||||
subject: str = stmt.subject.lexeme
|
sig: str = "".join(self._visit_param_spec(spec) for spec in stmt.params)
|
||||||
type: str = stmt.type.accept(self)
|
body: str = stmt.body.accept(self)
|
||||||
condition: str = stmt.condition.accept(self)
|
return self.indented(f"predicate {name}{sig} = {body}")
|
||||||
return self.indented(f"predicate {name}({subject}: {type}) = {condition}")
|
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: m.LogicalExpr):
|
def visit_logical_expr(self, expr: m.LogicalExpr):
|
||||||
left: str = expr.left.accept(self)
|
left: str = expr.left.accept(self)
|
||||||
@@ -389,6 +441,12 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
right: str = expr.right.accept(self)
|
right: str = expr.right.accept(self)
|
||||||
return f"{operator}{right}"
|
return f"{operator}{right}"
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: m.CallExpr) -> str:
|
||||||
|
args: list[str] = [arg.accept(self) for arg in expr.arguments] + [
|
||||||
|
f"{name}={arg.accept(self)}" for name, arg in expr.keywords.items()
|
||||||
|
]
|
||||||
|
return f"{expr.callee.accept(self)}({', '.join(args)})"
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr):
|
def visit_get_expr(self, expr: m.GetExpr):
|
||||||
expr_: str = expr.expr.accept(self)
|
expr_: str = expr.expr.accept(self)
|
||||||
name: str = expr.name.lexeme
|
name: str = expr.name.lexeme
|
||||||
@@ -436,9 +494,13 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
return f"{type.base.accept(self)} & {type.extension.accept(self)}"
|
||||||
|
|
||||||
def visit_function_type(self, type: m.FunctionType) -> str:
|
def visit_function_type(self, type: m.FunctionType) -> str:
|
||||||
pos_args: list[str] = [self._print_arg(arg) for arg in type.pos_args]
|
spec: str = self._visit_param_spec(type.params)
|
||||||
mixed_args: list[str] = [self._print_arg(arg) for arg in type.args]
|
return f"fn {spec} -> {type.returns.accept(self)}"
|
||||||
kw_args: list[str] = [self._print_arg(arg) for arg in type.kw_args]
|
|
||||||
|
def _visit_param_spec(self, spec: m.ParamSpec) -> str:
|
||||||
|
pos_args: list[str] = [self._print_arg(arg) for arg in spec.pos]
|
||||||
|
mixed_args: list[str] = [self._print_arg(arg) for arg in spec.mixed]
|
||||||
|
kw_args: list[str] = [self._print_arg(arg) for arg in spec.kw]
|
||||||
args: list[str] = pos_args
|
args: list[str] = pos_args
|
||||||
|
|
||||||
if len(pos_args) != 0:
|
if len(pos_args) != 0:
|
||||||
@@ -447,8 +509,7 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
if len(kw_args) != 0:
|
if len(kw_args) != 0:
|
||||||
args.append("*")
|
args.append("*")
|
||||||
args += kw_args
|
args += kw_args
|
||||||
|
return f"({', '.join(args)})"
|
||||||
return f"fn ({', '.join(args)}) -> {type.returns.accept(self)}"
|
|
||||||
|
|
||||||
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
def _print_arg(self, arg: m.FunctionType.Argument) -> str:
|
||||||
res: str = ""
|
res: str = ""
|
||||||
@@ -460,6 +521,23 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
|
|||||||
res += "?"
|
res += "?"
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def visit_frame_type(self, type: m.FrameType) -> str:
|
||||||
|
res: str = self.indented("Frame[")
|
||||||
|
if len(type.columns) != 0:
|
||||||
|
res += "\n"
|
||||||
|
self.level += 1
|
||||||
|
columns: list[str] = []
|
||||||
|
for column in type.columns:
|
||||||
|
columns.append(self.indented(self._print_frame_column(column)))
|
||||||
|
res += ",\n".join(columns)
|
||||||
|
self.level -= 1
|
||||||
|
res += "\n"
|
||||||
|
res += "]"
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _print_frame_column(self, column: m.FrameType.Column) -> str:
|
||||||
|
return f"{column.name.lexeme}: {column.type.accept(self)}"
|
||||||
|
|
||||||
|
|
||||||
class PythonAstPrinter(
|
class PythonAstPrinter(
|
||||||
AstPrinter,
|
AstPrinter,
|
||||||
@@ -471,7 +549,13 @@ class PythonAstPrinter(
|
|||||||
self._write_line("BaseType")
|
self._write_line("BaseType")
|
||||||
with self._child_level():
|
with self._child_level():
|
||||||
self._write_line(f"base: {node.base}")
|
self._write_line(f"base: {node.base}")
|
||||||
self._write_optional_child("param", node.param, last=True)
|
self._write_line("args:", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, arg in enumerate(node.args):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(node.args) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
arg.accept(self)
|
||||||
|
|
||||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||||
self._write_line("ConstraintType")
|
self._write_line("ConstraintType")
|
||||||
@@ -715,9 +799,10 @@ class PythonAstPrinter(
|
|||||||
self._write_line("type")
|
self._write_line("type")
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.type.accept(self)
|
expr.type.accept(self)
|
||||||
self._write_line("expr", last=True)
|
self._write_line("expr")
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
expr.expr.accept(self)
|
expr.expr.accept(self)
|
||||||
|
self._write_line(f"unsafe: {expr.unsafe}", last=True)
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
|
||||||
self._write_line("TernaryExpr")
|
self._write_line("TernaryExpr")
|
||||||
@@ -783,6 +868,17 @@ class PythonAstPrinter(
|
|||||||
self._write_optional_child("upper", expr.upper)
|
self._write_optional_child("upper", expr.upper)
|
||||||
self._write_optional_child("step", expr.step, last=True)
|
self._write_optional_child("step", expr.step, last=True)
|
||||||
|
|
||||||
|
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||||
|
self._write_line("TupleExpr")
|
||||||
|
with self._child_level():
|
||||||
|
self._write_line("items", last=True)
|
||||||
|
with self._child_level():
|
||||||
|
for i, item in enumerate(expr.items):
|
||||||
|
self._idx = i
|
||||||
|
if i == len(expr.items) - 1:
|
||||||
|
self._mark_last()
|
||||||
|
item.accept(self)
|
||||||
|
|
||||||
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||||
self._write_line("RawExpr")
|
self._write_line("RawExpr")
|
||||||
with self._child_level(single=True):
|
with self._child_level(single=True):
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class MidasType(ABC):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class BaseType(MidasType):
|
class BaseType(MidasType):
|
||||||
base: str
|
base: str
|
||||||
param: Optional[MidasType]
|
args: tuple[MidasType, ...]
|
||||||
|
|
||||||
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
def accept(self, visitor: MidasType.Visitor[T]) -> T:
|
||||||
return visitor.visit_base_type(self)
|
return visitor.visit_base_type(self)
|
||||||
@@ -268,6 +268,9 @@ class Expr(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
|
def visit_slice_expr(self, expr: SliceExpr) -> T: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visit_tuple_expr(self, expr: TupleExpr) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_raw_expr(self, expr: RawExpr) -> T: ...
|
def visit_raw_expr(self, expr: RawExpr) -> T: ...
|
||||||
|
|
||||||
@@ -350,6 +353,7 @@ class LogicalExpr(Expr):
|
|||||||
class CastExpr(Expr):
|
class CastExpr(Expr):
|
||||||
type: MidasType
|
type: MidasType
|
||||||
expr: Expr
|
expr: Expr
|
||||||
|
unsafe: bool
|
||||||
|
|
||||||
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
return visitor.visit_cast_expr(self)
|
return visitor.visit_cast_expr(self)
|
||||||
@@ -401,6 +405,14 @@ class SliceExpr(Expr):
|
|||||||
return visitor.visit_slice_expr(self)
|
return visitor.visit_slice_expr(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TupleExpr(Expr):
|
||||||
|
items: tuple[Expr, ...]
|
||||||
|
|
||||||
|
def accept(self, visitor: Expr.Visitor[T]) -> T:
|
||||||
|
return visitor.visit_tuple_expr(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class RawExpr(Expr):
|
class RawExpr(Expr):
|
||||||
expr: ast.expr
|
expr: ast.expr
|
||||||
|
|||||||
@@ -179,3 +179,99 @@ extend dict[K, V] {
|
|||||||
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
|
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extend str {
|
||||||
|
def capitalize: fn() -> str
|
||||||
|
def casefold: fn() -> str
|
||||||
|
def center: fn(width: int, fillchar: str?, /) -> str
|
||||||
|
def count: fn(sub: str, start: None?, end: None?, /) -> int
|
||||||
|
def count: fn(sub: str, start: int, end: None?, /) -> int
|
||||||
|
def count: fn(sub: str, start: None, end: int, /) -> int
|
||||||
|
def count: fn(sub: str, start: int, end: int, /) -> int
|
||||||
|
def encode: fn(encoding: str?, errors: str?) -> bytes
|
||||||
|
def endswith: fn(suffix: str, start: None?, end: None?, /) -> bool
|
||||||
|
def endswith: fn(suffix: str, start: int, end: None?, /) -> bool
|
||||||
|
def endswith: fn(suffix: str, start: None, end: int, /) -> bool
|
||||||
|
def endswith: fn(suffix: str, start: int, end: int, /) -> bool
|
||||||
|
def expandtabs: fn(tabsize: int?) -> str
|
||||||
|
def find: fn(sub: str, start: None?, end: None?, /) -> int
|
||||||
|
def find: fn(sub: str, start: int, end: None?, /) -> int
|
||||||
|
def find: fn(sub: str, start: None, end: int, /) -> int
|
||||||
|
def find: fn(sub: str, start: int, end: int, /) -> int
|
||||||
|
// def format: fn(*args: object, **kwargs: object) -> str
|
||||||
|
// def format_map: fn(mapping: _FormatMapMapping, /) -> str
|
||||||
|
def index: fn(sub: str, start: None?, end: None?, /) -> int
|
||||||
|
def index: fn(sub: str, start: int, end: None?, /) -> int
|
||||||
|
def index: fn(sub: str, start: None, end: int, /) -> int
|
||||||
|
def index: fn(sub: str, start: int, end: int, /) -> int
|
||||||
|
def isalnum: fn() -> bool
|
||||||
|
def isalpha: fn() -> bool
|
||||||
|
def isascii: fn() -> bool
|
||||||
|
def isdecimal: fn() -> bool
|
||||||
|
def isdigit: fn() -> bool
|
||||||
|
def isidentifier: fn() -> bool
|
||||||
|
def islower: fn() -> bool
|
||||||
|
def isnumeric: fn() -> bool
|
||||||
|
def isprintable: fn() -> bool
|
||||||
|
def isspace: fn() -> bool
|
||||||
|
def istitle: fn() -> bool
|
||||||
|
def isupper: fn() -> bool
|
||||||
|
def join: fn(iterable: list[str], /) -> str // TODO: use Iterable
|
||||||
|
def ljust: fn(width: int, fillchar: str?, /) -> str
|
||||||
|
def lower: fn() -> str
|
||||||
|
def lstrip: fn(chars: None?, /) -> str
|
||||||
|
def lstrip: fn(chars: str, /) -> str
|
||||||
|
def partition: fn(sep: str, /) -> tuple[str, str, str]
|
||||||
|
|
||||||
|
def replace: fn(old: str, new: str, count: int?, /) -> str
|
||||||
|
|
||||||
|
def removeprefix: fn(prefix: str, /) -> str
|
||||||
|
def removesuffix: fn(suffix: str, /) -> str
|
||||||
|
def rfind: fn(sub: str, start: None?, end: None?, /) -> int
|
||||||
|
def rfind: fn(sub: str, start: int, end: None?, /) -> int
|
||||||
|
def rfind: fn(sub: str, start: None, end: int, /) -> int
|
||||||
|
def rfind: fn(sub: str, start: int, end: int, /) -> int
|
||||||
|
def rindex: fn(sub: str, start: None?, end: None?, /) -> int
|
||||||
|
def rindex: fn(sub: str, start: int, end: None?, /) -> int
|
||||||
|
def rindex: fn(sub: str, start: None, end: int, /) -> int
|
||||||
|
def rindex: fn(sub: str, start: int, end: int, /) -> int
|
||||||
|
def rjust: fn(width: int, fillchar: str?, /) -> str
|
||||||
|
def rpartition: fn(sep: str, /) -> tuple[str, str, str]
|
||||||
|
def rsplit: fn(sep: None?, maxsplit: int?) -> list[str]
|
||||||
|
def rsplit: fn(sep: str, maxsplit: int?) -> list[str]
|
||||||
|
def rstrip: fn(chars: None?, /) -> str
|
||||||
|
def rstrip: fn(chars: str, /) -> str
|
||||||
|
def split: fn(sep: None?, maxsplit: int?) -> list[str]
|
||||||
|
def split: fn(sep: str, maxsplit: int?) -> list[str]
|
||||||
|
def splitlines: fn(keepends: bool?) -> list[str]
|
||||||
|
def startswith: fn(prefix: str, start: None?, end: None?, /) -> bool
|
||||||
|
def startswith: fn(prefix: str, start: int, end: None?, /) -> bool
|
||||||
|
def startswith: fn(prefix: str, start: None, end: int, /) -> bool
|
||||||
|
def startswith: fn(prefix: str, start: int, end: int, /) -> bool
|
||||||
|
def strip: fn(chars: None?, /) -> str
|
||||||
|
def strip: fn(chars: str, /) -> str
|
||||||
|
def swapcase: fn() -> str
|
||||||
|
def title: fn() -> str
|
||||||
|
// def translate: fn(table: _TranslateTable, /) -> str
|
||||||
|
def upper: fn() -> str
|
||||||
|
def zfill: fn(width: int, /) -> str
|
||||||
|
def __add__: fn(value: str, /) -> str
|
||||||
|
// Incompatible with Sequence.__contains__
|
||||||
|
def __contains__: fn(key: str, /) -> bool
|
||||||
|
def __eq__: fn(value: object, /) -> bool
|
||||||
|
def __ge__: fn(value: str, /) -> bool
|
||||||
|
def __getitem__: fn(key: slice, /) -> str
|
||||||
|
def __getitem__: fn(key: int, /) -> str
|
||||||
|
def __gt__: fn(value: str, /) -> bool
|
||||||
|
def __hash__: fn() -> int
|
||||||
|
// def __iter__: fn() -> Iterator[str]
|
||||||
|
def __le__: fn(value: str, /) -> bool
|
||||||
|
def __len__: fn() -> int
|
||||||
|
def __lt__: fn(value: str, /) -> bool
|
||||||
|
def __mod__: fn(value: Any, /) -> str
|
||||||
|
def __mul__: fn(value: int, /) -> str
|
||||||
|
def __ne__: fn(value: object, /) -> bool
|
||||||
|
def __rmul__: fn(value: int, /) -> str
|
||||||
|
def __getnewargs__: fn() -> tuple[str]
|
||||||
|
def __format__: fn(format_spec: str, /) -> str
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||||
|
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
|
||||||
"float": {"int"},
|
"float": {"int"},
|
||||||
"int": {"bool"},
|
"int": {"bool"},
|
||||||
}
|
}
|
||||||
@@ -25,12 +26,15 @@ def define_builtins(reg: TypesRegistry):
|
|||||||
any = reg.define_type("Any", TopType())
|
any = reg.define_type("Any", TopType())
|
||||||
unit = reg.define_type("None", UnitType())
|
unit = reg.define_type("None", UnitType())
|
||||||
object = reg.define_type("object", BaseType(name="object"))
|
object = reg.define_type("object", BaseType(name="object"))
|
||||||
|
bytes = reg.define_type("bytes", BaseType(name="bytes"))
|
||||||
bool = reg.define_type("bool", BaseType(name="bool"))
|
bool = reg.define_type("bool", BaseType(name="bool"))
|
||||||
int = reg.define_type("int", BaseType(name="int"))
|
int = reg.define_type("int", BaseType(name="int"))
|
||||||
float = reg.define_type("float", BaseType(name="float"))
|
float = reg.define_type("float", BaseType(name="float"))
|
||||||
str = reg.define_type("str", BaseType(name="str"))
|
str = reg.define_type("str", BaseType(name="str"))
|
||||||
slice = reg.define_type("slice", BaseType(name="slice"))
|
slice = reg.define_type("slice", BaseType(name="slice"))
|
||||||
|
|
||||||
|
tuple = reg.define_type("tuple", BaseType(name="tuple"))
|
||||||
|
|
||||||
list = reg.define_type(
|
list = reg.define_type(
|
||||||
"list",
|
"list",
|
||||||
GenericType(
|
GenericType(
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ class DiagnosticType(StrEnum):
|
|||||||
ERROR = "Error"
|
ERROR = "Error"
|
||||||
WARNING = "Warning"
|
WARNING = "Warning"
|
||||||
INFO = "Info"
|
INFO = "Info"
|
||||||
|
DEBUG = "Debug"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|||||||
172
midas/checker/evaluator.py
Normal file
172
midas/checker/evaluator.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
from midas.checker.preamble import Preamble
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter
|
||||||
|
from midas.checker.types import Function, Predicate
|
||||||
|
from midas.lexer.token import TokenType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class PartialPredicate(Predicate):
|
||||||
|
scope: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class Evaluator(m.Expr.Visitor[Any]):
|
||||||
|
def __init__(self, types: TypesRegistry, reporter: Optional[FileReporter] = None):
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self.reporter: Optional[FileReporter] = reporter
|
||||||
|
self.preamble: Preamble = Preamble(self.types)
|
||||||
|
self.scopes: list[dict[str, Any]] = [{}]
|
||||||
|
|
||||||
|
def evaluate(self, expr: m.Expr) -> Any:
|
||||||
|
value: Any = expr.accept(self)
|
||||||
|
if self.reporter is not None:
|
||||||
|
self.reporter.debug(expr.location, f"Value: {value}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
def get_value(self, name: str) -> Any:
|
||||||
|
scope: dict[str, Any] = self.scopes[-1]
|
||||||
|
return scope[name]
|
||||||
|
|
||||||
|
def set_value(self, name: str, value: Any, force_declare: bool = False):
|
||||||
|
if not force_declare:
|
||||||
|
for scope in reversed(self.scopes):
|
||||||
|
if name in scope:
|
||||||
|
scope[name] = value
|
||||||
|
return
|
||||||
|
self.scopes[-1][name] = value
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> Any:
|
||||||
|
def left():
|
||||||
|
return self.evaluate(expr.left)
|
||||||
|
|
||||||
|
def right():
|
||||||
|
return self.evaluate(expr.right)
|
||||||
|
|
||||||
|
match expr.operator.type:
|
||||||
|
case TokenType.AND:
|
||||||
|
return left() and right()
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: m.BinaryExpr) -> Any:
|
||||||
|
left: Any = self.evaluate(expr.left)
|
||||||
|
right: Any = self.evaluate(expr.right)
|
||||||
|
match expr.operator.type:
|
||||||
|
case TokenType.MINUS:
|
||||||
|
return left - right
|
||||||
|
case TokenType.STAR:
|
||||||
|
return left * right
|
||||||
|
case TokenType.SLASH:
|
||||||
|
return left / right
|
||||||
|
case TokenType.GREATER:
|
||||||
|
return left > right
|
||||||
|
case TokenType.GREATER_EQUAL:
|
||||||
|
return left >= right
|
||||||
|
case TokenType.LESS:
|
||||||
|
return left < right
|
||||||
|
case TokenType.LESS_EQUAL:
|
||||||
|
return left <= right
|
||||||
|
case TokenType.EQUAL_EQUAL:
|
||||||
|
return left == right
|
||||||
|
case TokenType.BANG_EQUAL:
|
||||||
|
return left != right
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> Any:
|
||||||
|
right: Any = self.evaluate(expr.right)
|
||||||
|
match expr.operator.type:
|
||||||
|
case TokenType.MINUS:
|
||||||
|
return -right
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: m.CallExpr) -> Any:
|
||||||
|
callee: Any = self.evaluate(expr.callee)
|
||||||
|
args: list[Any] = [self.evaluate(arg) for arg in expr.arguments]
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
name: self.evaluate(arg) for name, arg in expr.keywords.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
match callee:
|
||||||
|
case Predicate():
|
||||||
|
return self._evaluate_predicate(callee, args, kwargs)
|
||||||
|
case _ if callable(callee):
|
||||||
|
return callee(*args, **kwargs)
|
||||||
|
case _:
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: m.GetExpr) -> Any:
|
||||||
|
obj: Any = self.evaluate(expr.expr)
|
||||||
|
return getattr(obj, expr.name.lexeme)
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr) -> Any:
|
||||||
|
name: str = expr.name.lexeme
|
||||||
|
for scope in reversed(self.scopes):
|
||||||
|
if name in scope:
|
||||||
|
return scope[name]
|
||||||
|
|
||||||
|
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||||
|
if predicate is not None:
|
||||||
|
if predicate.alias:
|
||||||
|
return self.evaluate(predicate.body)
|
||||||
|
return predicate
|
||||||
|
|
||||||
|
glob: Optional[Callable] = self.preamble.get_py_func(name)
|
||||||
|
if glob is not None:
|
||||||
|
return glob
|
||||||
|
raise NameError(f"Unknown variable '{name}'")
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Any:
|
||||||
|
return self.evaluate(expr.expr)
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: m.LiteralExpr) -> Any:
|
||||||
|
return expr.value
|
||||||
|
|
||||||
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Any:
|
||||||
|
return self.get_value("_")
|
||||||
|
|
||||||
|
def _evaluate_predicate(
|
||||||
|
self, predicate: Predicate, args: list[Any], kwargs: dict[str, Any]
|
||||||
|
) -> Any:
|
||||||
|
res: Any = None
|
||||||
|
if isinstance(predicate, PartialPredicate):
|
||||||
|
self.scopes.append(predicate.scope)
|
||||||
|
else:
|
||||||
|
self.scopes.append({})
|
||||||
|
match predicate.type:
|
||||||
|
case Function(returns=Function() as inner):
|
||||||
|
self._map_args(predicate.type, args, kwargs)
|
||||||
|
res = PartialPredicate(
|
||||||
|
type=inner,
|
||||||
|
body=predicate.body,
|
||||||
|
alias=False,
|
||||||
|
scope=self.scopes[-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
case Function():
|
||||||
|
self._map_args(predicate.type, args, kwargs)
|
||||||
|
res = self.evaluate(predicate.body)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError
|
||||||
|
self.scopes.pop()
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _map_args(self, function: Function, args: list[Any], kwargs: dict[str, Any]):
|
||||||
|
positional: list[Function.Argument] = function.pos_args + function.args
|
||||||
|
keywords: dict[str, Function.Argument] = {
|
||||||
|
arg.name: arg for arg in function.args + function.kw_args
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
param: Function.Argument = positional[i]
|
||||||
|
self.set_value(param.name, arg)
|
||||||
|
|
||||||
|
for name, arg in kwargs.items():
|
||||||
|
param: Function.Argument = keywords[name]
|
||||||
|
self.set_value(param.name, arg)
|
||||||
198
midas/checker/frame_methods.py
Normal file
198
midas/checker/frame_methods.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||||
|
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter
|
||||||
|
from midas.checker.types import (
|
||||||
|
ColumnType,
|
||||||
|
DataFrameType,
|
||||||
|
Function,
|
||||||
|
OverloadedFunction,
|
||||||
|
TopType,
|
||||||
|
Type,
|
||||||
|
UnknownType,
|
||||||
|
unfold_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from midas.checker.python import PythonTyper, TypedExpr
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def frame_method(*names: str):
|
||||||
|
def wrapper(func):
|
||||||
|
names_: tuple[str, ...] = names
|
||||||
|
if len(names_) == 0:
|
||||||
|
names_ = (func.__name__,)
|
||||||
|
setattr(func, "__method_names__", names_)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Call:
|
||||||
|
location: Location
|
||||||
|
frame: DataFrameType
|
||||||
|
positional: list[TypedExpr]
|
||||||
|
keywords: dict[str, TypedExpr]
|
||||||
|
|
||||||
|
|
||||||
|
class _MethodRegistryMeta(type):
|
||||||
|
_methods: dict[str, Callable[..., Type]] = {}
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
bases: tuple[type, ...],
|
||||||
|
namespace: dict[str, Any],
|
||||||
|
):
|
||||||
|
new_class = super().__new__(cls, name, bases, namespace)
|
||||||
|
new_class._methods = {}
|
||||||
|
for attr in namespace.values():
|
||||||
|
if callable(attr) and hasattr(attr, "__method_names__"):
|
||||||
|
for name in attr.__method_names__: # type: ignore
|
||||||
|
new_class._methods[name] = attr # type: ignore
|
||||||
|
return new_class
|
||||||
|
|
||||||
|
|
||||||
|
class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||||
|
def __init__(self, typer: PythonTyper) -> None:
|
||||||
|
self.typer: PythonTyper = typer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reporter(self) -> FileReporter:
|
||||||
|
return self.typer.reporter
|
||||||
|
|
||||||
|
@property
|
||||||
|
def types(self) -> TypesRegistry:
|
||||||
|
return self.typer.types
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
call: Call,
|
||||||
|
) -> Type:
|
||||||
|
func: Optional[Callable[..., Type]] = self._methods.get(method)
|
||||||
|
if func is None:
|
||||||
|
self.reporter.warning(call.location, f"Unknown method {method}")
|
||||||
|
return UnknownType()
|
||||||
|
return func(self, call)
|
||||||
|
|
||||||
|
@frame_method("add", "__add__")
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
call: Call,
|
||||||
|
) -> Type:
|
||||||
|
# TODO: support add with scalar, sequence, Series, dict
|
||||||
|
# TODO: check operation exists on inner column types
|
||||||
|
|
||||||
|
new_columns: list[DataFrameType.Column] = []
|
||||||
|
|
||||||
|
by_name: dict[str, DataFrameType.Column] = {}
|
||||||
|
frame2: Optional[DataFrameType] = None
|
||||||
|
if len(call.positional) != 0:
|
||||||
|
other: Type = call.positional[0][1]
|
||||||
|
unfolded_other: Type = unfold_type(other)
|
||||||
|
if isinstance(unfolded_other, DataFrameType):
|
||||||
|
frame2 = unfolded_other
|
||||||
|
by_name = {
|
||||||
|
col.name: col for col in frame2.columns if col.name is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
in_frame1: set[str] = set()
|
||||||
|
for column in call.frame.columns:
|
||||||
|
if column.name is not None:
|
||||||
|
in_frame1.add(column.name)
|
||||||
|
|
||||||
|
col_type1: Type = column.type
|
||||||
|
col_type: Type = ColumnType(type=UnknownType())
|
||||||
|
if column.name in by_name:
|
||||||
|
column2 = by_name[column.name]
|
||||||
|
col_type2: Type = column2.type
|
||||||
|
if self.types.are_equivalent(col_type2, col_type1):
|
||||||
|
col_type = col_type1
|
||||||
|
|
||||||
|
new_column = DataFrameType.Column(
|
||||||
|
index=column.index,
|
||||||
|
name=column.name,
|
||||||
|
type=col_type,
|
||||||
|
)
|
||||||
|
new_columns.append(new_column)
|
||||||
|
|
||||||
|
if frame2 is not None:
|
||||||
|
for column in frame2.columns:
|
||||||
|
if column.name in in_frame1:
|
||||||
|
continue
|
||||||
|
new_columns.append(
|
||||||
|
DataFrameType.Column(
|
||||||
|
index=len(new_columns),
|
||||||
|
name=column.name,
|
||||||
|
type=ColumnType(type=UnknownType()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
signature = Function(
|
||||||
|
args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="other",
|
||||||
|
type=DataFrameType(columns=[]),
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
returns=DataFrameType(columns=new_columns),
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.typer._get_call_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=signature,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
or UnknownType()
|
||||||
|
)
|
||||||
|
|
||||||
|
@frame_method()
|
||||||
|
def mean(self, call: Call) -> Type:
|
||||||
|
with_axis = Function(
|
||||||
|
kw_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="axis",
|
||||||
|
type=self.types.get_type("int"),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
returns=ColumnType(type=TopType()),
|
||||||
|
)
|
||||||
|
without_axis = Function(
|
||||||
|
kw_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="axis",
|
||||||
|
type=self.types.get_type("None"),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
returns=TopType(),
|
||||||
|
)
|
||||||
|
overload = OverloadedFunction(
|
||||||
|
overloads=[
|
||||||
|
with_axis,
|
||||||
|
without_axis,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
self.typer._get_call_result(
|
||||||
|
location=call.location,
|
||||||
|
callee=overload,
|
||||||
|
positional=call.positional,
|
||||||
|
keywords=call.keywords,
|
||||||
|
)
|
||||||
|
or UnknownType()
|
||||||
|
)
|
||||||
154
midas/checker/frames.py
Normal file
154
midas/checker/frames.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.ast.location import Location
|
||||||
|
from midas.checker.frame_methods import Call, MethodRegistry
|
||||||
|
from midas.checker.reporter import FileReporter
|
||||||
|
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from midas.checker.python import PythonTyper, TypedExpr
|
||||||
|
|
||||||
|
|
||||||
|
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
||||||
|
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
|
||||||
|
|
||||||
|
|
||||||
|
class FrameManager:
|
||||||
|
def __init__(self, typer: PythonTyper) -> None:
|
||||||
|
self.typer: PythonTyper = typer
|
||||||
|
self.method_resolver: MethodRegistry = MethodRegistry(self.typer)
|
||||||
|
|
||||||
|
def assign(
|
||||||
|
self,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
index: p.Expr,
|
||||||
|
value_type: Type,
|
||||||
|
) -> Type:
|
||||||
|
match index:
|
||||||
|
case p.LiteralExpr(value=str() as name):
|
||||||
|
return self.assign_column(reporter, location, frame, name, value_type)
|
||||||
|
|
||||||
|
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||||
|
isinstance(idx, str) for idx in indices
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
case _:
|
||||||
|
reporter.error(location, f"Invalid index type {index} on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def assign_column(
|
||||||
|
self,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
name: str,
|
||||||
|
type: Type,
|
||||||
|
) -> Type:
|
||||||
|
if not isinstance(type, ColumnType):
|
||||||
|
reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
|
||||||
|
)
|
||||||
|
return frame
|
||||||
|
return self._set_column(frame, name, type)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
reporter: FileReporter,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
index: p.Expr,
|
||||||
|
) -> Type:
|
||||||
|
match index:
|
||||||
|
case p.LiteralExpr(value=str() as name):
|
||||||
|
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||||
|
if column is None:
|
||||||
|
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
return column
|
||||||
|
|
||||||
|
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
|
||||||
|
isinstance(index.value, str) for index in indices
|
||||||
|
):
|
||||||
|
names: list[str] = [cast(str, index.value) for index in indices]
|
||||||
|
columns: list[ColumnType] = []
|
||||||
|
for name in names:
|
||||||
|
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
|
||||||
|
if column is None:
|
||||||
|
reporter.error(location, f"Unknown column '{name}' on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
columns.append(column)
|
||||||
|
return TupleType(items=tuple(columns))
|
||||||
|
|
||||||
|
case _:
|
||||||
|
reporter.error(location, f"Invalid index type {index} on {frame}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_column(
|
||||||
|
cls, frame: DataFrameType, name: str, column: ColumnType
|
||||||
|
) -> DataFrameType:
|
||||||
|
new_columns: list[DataFrameType.Column] = []
|
||||||
|
index: int = len(frame.columns)
|
||||||
|
replace: bool = False
|
||||||
|
for i, col in enumerate(frame.columns):
|
||||||
|
if col.name == name:
|
||||||
|
index = i
|
||||||
|
replace = True
|
||||||
|
# TODO: check column type here to prevent changing it
|
||||||
|
new_columns.append(col)
|
||||||
|
|
||||||
|
new_col: DataFrameType.Column = DataFrameType.Column(
|
||||||
|
index=index,
|
||||||
|
name=name,
|
||||||
|
type=column,
|
||||||
|
)
|
||||||
|
if replace:
|
||||||
|
new_columns[index] = new_col
|
||||||
|
else:
|
||||||
|
new_columns.append(new_col)
|
||||||
|
|
||||||
|
return DataFrameType(columns=new_columns)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_columns(
|
||||||
|
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
|
||||||
|
) -> DataFrameType:
|
||||||
|
for name, col in zip(names, columns):
|
||||||
|
frame = cls._set_column(frame, name, col)
|
||||||
|
return frame
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
|
||||||
|
for col in frame.columns:
|
||||||
|
if col.name == name:
|
||||||
|
return col.type
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_columns(
|
||||||
|
cls, frame: DataFrameType, names: list[str]
|
||||||
|
) -> list[Optional[ColumnType]]:
|
||||||
|
return [cls._get_column(frame, name) for name in names]
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
location: Location,
|
||||||
|
frame: DataFrameType,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
) -> Type:
|
||||||
|
call: Call = Call(
|
||||||
|
location=location,
|
||||||
|
frame=frame,
|
||||||
|
positional=positional,
|
||||||
|
keywords=keywords,
|
||||||
|
)
|
||||||
|
return self.method_resolver.call(method, call)
|
||||||
@@ -1,27 +1,67 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import midas.ast.midas as m
|
import midas.ast.midas as m
|
||||||
|
from midas.ast.location import Location
|
||||||
from midas.checker.builtins import define_builtins
|
from midas.checker.builtins import define_builtins
|
||||||
|
from midas.checker.environment import Environment
|
||||||
|
from midas.checker.operators import MIDAS_BINARY_METHODS, MIDAS_UNARY_METHODS
|
||||||
|
from midas.checker.preamble import Preamble
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.reporter import FileReporter, Reporter
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
|
AppliedType,
|
||||||
|
ColumnType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
|
OverloadedFunction,
|
||||||
|
Predicate,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
|
unfold_type,
|
||||||
)
|
)
|
||||||
|
from midas.checker.variance import VarianceInferrer
|
||||||
from midas.lexer.midas import MidasLexer
|
from midas.lexer.midas import MidasLexer
|
||||||
from midas.lexer.token import Token
|
from midas.lexer.token import Token
|
||||||
from midas.parser.midas import MidasParser
|
from midas.parser.midas import MidasParser
|
||||||
|
|
||||||
|
|
||||||
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type]):
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TypedParamSpec:
|
||||||
|
pos: list[Function.Argument]
|
||||||
|
mixed: list[Function.Argument]
|
||||||
|
kw: list[Function.Argument]
|
||||||
|
|
||||||
|
|
||||||
|
TypedExpr = tuple[m.Expr, Type]
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class MappedArgument:
|
||||||
|
expr: m.Expr
|
||||||
|
type: Type
|
||||||
|
argument: Function.Argument
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class OverloadCandidate:
|
||||||
|
function: Function
|
||||||
|
mapped: list[MappedArgument]
|
||||||
|
|
||||||
|
|
||||||
|
class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type]):
|
||||||
"""A resolver which evaluates Midas type definitions and build a registry"""
|
"""A resolver which evaluates Midas type definitions and build a registry"""
|
||||||
|
|
||||||
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
|
||||||
@@ -31,12 +71,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
self._local_variables: dict[str, TypeVar] = {}
|
self._local_variables: dict[str, TypeVar] = {}
|
||||||
|
|
||||||
|
self._predicate_params: dict[str, Type] = {}
|
||||||
|
|
||||||
self._current_name: Optional[str] = None
|
self._current_name: Optional[str] = None
|
||||||
|
|
||||||
define_builtins(self.types)
|
define_builtins(self.types)
|
||||||
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
|
builtins_path: Path = (Path(__file__).parent / "builtins.midas").resolve()
|
||||||
self.process(builtins_path.read_text(), str(builtins_path))
|
self.process(builtins_path.read_text(), str(builtins_path))
|
||||||
|
|
||||||
|
self._bool: Type = self.get_type("bool")
|
||||||
|
|
||||||
|
self._preamble: Environment = Preamble(self.types)
|
||||||
|
|
||||||
def process(self, source: str, path: Optional[str]):
|
def process(self, source: str, path: Optional[str]):
|
||||||
self.reporter = self.reporter.for_file(path)
|
self.reporter = self.reporter.for_file(path)
|
||||||
lexer: MidasLexer = MidasLexer(source)
|
lexer: MidasLexer = MidasLexer(source)
|
||||||
@@ -47,6 +93,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
self.reporter.error(error.token.get_location(), error.message)
|
self.reporter.error(error.token.get_location(), error.message)
|
||||||
self.resolve(stmts)
|
self.resolve(stmts)
|
||||||
|
|
||||||
|
def type_of(self, expr: m.Expr) -> Type:
|
||||||
|
type: Type = expr.accept(self)
|
||||||
|
return type
|
||||||
|
|
||||||
def get_type(self, name: str) -> Type:
|
def get_type(self, name: str) -> Type:
|
||||||
"""Get a type from its name
|
"""Get a type from its name
|
||||||
|
|
||||||
@@ -63,6 +113,19 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
return self._local_variables[name]
|
return self._local_variables[name]
|
||||||
return self.types.get_type(name)
|
return self.types.get_type(name)
|
||||||
|
|
||||||
|
def get_variable(self, name: str) -> Type:
|
||||||
|
if name in self._predicate_params:
|
||||||
|
return self._predicate_params[name]
|
||||||
|
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||||
|
if predicate is not None:
|
||||||
|
return predicate.type
|
||||||
|
|
||||||
|
global_: Optional[Type] = self._preamble.get(name)
|
||||||
|
if global_ is not None:
|
||||||
|
return global_
|
||||||
|
|
||||||
|
raise NameError(f"Unknown variable '{name}'")
|
||||||
|
|
||||||
def resolve(self, stmts: list[m.Stmt]):
|
def resolve(self, stmts: list[m.Stmt]):
|
||||||
"""Process a sequence of statements
|
"""Process a sequence of statements
|
||||||
|
|
||||||
@@ -72,6 +135,16 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
for stmt in stmts:
|
for stmt in stmts:
|
||||||
stmt.accept(self)
|
stmt.accept(self)
|
||||||
|
|
||||||
|
for name, type in self.types._types.items():
|
||||||
|
if isinstance(type, GenericType):
|
||||||
|
inferrer = VarianceInferrer(self.types)
|
||||||
|
self.types._types[name] = inferrer.infer(type)
|
||||||
|
|
||||||
|
def assert_bool(self, expr: m.Expr):
|
||||||
|
type: Type = self.type_of(expr)
|
||||||
|
if not self.types.is_subtype(type, self._bool):
|
||||||
|
self.reporter.error(expr.location, f"Must be a boolean but is {type}")
|
||||||
|
|
||||||
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
def visit_type_stmt(self, stmt: m.TypeStmt) -> None:
|
||||||
name: str = stmt.name.lexeme
|
name: str = stmt.name.lexeme
|
||||||
self._current_name = name
|
self._current_name = name
|
||||||
@@ -102,35 +175,167 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
base_name,
|
base_name,
|
||||||
member.name.lexeme,
|
member.name.lexeme,
|
||||||
member_type,
|
member_type,
|
||||||
member.kind == m.MemberKind.METHOD,
|
member.kind,
|
||||||
)
|
)
|
||||||
|
|
||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||||
self.reporter.warning(stmt.location, "PredicateStmt not yet supported")
|
for spec in stmt.params:
|
||||||
|
for param in spec.mixed:
|
||||||
|
assert param.name is not None
|
||||||
|
self._predicate_params[param.name.lexeme] = param.type.accept(self)
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
type: Type = self.type_of(stmt.body)
|
||||||
self.reporter.warning(expr.location, "LogicalExpr not yet supported")
|
params: list[TypedParamSpec] = [
|
||||||
|
self._visit_param_spec(spec) for spec in stmt.params
|
||||||
|
]
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: m.BinaryExpr) -> None:
|
if not self._is_valid_predicate(type):
|
||||||
self.reporter.warning(expr.location, "BinaryExpr not yet supported")
|
self.reporter.error(
|
||||||
|
stmt.body.location,
|
||||||
|
f"Predicate function body must evaluate to a boolean, got {type}",
|
||||||
|
)
|
||||||
|
if len(params) != 0:
|
||||||
|
type = self._bool
|
||||||
|
for spec in reversed(params):
|
||||||
|
type = Function(
|
||||||
|
pos_args=spec.pos,
|
||||||
|
args=spec.mixed,
|
||||||
|
kw_args=spec.kw,
|
||||||
|
returns=type,
|
||||||
|
)
|
||||||
|
self._predicate_params = {}
|
||||||
|
self.types.define_predicate(
|
||||||
|
stmt.name.lexeme,
|
||||||
|
Predicate(
|
||||||
|
type=type,
|
||||||
|
body=stmt.body,
|
||||||
|
alias=len(params) == 0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: m.UnaryExpr) -> None:
|
def _is_valid_predicate(self, body: Type) -> bool:
|
||||||
self.reporter.warning(expr.location, "UnaryExpr not yet supported")
|
match body:
|
||||||
|
case Function(returns=returns):
|
||||||
|
return self._is_valid_predicate(returns)
|
||||||
|
case _ if self.types.is_subtype(body, self._bool):
|
||||||
|
return True
|
||||||
|
case _:
|
||||||
|
return False
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "GetExpr not yet supported")
|
self.assert_bool(expr.left)
|
||||||
|
self.assert_bool(expr.right)
|
||||||
|
return self._bool
|
||||||
|
|
||||||
def visit_variable_expr(self, expr: m.VariableExpr) -> None:
|
def visit_binary_expr(self, expr: m.BinaryExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "VariableExpr not yet supported")
|
method: Optional[str] = MIDAS_BINARY_METHODS.get(expr.operator.type)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operator {expr.operator.lexeme}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_grouping_expr(self, expr: m.GroupingExpr) -> None:
|
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||||
|
|
||||||
|
def _visit_binary_expr(
|
||||||
|
self, location: Location, left_expr: m.Expr, right_expr: m.Expr, method: str
|
||||||
|
) -> Type:
|
||||||
|
left: Type = self.type_of(left_expr)
|
||||||
|
right: Type = self.type_of(right_expr)
|
||||||
|
|
||||||
|
operation: Optional[Type] = self.types.lookup_member(left, method)
|
||||||
|
if operation is None:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
result: Optional[Type] = self._get_call_result(
|
||||||
|
location,
|
||||||
|
operation,
|
||||||
|
[(right_expr, right)],
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
return result or UnknownType()
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> Type:
|
||||||
|
method: Optional[str] = MIDAS_UNARY_METHODS.get(expr.operator.type)
|
||||||
|
if method is None:
|
||||||
|
self.logger.warning(f"Unsupported operator {expr.operator.lexeme}")
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operator {expr.operator.lexeme}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
operand: Type = self.type_of(expr.right)
|
||||||
|
operation: Optional[Type] = self.types.lookup_member(operand, method)
|
||||||
|
if operation is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Undefined operation {method} for {operand}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
result: Optional[Type] = self._get_call_result(
|
||||||
|
expr.location,
|
||||||
|
operation,
|
||||||
|
[],
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
return result or UnknownType()
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: m.CallExpr) -> Type:
|
||||||
|
callee: Type = expr.callee.accept(self)
|
||||||
|
positional: list[TypedExpr] = [
|
||||||
|
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||||
|
]
|
||||||
|
keywords: dict[str, TypedExpr] = {
|
||||||
|
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
self._get_call_result(
|
||||||
|
expr.location,
|
||||||
|
callee,
|
||||||
|
positional,
|
||||||
|
keywords,
|
||||||
|
)
|
||||||
|
or UnknownType()
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: m.GetExpr) -> Type:
|
||||||
|
object: Type = expr.expr.accept(self)
|
||||||
|
member: Optional[Type] = self.types.lookup_member(object, expr.name.lexeme)
|
||||||
|
if member is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Unknown member '{expr.name.lexeme}' of {object}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return member
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr) -> Type:
|
||||||
|
return self.get_variable(expr.name.lexeme)
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Type:
|
||||||
return expr.expr.accept(self)
|
return expr.expr.accept(self)
|
||||||
|
|
||||||
def visit_literal_expr(self, expr: m.LiteralExpr) -> None:
|
def visit_literal_expr(self, expr: m.LiteralExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "LiteralExpr not yet supported")
|
match expr.value:
|
||||||
|
case bool(): # Must be before int
|
||||||
|
return self.types.get_type("bool")
|
||||||
|
case int():
|
||||||
|
return self.types.get_type("int")
|
||||||
|
case float():
|
||||||
|
return self.types.get_type("float")
|
||||||
|
case str():
|
||||||
|
return self.types.get_type("str")
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> None:
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Type:
|
||||||
self.reporter.warning(expr.location, "WildcardExpr not yet supported")
|
return self.get_variable("_")
|
||||||
|
|
||||||
def visit_named_type(self, type: m.NamedType) -> Type:
|
def visit_named_type(self, type: m.NamedType) -> Type:
|
||||||
name: str = type.name.lexeme
|
name: str = type.name.lexeme
|
||||||
@@ -153,10 +358,10 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
def visit_constraint_type(self, type: m.ConstraintType) -> Type:
|
||||||
type_: Type = type.type.accept(self)
|
return ConstraintType(
|
||||||
type.constraint.accept(self)
|
type=type.type.accept(self),
|
||||||
# TODO
|
constraint=type.constraint,
|
||||||
return UnknownType()
|
)
|
||||||
|
|
||||||
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
|
def visit_complex_type(self, type: m.ComplexType) -> ComplexType:
|
||||||
return ComplexType(
|
return ComplexType(
|
||||||
@@ -172,8 +377,17 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
)
|
)
|
||||||
|
|
||||||
def visit_function_type(self, type: m.FunctionType) -> Type:
|
def visit_function_type(self, type: m.FunctionType) -> Type:
|
||||||
n_pos_args: int = len(type.pos_args)
|
params: TypedParamSpec = self._visit_param_spec(type.params)
|
||||||
n_args: int = len(type.args)
|
return Function(
|
||||||
|
pos_args=params.pos,
|
||||||
|
args=params.mixed,
|
||||||
|
kw_args=params.kw,
|
||||||
|
returns=type.returns.accept(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _visit_param_spec(self, spec: m.ParamSpec) -> TypedParamSpec:
|
||||||
|
n_pos: int = len(spec.pos)
|
||||||
|
n_mixed: int = len(spec.mixed)
|
||||||
|
|
||||||
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
def process_arg(arg: m.FunctionType.Argument, i: int) -> Function.Argument:
|
||||||
return Function.Argument(
|
return Function.Argument(
|
||||||
@@ -183,14 +397,22 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
required=arg.required,
|
required=arg.required,
|
||||||
)
|
)
|
||||||
|
|
||||||
return Function(
|
return TypedParamSpec(
|
||||||
pos_args=[process_arg(arg, i) for i, arg in enumerate(type.pos_args)],
|
pos=[process_arg(arg, i) for i, arg in enumerate(spec.pos)],
|
||||||
args=[process_arg(arg, i + n_pos_args) for i, arg in enumerate(type.args)],
|
mixed=[process_arg(arg, i + n_pos) for i, arg in enumerate(spec.mixed)],
|
||||||
kw_args=[
|
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
|
||||||
process_arg(arg, i + n_pos_args + n_args)
|
)
|
||||||
for i, arg in enumerate(type.kw_args)
|
|
||||||
],
|
def visit_frame_type(self, type: m.FrameType) -> Type:
|
||||||
returns=type.returns.accept(self),
|
def process_column(i: int, col: m.FrameType.Column) -> DataFrameType.Column:
|
||||||
|
return DataFrameType.Column(
|
||||||
|
index=i,
|
||||||
|
name=col.name.lexeme,
|
||||||
|
type=ColumnType(type=col.type.accept(self)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return DataFrameType(
|
||||||
|
columns=[process_column(i, col) for i, col in enumerate(type.columns)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _resolve_type_params(self, params: list[m.TypeParam]):
|
def _resolve_type_params(self, params: list[m.TypeParam]):
|
||||||
@@ -204,3 +426,343 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[Type
|
|||||||
self._local_variables[name] = var
|
self._local_variables[name] = var
|
||||||
vars.append(var)
|
vars.append(var)
|
||||||
return vars
|
return vars
|
||||||
|
|
||||||
|
def _get_call_result(
|
||||||
|
self,
|
||||||
|
location: Location,
|
||||||
|
callee: Type,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> Optional[Type]:
|
||||||
|
"""Get the result type of a function call
|
||||||
|
|
||||||
|
If the function has overloads, the function will try to resolve the
|
||||||
|
appropriate signature.
|
||||||
|
Argument types are matched to the defined parameters.
|
||||||
|
The function doesn't take the raw expression as a parameter to accommodate
|
||||||
|
for desugared calls such as for operators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location (Location): the call location
|
||||||
|
callee (Type): the called function
|
||||||
|
positional (list[TypedExpr]): the list positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: the return type of the call, or `None` if either
|
||||||
|
the call is invalid or no overload matched the arguments uniquely
|
||||||
|
"""
|
||||||
|
match callee:
|
||||||
|
case Function() as function:
|
||||||
|
valid: bool
|
||||||
|
mapped: list[MappedArgument]
|
||||||
|
valid, mapped = self.map_call_arguments(
|
||||||
|
function, location, positional, keywords
|
||||||
|
)
|
||||||
|
valid = valid and self._are_arguments_valid(mapped, report_errors)
|
||||||
|
if not valid:
|
||||||
|
return None
|
||||||
|
return function.returns
|
||||||
|
|
||||||
|
case OverloadedFunction(overloads=overloads):
|
||||||
|
function = self._match_overload(
|
||||||
|
overloads, location, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
if function is None:
|
||||||
|
return None
|
||||||
|
return function.returns
|
||||||
|
|
||||||
|
case AppliedType(body=body):
|
||||||
|
return self._get_call_result(
|
||||||
|
location, body, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
case _:
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(location, f"{callee} is not callable")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _are_arguments_valid(
|
||||||
|
self,
|
||||||
|
arguments: list[MappedArgument],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether the passed argument types correspond to their matched parameter definitions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments (list[MappedArgument]): the list of argument/parameter pairs
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if all arguments fit the matching parameter definitions, False otherwise
|
||||||
|
"""
|
||||||
|
valid: bool = True
|
||||||
|
for arg in arguments:
|
||||||
|
if not self.types.is_subtype(arg.type, arg.argument.type):
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
arg.expr.location,
|
||||||
|
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def _match_overload(
|
||||||
|
self,
|
||||||
|
overloads: list[Type],
|
||||||
|
location: Location,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> Optional[Function]:
|
||||||
|
"""Try and resolve the appropriate overload for the given arguments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
overloads (list[Type]): the list of possible overloads
|
||||||
|
location (Location): the call location
|
||||||
|
positional (list[TypedExpr]): the list of positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keywords arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Function]: the resolved function signature if it can be
|
||||||
|
determined unambiguously, or `None`.
|
||||||
|
"""
|
||||||
|
candidates: list[OverloadCandidate] = []
|
||||||
|
for overload in overloads:
|
||||||
|
function: Type = unfold_type(overload)
|
||||||
|
if not isinstance(function, Function):
|
||||||
|
if report_errors:
|
||||||
|
self.logger.error(
|
||||||
|
f"Overload is not a function: {overload} is {function}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
valid, mapped = self.map_call_arguments(
|
||||||
|
function=function,
|
||||||
|
location=location,
|
||||||
|
positional=positional,
|
||||||
|
keywords=keywords,
|
||||||
|
report_errors=False,
|
||||||
|
)
|
||||||
|
if valid and self._are_arguments_valid(mapped, report_errors=False):
|
||||||
|
candidates.append(
|
||||||
|
OverloadCandidate(
|
||||||
|
function=function,
|
||||||
|
mapped=mapped,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_types: str = ", ".join(str(type) for _, type in positional)
|
||||||
|
kw_types: str = ", ".join(
|
||||||
|
f"{name}: {type}" for name, (_, type) in keywords.items()
|
||||||
|
)
|
||||||
|
for_args: str = f"for arguments pos=[{pos_types}] and kw={{{kw_types}}}"
|
||||||
|
|
||||||
|
n_candidates: int = len(candidates)
|
||||||
|
|
||||||
|
# Exactly 1 match -> return it
|
||||||
|
if n_candidates == 1:
|
||||||
|
return candidates[0].function
|
||||||
|
|
||||||
|
# No match -> invalid call
|
||||||
|
if n_candidates == 0:
|
||||||
|
overloads_str: str = ", ".join(map(str, overloads))
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"No matching overload in [{overloads_str}] {for_args}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Multiple matches -> see if one <: all others (more specific)
|
||||||
|
for i1, c1 in enumerate(candidates):
|
||||||
|
mapped1: list[MappedArgument] = c1.mapped
|
||||||
|
best_match: bool = True
|
||||||
|
for i2, c2 in enumerate(candidates):
|
||||||
|
if i1 == i2:
|
||||||
|
continue
|
||||||
|
mapped2: list[MappedArgument] = c2.mapped
|
||||||
|
if not self._are_mapped_subtypes(mapped1, mapped2):
|
||||||
|
best_match = False
|
||||||
|
break
|
||||||
|
self.logger.debug(f"{c1.function} is a full overload of {c2.function}")
|
||||||
|
if best_match:
|
||||||
|
return c1.function
|
||||||
|
|
||||||
|
candidates_str: str = ", ".join(
|
||||||
|
str(candidate.function) for candidate in candidates
|
||||||
|
)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Multiple matching overloads {for_args}: {candidates_str}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def map_call_arguments(
|
||||||
|
self,
|
||||||
|
function: Function,
|
||||||
|
location: Location,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
report_errors: bool = True,
|
||||||
|
) -> tuple[bool, list[MappedArgument]]:
|
||||||
|
"""Map call arguments to a function's parameters as defined in its signature
|
||||||
|
|
||||||
|
This method maps positional-only, keyword-only and mixed parameter definitions
|
||||||
|
with the arguments passed at the call site
|
||||||
|
|
||||||
|
Any mismatched, missing or unexpected argument is reported as a diagnostic,
|
||||||
|
unless `report_errors` is set to `False`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function (Function): the function definition
|
||||||
|
location (Location): the call location
|
||||||
|
positional (list[TypedExpr]): the list of positional arguments
|
||||||
|
keywords (dict[str, TypedExpr]): the map of keyword arguments
|
||||||
|
report_errors (bool, optional): whether type errors should be reported as diagnostics. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, list[MappedArgument]]: a boolean reporting whether
|
||||||
|
the call is valid and the list of mapped arguments
|
||||||
|
"""
|
||||||
|
set_args: set[str] = set()
|
||||||
|
|
||||||
|
required_positional: list[str] = [
|
||||||
|
arg.name for arg in function.pos_args + function.args if arg.required
|
||||||
|
]
|
||||||
|
required_keyword: list[str] = [
|
||||||
|
arg.name for arg in function.kw_args if arg.required
|
||||||
|
]
|
||||||
|
|
||||||
|
mapped: list[MappedArgument] = []
|
||||||
|
|
||||||
|
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||||
|
mixed_params: list[Function.Argument] = list(function.args)
|
||||||
|
kw_params: dict[str, Function.Argument] = {
|
||||||
|
arg.name: arg for arg in function.kw_args
|
||||||
|
}
|
||||||
|
|
||||||
|
valid_call: bool = True
|
||||||
|
|
||||||
|
# TODO: handle *args and **kwargs sinks
|
||||||
|
for arg in positional:
|
||||||
|
param: Function.Argument
|
||||||
|
if len(pos_params) != 0:
|
||||||
|
param = pos_params.pop(0)
|
||||||
|
elif len(mixed_params) != 0:
|
||||||
|
param = mixed_params.pop(0)
|
||||||
|
else:
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, "Too many positional arguments"
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
break
|
||||||
|
name: str = param.name
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||||
|
for name, arg in keywords.items():
|
||||||
|
param: Function.Argument
|
||||||
|
if name not in kw_params:
|
||||||
|
if report_errors:
|
||||||
|
if name in set_args:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Multiple values for argument '{name}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.reporter.error(
|
||||||
|
arg[0].location, f"Unknown keyword argument '{name}'"
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
continue
|
||||||
|
param = kw_params.pop(name)
|
||||||
|
if name in required_positional:
|
||||||
|
required_positional.remove(name)
|
||||||
|
if name in required_keyword:
|
||||||
|
required_keyword.remove(name)
|
||||||
|
set_args.add(name)
|
||||||
|
mapped.append(
|
||||||
|
MappedArgument(
|
||||||
|
expr=arg[0],
|
||||||
|
type=arg[1],
|
||||||
|
argument=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def join_args(args: list[str]) -> str:
|
||||||
|
args = list(map(lambda a: f"'{a}'", args))
|
||||||
|
if len(args) == 0:
|
||||||
|
return ""
|
||||||
|
if len(args) == 1:
|
||||||
|
return args[0]
|
||||||
|
return ", ".join(args[:-1]) + " and " + args[-1]
|
||||||
|
|
||||||
|
if len(required_positional) != 0:
|
||||||
|
plural: str = "" if len(required_positional) == 1 else "s"
|
||||||
|
args: str = join_args(required_positional)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Missing required positional argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
|
||||||
|
if len(required_keyword) != 0:
|
||||||
|
plural: str = "" if len(required_keyword) == 1 else "s"
|
||||||
|
args: str = join_args(required_keyword)
|
||||||
|
if report_errors:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Missing required keyword argument{plural}: {args}",
|
||||||
|
)
|
||||||
|
valid_call = False
|
||||||
|
|
||||||
|
return valid_call, mapped
|
||||||
|
|
||||||
|
def _are_mapped_subtypes(
|
||||||
|
self, mapped1: list[MappedArgument], mapped2: list[MappedArgument]
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether the given argument mappings are subtype/supertype of one another
|
||||||
|
|
||||||
|
This function checks whether the argument mappings `mapped1` are subtypes
|
||||||
|
of `mapped2`. If any of the parameter type in `mapped1` is not a subtype
|
||||||
|
of the corresponding parameter in `mapped2`, `False` is returned.
|
||||||
|
|
||||||
|
This is used to check whether a given overload is
|
||||||
|
a more specific function/ a subtype of another.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mapped1 (list[MappedArgument]): the first argument mappings (subtype)
|
||||||
|
mapped2 (list[MappedArgument]): the second argument mappings (supertype)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: `True` if `mapped1` is a subtype of `mapped2`, `False` otherwise
|
||||||
|
"""
|
||||||
|
by_expr: dict[m.Expr, Type] = {}
|
||||||
|
for arg in mapped1:
|
||||||
|
by_expr[arg.expr] = arg.argument.type
|
||||||
|
|
||||||
|
for arg in mapped2:
|
||||||
|
type2: Type = arg.argument.type
|
||||||
|
type1: Type = by_expr[arg.expr]
|
||||||
|
if not self.types.is_subtype(type1, type2):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import ast
|
import ast
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
from midas.lexer.token import TokenType
|
||||||
|
|
||||||
|
PY_OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
||||||
ast.Add: "__add__",
|
ast.Add: "__add__",
|
||||||
ast.Sub: "__sub__",
|
ast.Sub: "__sub__",
|
||||||
ast.Mult: "__mul__",
|
ast.Mult: "__mul__",
|
||||||
@@ -17,9 +19,9 @@ OPERATOR_METHODS: dict[Type[ast.operator], str] = {
|
|||||||
ast.FloorDiv: "__floordiv__",
|
ast.FloorDiv: "__floordiv__",
|
||||||
}
|
}
|
||||||
|
|
||||||
COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
PY_COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
||||||
ast.Eq: "__eq__",
|
ast.Eq: "__eq__",
|
||||||
# ast.NotEq: "__noteq__",
|
ast.NotEq: "__eq__",
|
||||||
ast.Lt: "__lt__",
|
ast.Lt: "__lt__",
|
||||||
ast.LtE: "__le__",
|
ast.LtE: "__le__",
|
||||||
ast.Gt: "__gt__",
|
ast.Gt: "__gt__",
|
||||||
@@ -30,9 +32,40 @@ COMPARATOR_METHODS: dict[Type[ast.cmpop], str] = {
|
|||||||
# ast.NotIn: "__notin__",
|
# ast.NotIn: "__notin__",
|
||||||
}
|
}
|
||||||
|
|
||||||
UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
PY_UNARY_METHODS: dict[Type[ast.unaryop], str] = {
|
||||||
ast.Invert: "__invert__",
|
ast.Invert: "__invert__",
|
||||||
# ast.Not: "",
|
# ast.Not: "",
|
||||||
ast.UAdd: "__pos__",
|
ast.UAdd: "__pos__",
|
||||||
ast.USub: "__neg__",
|
ast.USub: "__neg__",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MIDAS_BINARY_METHODS: dict[TokenType, str] = {
|
||||||
|
# TokenType.PLUS: "__add__",
|
||||||
|
TokenType.MINUS: "__sub__",
|
||||||
|
TokenType.STAR: "__mul__",
|
||||||
|
TokenType.SLASH: "__truediv__",
|
||||||
|
# TokenType.MODULO: "__mod__",
|
||||||
|
# TokenType.POW: "__pow__",
|
||||||
|
# ast.BitOr: "__or__",
|
||||||
|
# ast.BitXor: "__xor__",
|
||||||
|
# ast.BitAnd: "__and__",
|
||||||
|
# ast.FloorDiv: "__floordiv__",
|
||||||
|
TokenType.EQUAL_EQUAL: "__eq__",
|
||||||
|
TokenType.BANG_EQUAL: "__eq__",
|
||||||
|
TokenType.LESS: "__lt__",
|
||||||
|
TokenType.LESS_EQUAL: "__le__",
|
||||||
|
TokenType.GREATER: "__gt__",
|
||||||
|
TokenType.GREATER_EQUAL: "__ge__",
|
||||||
|
# ast.Is: "__is__",
|
||||||
|
# ast.IsNot: "__isnot__",
|
||||||
|
# ast.In: "__in__",
|
||||||
|
# ast.NotIn: "__notin__",
|
||||||
|
}
|
||||||
|
|
||||||
|
MIDAS_UNARY_METHODS: dict[TokenType, str] = {
|
||||||
|
# ast.Invert: "__invert__",
|
||||||
|
# ast.Not: "",
|
||||||
|
# TokenType.PLUS: "__pos__",
|
||||||
|
TokenType.MINUS: "__neg__",
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
@@ -16,23 +17,26 @@ class Preamble(Environment):
|
|||||||
def __init__(self, types: TypesRegistry) -> None:
|
def __init__(self, types: TypesRegistry) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._types: TypesRegistry = types
|
self._types: TypesRegistry = types
|
||||||
|
self._python_funcs: dict[str, Callable[..., Any]] = {}
|
||||||
|
|
||||||
self._def_type_constructor("object")
|
self._def_type_constructor("object", object)
|
||||||
self._def_type_constructor("float")
|
self._def_type_constructor("float", float)
|
||||||
self._def_type_constructor("int")
|
self._def_type_constructor("int", int)
|
||||||
self._def_type_constructor("bool")
|
self._def_type_constructor("bool", bool)
|
||||||
self._def_type_constructor("str")
|
self._def_type_constructor("str", str)
|
||||||
self._def_function(
|
self._def_function(
|
||||||
name="list",
|
name="list",
|
||||||
pos=[Param("object", TopType())],
|
pos=[Param("object", TopType())],
|
||||||
returns=self._list_of(TopType()),
|
returns=self._list_of(TopType()),
|
||||||
|
py_function=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: use sink
|
# TODO: use sink
|
||||||
self._def_function(
|
self._def_function(
|
||||||
name="print",
|
name="print",
|
||||||
pos=[Param("object", TopType())],
|
pos=[Param("object", TopType(), required=False)],
|
||||||
returns=UnitType(),
|
returns=UnitType(),
|
||||||
|
py_function=print,
|
||||||
)
|
)
|
||||||
|
|
||||||
map_in = TypeVar(name="T", bound=None)
|
map_in = TypeVar(name="T", bound=None)
|
||||||
@@ -52,17 +56,32 @@ class Preamble(Environment):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
|
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
|
||||||
|
type_vars=[map_in, map_out],
|
||||||
|
py_function=map,
|
||||||
|
)
|
||||||
|
self._def_function(
|
||||||
|
name="input",
|
||||||
|
pos=[Param("prompt", TopType(), required=False)],
|
||||||
|
returns=self._types.get_type("str"),
|
||||||
|
)
|
||||||
|
self._def_function(
|
||||||
|
name="len",
|
||||||
|
pos=[Param("object", TopType())],
|
||||||
|
returns=self._types.get_type("int"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _list_of(self, item_type: Type) -> Type:
|
def _list_of(self, item_type: Type) -> Type:
|
||||||
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
||||||
|
|
||||||
def _def_type_constructor(self, name: str):
|
def _def_type_constructor(
|
||||||
|
self, name: str, py_function: Optional[Callable[..., Any]] = None
|
||||||
|
):
|
||||||
# TODO: more specific arg types
|
# TODO: more specific arg types
|
||||||
self._def_function(
|
self._def_function(
|
||||||
name=name,
|
name=name,
|
||||||
pos=[Param("object", TopType(), required=False)],
|
pos=[Param("object", TopType(), required=False)],
|
||||||
returns=self._types.get_type(name),
|
returns=self._types.get_type(name),
|
||||||
|
py_function=py_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_function(
|
def _make_function(
|
||||||
@@ -109,6 +128,7 @@ class Preamble(Environment):
|
|||||||
kw: list[Param] = [],
|
kw: list[Param] = [],
|
||||||
returns: Type = UnitType(),
|
returns: Type = UnitType(),
|
||||||
type_vars: list[TypeVar] = [],
|
type_vars: list[TypeVar] = [],
|
||||||
|
py_function: Optional[Callable[..., Any]] = None,
|
||||||
):
|
):
|
||||||
function: Type = self._make_function(
|
function: Type = self._make_function(
|
||||||
name=name,
|
name=name,
|
||||||
@@ -119,3 +139,8 @@ class Preamble(Environment):
|
|||||||
type_vars=type_vars,
|
type_vars=type_vars,
|
||||||
)
|
)
|
||||||
self.define(name, function)
|
self.define(name, function)
|
||||||
|
if py_function is not None:
|
||||||
|
self._python_funcs[name] = py_function
|
||||||
|
|
||||||
|
def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
|
||||||
|
return self._python_funcs.get(name)
|
||||||
|
|||||||
@@ -1,25 +1,42 @@
|
|||||||
import ast
|
import ast
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
|
from midas.ast.printer import MidasPrinter
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.operators import COMPARATOR_METHODS, OPERATOR_METHODS, UNARY_METHODS
|
from midas.checker.evaluator import Evaluator
|
||||||
|
from midas.checker.frames import FrameManager
|
||||||
|
from midas.checker.operators import (
|
||||||
|
PY_COMPARATOR_METHODS,
|
||||||
|
PY_OPERATOR_METHODS,
|
||||||
|
PY_UNARY_METHODS,
|
||||||
|
)
|
||||||
from midas.checker.preamble import Preamble
|
from midas.checker.preamble import Preamble
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.reporter import FileReporter, Reporter
|
from midas.checker.reporter import FileReporter, Reporter
|
||||||
from midas.checker.resolver import Resolver
|
from midas.checker.resolver import Resolver
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
|
AliasType,
|
||||||
AppliedType,
|
AppliedType,
|
||||||
|
BaseType,
|
||||||
|
ColumnType,
|
||||||
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
Function,
|
Function,
|
||||||
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
|
TupleType,
|
||||||
Type,
|
Type,
|
||||||
|
TypeVar,
|
||||||
UnitType,
|
UnitType,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
|
Variance,
|
||||||
unfold_type,
|
unfold_type,
|
||||||
)
|
)
|
||||||
|
from midas.checker.unifier import Unifier
|
||||||
from midas.parser.python import PythonParser
|
from midas.parser.python import PythonParser
|
||||||
from midas.utils import TypedAST
|
from midas.utils import TypedAST
|
||||||
|
|
||||||
@@ -30,6 +47,10 @@ class ReturnException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UndefinedMethodException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class MappedArgument:
|
class MappedArgument:
|
||||||
expr: p.Expr
|
expr: p.Expr
|
||||||
@@ -58,10 +79,12 @@ class PythonTyper(
|
|||||||
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
||||||
self.reporter: FileReporter = reporter.for_file(None)
|
self.reporter: FileReporter = reporter.for_file(None)
|
||||||
self.types: TypesRegistry = types
|
self.types: TypesRegistry = types
|
||||||
|
self.frame_mgr: FrameManager = FrameManager(self)
|
||||||
self.global_env: Environment = Preamble(self.types)
|
self.global_env: Environment = Preamble(self.types)
|
||||||
self.env: Environment = self.global_env
|
self.env: Environment = self.global_env
|
||||||
self.locals: dict[p.Expr, int] = {}
|
self.locals: dict[p.Expr, int] = {}
|
||||||
self.judgements: list[tuple[p.Expr, Type]] = []
|
self.judgements: list[tuple[p.Expr, Type]] = []
|
||||||
|
self.evaluated_casts: list[p.CastExpr] = []
|
||||||
|
|
||||||
def process(self, source: str, path: Optional[str]) -> TypedAST:
|
def process(self, source: str, path: Optional[str]) -> TypedAST:
|
||||||
self.reporter = self.reporter.for_file(path)
|
self.reporter = self.reporter.for_file(path)
|
||||||
@@ -75,10 +98,15 @@ class PythonTyper(
|
|||||||
self.env = self.global_env
|
self.env = self.global_env
|
||||||
self.locals = resolver.locals
|
self.locals = resolver.locals
|
||||||
self.judgements = []
|
self.judgements = []
|
||||||
|
self.evaluated_casts = []
|
||||||
|
|
||||||
self.check(stmts)
|
self.check(stmts)
|
||||||
|
|
||||||
return TypedAST(stmts=stmts, judgements=self.judgements)
|
return TypedAST(
|
||||||
|
stmts=stmts,
|
||||||
|
judgements=self.judgements,
|
||||||
|
evaluated_casts=self.evaluated_casts,
|
||||||
|
)
|
||||||
|
|
||||||
def judge(self, expr: p.Expr, type: Type):
|
def judge(self, expr: p.Expr, type: Type):
|
||||||
"""Record a typing judgement
|
"""Record a typing judgement
|
||||||
@@ -171,6 +199,36 @@ class PythonTyper(
|
|||||||
return self.env.get_at(distance, name)
|
return self.env.get_at(distance, name)
|
||||||
return self.global_env.get(name)
|
return self.global_env.get(name)
|
||||||
|
|
||||||
|
def call_method(
|
||||||
|
self,
|
||||||
|
location: Location,
|
||||||
|
obj: Type,
|
||||||
|
method_name: str,
|
||||||
|
positional: list[TypedExpr],
|
||||||
|
keywords: dict[str, TypedExpr],
|
||||||
|
) -> Optional[Type]:
|
||||||
|
unfolded: Type = unfold_type(obj)
|
||||||
|
match unfolded:
|
||||||
|
case DataFrameType():
|
||||||
|
return self.frame_mgr.call(
|
||||||
|
method=method_name,
|
||||||
|
location=location,
|
||||||
|
frame=unfolded,
|
||||||
|
positional=positional,
|
||||||
|
keywords=keywords,
|
||||||
|
)
|
||||||
|
|
||||||
|
method: Optional[Type] = self.types.lookup_member(obj, method_name)
|
||||||
|
if method is None:
|
||||||
|
raise UndefinedMethodException
|
||||||
|
|
||||||
|
return self._get_call_result(
|
||||||
|
location,
|
||||||
|
method,
|
||||||
|
positional,
|
||||||
|
keywords,
|
||||||
|
)
|
||||||
|
|
||||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
return self.types.is_subtype(type1, type2)
|
return self.types.is_subtype(type1, type2)
|
||||||
|
|
||||||
@@ -222,7 +280,8 @@ class PythonTyper(
|
|||||||
)
|
)
|
||||||
pos += 1
|
pos += 1
|
||||||
|
|
||||||
for arg in pos_args + args + kw_args:
|
all_args: list[Function.Argument] = pos_args + args + kw_args
|
||||||
|
for arg in all_args:
|
||||||
env.define(arg.name, arg.type)
|
env.define(arg.name, arg.type)
|
||||||
|
|
||||||
returns_hint: Optional[Type] = None
|
returns_hint: Optional[Type] = None
|
||||||
@@ -263,12 +322,25 @@ class PythonTyper(
|
|||||||
returns = inferred_return
|
returns = inferred_return
|
||||||
|
|
||||||
# TODO: handle *args and **kwargs sinks
|
# TODO: handle *args and **kwargs sinks
|
||||||
function: Function = Function(
|
function: Type = Function(
|
||||||
pos_args=pos_args,
|
pos_args=pos_args,
|
||||||
args=args,
|
args=args,
|
||||||
kw_args=kw_args,
|
kw_args=kw_args,
|
||||||
returns=returns,
|
returns=returns,
|
||||||
)
|
)
|
||||||
|
generic_params: list[TypeVar] = []
|
||||||
|
all_types: list[Type] = [arg.type for arg in all_args] + [returns]
|
||||||
|
for type in all_types:
|
||||||
|
if isinstance(type, TypeVar):
|
||||||
|
if type not in generic_params:
|
||||||
|
generic_params.append(type)
|
||||||
|
|
||||||
|
if len(generic_params) != 0:
|
||||||
|
function = GenericType(
|
||||||
|
name=stmt.name,
|
||||||
|
params=generic_params,
|
||||||
|
body=function,
|
||||||
|
)
|
||||||
self.env.define(stmt.name, function)
|
self.env.define(stmt.name, function)
|
||||||
|
|
||||||
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
|
||||||
@@ -286,9 +358,15 @@ class PythonTyper(
|
|||||||
case p.VariableExpr():
|
case p.VariableExpr():
|
||||||
self._assign_var(location, target, value_type)
|
self._assign_var(location, target, value_type)
|
||||||
|
|
||||||
|
# Allow any kind of object because we disallow creating new attributes
|
||||||
case p.GetExpr(object=object, name=name):
|
case p.GetExpr(object=object, name=name):
|
||||||
self._assign_attr(location, object, name, value_type)
|
self._assign_attr(location, object, name, value_type)
|
||||||
|
|
||||||
|
# Only support variable expressions because modifying
|
||||||
|
# the underlying value would require reference types
|
||||||
|
case p.SubscriptExpr(object=p.VariableExpr() as var, index=index):
|
||||||
|
self._assign_sub(location, var, index, value_type)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
if not isinstance(target, p.VariableExpr):
|
if not isinstance(target, p.VariableExpr):
|
||||||
self.logger.warning(f"Unsupported assignment to {target}")
|
self.logger.warning(f"Unsupported assignment to {target}")
|
||||||
@@ -327,6 +405,30 @@ class PythonTyper(
|
|||||||
f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
|
f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _assign_sub(
|
||||||
|
self,
|
||||||
|
location: Location,
|
||||||
|
var: p.VariableExpr,
|
||||||
|
index: p.Expr,
|
||||||
|
value_type: Type,
|
||||||
|
):
|
||||||
|
var_type: Type = self.type_of(var)
|
||||||
|
unfolded_type: Type = unfold_type(var_type)
|
||||||
|
# TODO: what happens if type is an alias of a dataframe type
|
||||||
|
match unfolded_type:
|
||||||
|
case DataFrameType() as frame:
|
||||||
|
new_type: Type = self.frame_mgr.assign(
|
||||||
|
self.reporter, location, frame, index, value_type
|
||||||
|
)
|
||||||
|
self.env.assign(var.name, new_type)
|
||||||
|
case UnknownType():
|
||||||
|
return
|
||||||
|
case _:
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Cannot assign {value_type} to index {index} of {var_type}",
|
||||||
|
)
|
||||||
|
|
||||||
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
|
||||||
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
|
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
|
||||||
self.env.return_types.append(type)
|
self.env.return_types.append(type)
|
||||||
@@ -340,8 +442,10 @@ class PythonTyper(
|
|||||||
# print(m) # <- m is still defined
|
# print(m) # <- m is still defined
|
||||||
test_type: Type = self.type_of(stmt.test)
|
test_type: Type = self.type_of(stmt.test)
|
||||||
|
|
||||||
# TODO Allow subtypes or any type
|
if (
|
||||||
if test_type != self.types.get_type("bool"):
|
not self.types.is_subtype(test_type, self.types.get_type("bool"))
|
||||||
|
and test_type != UnknownType()
|
||||||
|
):
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
stmt.test.location, f"If test must be a boolean, got {test_type}"
|
||||||
)
|
)
|
||||||
@@ -357,13 +461,16 @@ class PythonTyper(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
|
def visit_for_stmt(self, stmt: p.ForStmt) -> None:
|
||||||
item_type: Optional[Type] = self._get_iterator_type(stmt.iterator)
|
item_type: Type = UnknownType()
|
||||||
if item_type is None:
|
iterator_type: Type = self.type_of(stmt.iterator)
|
||||||
iterator_type: Type = self.compute_type(stmt.iterator)
|
if iterator_type != UnknownType():
|
||||||
self.reporter.error(
|
maybe_item_type = self._get_iterator_type(stmt.iterator, iterator_type)
|
||||||
stmt.iterator.location, f"{iterator_type} is not iterable"
|
if maybe_item_type is None:
|
||||||
)
|
self.reporter.error(
|
||||||
item_type = UnknownType()
|
stmt.iterator.location, f"{iterator_type} is not iterable"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
item_type = maybe_item_type
|
||||||
|
|
||||||
self._assign(stmt.location, stmt.target, item_type)
|
self._assign(stmt.location, stmt.target, item_type)
|
||||||
self.judge(stmt.target, item_type)
|
self.judge(stmt.target, item_type)
|
||||||
@@ -376,7 +483,7 @@ class PythonTyper(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
def visit_binary_expr(self, expr: p.BinaryExpr) -> Type:
|
||||||
method: Optional[str] = OPERATOR_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = PY_OPERATOR_METHODS.get(expr.operator.__class__)
|
||||||
if method is None:
|
if method is None:
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
self.reporter.warning(
|
self.reporter.warning(
|
||||||
@@ -387,7 +494,7 @@ class PythonTyper(
|
|||||||
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
return self._visit_binary_expr(expr.location, expr.left, expr.right, method)
|
||||||
|
|
||||||
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
def visit_compare_expr(self, expr: p.CompareExpr) -> Type:
|
||||||
method: Optional[str] = COMPARATOR_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = PY_COMPARATOR_METHODS.get(expr.operator.__class__)
|
||||||
if method is None:
|
if method is None:
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
self.reporter.warning(
|
self.reporter.warning(
|
||||||
@@ -403,24 +510,20 @@ class PythonTyper(
|
|||||||
left: Type = self.type_of(left_expr)
|
left: Type = self.type_of(left_expr)
|
||||||
right: Type = self.type_of(right_expr)
|
right: Type = self.type_of(right_expr)
|
||||||
|
|
||||||
operation: Optional[Type] = self.types.lookup_member(left, method)
|
result: Optional[Type]
|
||||||
if operation is None:
|
try:
|
||||||
|
result = self.call_method(location, left, method, [(right_expr, right)], {})
|
||||||
|
except UndefinedMethodException:
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
location,
|
location,
|
||||||
f"Undefined operation {method} between {left} and {right}",
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
result: Optional[Type] = self._get_call_result(
|
|
||||||
location,
|
|
||||||
operation,
|
|
||||||
[(right_expr, right)],
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
return result or UnknownType()
|
return result or UnknownType()
|
||||||
|
|
||||||
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
|
||||||
method: Optional[str] = UNARY_METHODS.get(expr.operator.__class__)
|
method: Optional[str] = PY_UNARY_METHODS.get(expr.operator.__class__)
|
||||||
if method is None:
|
if method is None:
|
||||||
self.logger.warning(f"Unsupported operator {expr.operator}")
|
self.logger.warning(f"Unsupported operator {expr.operator}")
|
||||||
self.reporter.warning(
|
self.reporter.warning(
|
||||||
@@ -429,30 +532,45 @@ class PythonTyper(
|
|||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
operand: Type = self.type_of(expr.right)
|
operand: Type = self.type_of(expr.right)
|
||||||
operation: Optional[Type] = self.types.lookup_member(operand, method)
|
|
||||||
if operation is None:
|
result: Optional[Type]
|
||||||
|
try:
|
||||||
|
result = self.call_method(expr.location, operand, method, [], {})
|
||||||
|
except UndefinedMethodException:
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
expr.location,
|
expr.location,
|
||||||
f"Undefined operation {method} for {operand}",
|
f"Undefined operation {method} for {operand}",
|
||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
result: Optional[Type] = self._get_call_result(
|
|
||||||
expr.location,
|
|
||||||
operation,
|
|
||||||
[],
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
return result or UnknownType()
|
return result or UnknownType()
|
||||||
|
|
||||||
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
def visit_call_expr(self, expr: p.CallExpr) -> Type:
|
||||||
callee: Type = self.type_of(expr.callee)
|
match expr.callee:
|
||||||
|
case p.VariableExpr(name="TypeVar"):
|
||||||
|
return self.define_typevar(expr) or UnknownType()
|
||||||
|
|
||||||
positional: list[TypedExpr] = [
|
positional: list[TypedExpr] = [
|
||||||
(arg, self.type_of(arg)) for arg in expr.arguments
|
(arg, self.type_of(arg)) for arg in expr.arguments
|
||||||
]
|
]
|
||||||
keywords: dict[str, TypedExpr] = {
|
keywords: dict[str, TypedExpr] = {
|
||||||
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
match expr.callee:
|
||||||
|
case p.GetExpr(object=obj, name=method):
|
||||||
|
obj_type: Type = self.type_of(obj)
|
||||||
|
unfolded: Type = unfold_type(obj_type)
|
||||||
|
if isinstance(unfolded, DataFrameType):
|
||||||
|
return self.frame_mgr.call(
|
||||||
|
method,
|
||||||
|
expr.location,
|
||||||
|
unfolded,
|
||||||
|
positional,
|
||||||
|
keywords,
|
||||||
|
)
|
||||||
|
|
||||||
|
callee: Type = self.type_of(expr.callee)
|
||||||
return (
|
return (
|
||||||
self._get_call_result(
|
self._get_call_result(
|
||||||
location=expr.location,
|
location=expr.location,
|
||||||
@@ -467,7 +585,7 @@ class PythonTyper(
|
|||||||
object: Type = self.type_of(expr.object)
|
object: Type = self.type_of(expr.object)
|
||||||
member: Optional[Type] = self.types.lookup_member(object, expr.name)
|
member: Optional[Type] = self.types.lookup_member(object, expr.name)
|
||||||
if member is None:
|
if member is None:
|
||||||
self.reporter.error(
|
self.reporter.warning(
|
||||||
expr.location, f"Unknown member '{expr.name}' of {object}"
|
expr.location, f"Unknown member '{expr.name}' of {object}"
|
||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
@@ -484,6 +602,8 @@ class PythonTyper(
|
|||||||
return self.types.get_type("float")
|
return self.types.get_type("float")
|
||||||
case str():
|
case str():
|
||||||
return self.types.get_type("str")
|
return self.types.get_type("str")
|
||||||
|
case None:
|
||||||
|
return self.types.get_type("None")
|
||||||
case _:
|
case _:
|
||||||
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
self.reporter.warning(expr.location, f"Unknown literal {expr}")
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
@@ -511,13 +631,25 @@ class PythonTyper(
|
|||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
|
||||||
return self.resolve_type_expr(expr.type)
|
subject_type: Type = self.type_of(expr.expr)
|
||||||
|
target_type: Type = self.resolve_type_expr(expr.type)
|
||||||
|
is_lit, lit_value = self._get_literal(expr.expr)
|
||||||
|
if is_lit:
|
||||||
|
evaluated: bool = self._evaluate_cast_statically(
|
||||||
|
expr, subject_type, target_type, lit_value
|
||||||
|
)
|
||||||
|
if evaluated:
|
||||||
|
self.evaluated_casts.append(expr)
|
||||||
|
return target_type
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
|
||||||
test_type: Type = self.type_of(expr.test)
|
test_type: Type = self.type_of(expr.test)
|
||||||
|
|
||||||
# TODO Allow subtypes or any type
|
# TODO Allow subtypes or any type
|
||||||
if test_type != self.types.get_type("bool"):
|
if (
|
||||||
|
not self.is_subtype(test_type, self.types.get_type("bool"))
|
||||||
|
and test_type != UnknownType()
|
||||||
|
):
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
expr.test.location, f"If test must be a boolean, got {test_type}"
|
expr.test.location, f"If test must be a boolean, got {test_type}"
|
||||||
)
|
)
|
||||||
@@ -546,9 +678,9 @@ class PythonTyper(
|
|||||||
if len(item_types) == 1:
|
if len(item_types) == 1:
|
||||||
item_type: Type = item_types[0]
|
item_type: Type = item_types[0]
|
||||||
return self.types.apply_generic(list_type, [item_type])
|
return self.types.apply_generic(list_type, [item_type])
|
||||||
self.reporter.error(
|
self.reporter.warning(
|
||||||
expr.location,
|
expr.location,
|
||||||
f"Heterogeneous list items: {item_types}",
|
f"Heterogeneous list items: [{', '.join(map(str, item_types))}]",
|
||||||
)
|
)
|
||||||
return self.types.apply_generic(list_type, [UnknownType()])
|
return self.types.apply_generic(list_type, [UnknownType()])
|
||||||
|
|
||||||
@@ -578,22 +710,29 @@ class PythonTyper(
|
|||||||
if len(key_types) == 1:
|
if len(key_types) == 1:
|
||||||
key_type = key_types[0]
|
key_type = key_types[0]
|
||||||
else:
|
else:
|
||||||
self.reporter.error(
|
self.reporter.warning(
|
||||||
expr.location,
|
expr.location,
|
||||||
f"Heterogeneous dict keys: {key_types}",
|
f"Heterogeneous dict keys: [{', '.join(map(str, key_types))}]",
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(value_types) == 1:
|
if len(value_types) == 1:
|
||||||
value_type = value_types[0]
|
value_type = value_types[0]
|
||||||
else:
|
else:
|
||||||
self.reporter.error(
|
self.reporter.warning(
|
||||||
expr.location,
|
expr.location,
|
||||||
f"Heterogeneous dict values: {value_types}",
|
f"Heterogeneous dict values: [{', '.join(map(str, value_types))}]",
|
||||||
)
|
)
|
||||||
return self.types.apply_generic(dict_type, [key_type, value_type])
|
return self.types.apply_generic(dict_type, [key_type, value_type])
|
||||||
|
|
||||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
|
||||||
object: Type = self.type_of(expr.object)
|
object: Type = self.type_of(expr.object)
|
||||||
|
unfolded: Type = unfold_type(object)
|
||||||
|
match unfolded:
|
||||||
|
case TupleType():
|
||||||
|
return self._visit_tuple_subscript(unfolded, expr)
|
||||||
|
case DataFrameType():
|
||||||
|
return self._visit_frame_subscript(unfolded, expr)
|
||||||
|
|
||||||
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
|
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
|
||||||
if operation is None:
|
if operation is None:
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
@@ -611,6 +750,11 @@ class PythonTyper(
|
|||||||
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
|
def visit_slice_expr(self, expr: p.SliceExpr) -> Type:
|
||||||
return self.types.get_type("slice")
|
return self.types.get_type("slice")
|
||||||
|
|
||||||
|
def visit_tuple_expr(self, expr: p.TupleExpr) -> Type:
|
||||||
|
return TupleType(
|
||||||
|
items=tuple(self.type_of(item) for item in expr.items),
|
||||||
|
)
|
||||||
|
|
||||||
def visit_raw_expr(self, expr: p.RawExpr) -> Type:
|
def visit_raw_expr(self, expr: p.RawExpr) -> Type:
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
@@ -622,22 +766,35 @@ class PythonTyper(
|
|||||||
self.reporter.warning(node.location, f"Unknown type '{node.base}'")
|
self.reporter.warning(node.location, f"Unknown type '{node.base}'")
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
if node.param is not None:
|
if len(node.args) != 0:
|
||||||
param: Type = self.resolve_type_expr(node.param)
|
args: list[Type] = [self.resolve_type_expr(arg) for arg in node.args]
|
||||||
return self.types.apply_generic(base, [param])
|
return self.types.apply_generic(base, args)
|
||||||
return base
|
return base
|
||||||
|
|
||||||
def visit_constraint_type(self, node: p.ConstraintType) -> Type:
|
def visit_constraint_type(self, node: p.ConstraintType) -> Type:
|
||||||
self.reporter.warning(node.location, "ConstraintType not yet supported")
|
self.reporter.warning(node.location, "ConstraintType not yet supported")
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
def visit_frame_column(self, node: p.FrameColumn) -> Type:
|
def visit_frame_column(self, node: p.FrameColumn) -> ColumnType:
|
||||||
self.reporter.warning(node.location, "FrameColumn not yet supported")
|
return ColumnType(
|
||||||
return UnknownType()
|
type=(
|
||||||
|
self.resolve_type_expr(node.type)
|
||||||
|
if node.type is not None
|
||||||
|
else UnknownType()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def visit_frame_type(self, node: p.FrameType) -> Type:
|
def visit_frame_type(self, node: p.FrameType) -> Type:
|
||||||
self.reporter.warning(node.location, "FrameType not yet supported")
|
return DataFrameType(
|
||||||
return UnknownType()
|
columns=[
|
||||||
|
DataFrameType.Column(
|
||||||
|
index=i,
|
||||||
|
name=column.name,
|
||||||
|
type=self.visit_frame_column(column),
|
||||||
|
)
|
||||||
|
for i, column in enumerate(node.columns)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def _get_call_result(
|
def _get_call_result(
|
||||||
self,
|
self,
|
||||||
@@ -652,7 +809,7 @@ class PythonTyper(
|
|||||||
If the function has overloads, the function will try to resolve the
|
If the function has overloads, the function will try to resolve the
|
||||||
appropriate signature.
|
appropriate signature.
|
||||||
Argument types are matched to the defined parameters.
|
Argument types are matched to the defined parameters.
|
||||||
The function doesn't take the raw expression as a parameter to accomodate
|
The function doesn't take the raw expression as a parameter to accommodate
|
||||||
for desugared calls such as for operators.
|
for desugared calls such as for operators.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -694,9 +851,39 @@ class PythonTyper(
|
|||||||
case UnknownType():
|
case UnknownType():
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
|
case AliasType(type=base):
|
||||||
|
return self._get_call_result(
|
||||||
|
location, base, positional, keywords, report_errors
|
||||||
|
)
|
||||||
|
|
||||||
|
case GenericType():
|
||||||
|
unifier: Unifier = Unifier(self.types)
|
||||||
|
pos: list[Type] = [a[1] for a in positional]
|
||||||
|
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
|
||||||
|
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
|
||||||
|
if unified is None:
|
||||||
|
if report_errors:
|
||||||
|
pos_str: str = ", ".join(str(t) for t in pos)
|
||||||
|
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
|
||||||
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return self._get_call_result(
|
||||||
|
location,
|
||||||
|
unified,
|
||||||
|
positional,
|
||||||
|
keywords,
|
||||||
|
report_errors,
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
if report_errors:
|
if report_errors:
|
||||||
self.reporter.error(location, f"{callee} is not callable")
|
self.reporter.error(
|
||||||
|
location,
|
||||||
|
f"{callee} ({callee.__class__.__name__}) is not callable",
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _are_arguments_valid(
|
def _are_arguments_valid(
|
||||||
@@ -743,7 +930,7 @@ class PythonTyper(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[Function]: the resolved function signature if it can be
|
Optional[Function]: the resolved function signature if it can be
|
||||||
determined unambigously, or `None`.
|
determined unambiguously, or `None`.
|
||||||
"""
|
"""
|
||||||
candidates: list[OverloadCandidate] = []
|
candidates: list[OverloadCandidate] = []
|
||||||
for overload in overloads:
|
for overload in overloads:
|
||||||
@@ -979,9 +1166,8 @@ class PythonTyper(
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _get_iterator_type(self, expr: p.Expr) -> Optional[Type]:
|
def _get_iterator_type(self, expr: p.Expr, type: Type) -> Optional[Type]:
|
||||||
# TODO: lookup __iter__
|
# TODO: lookup __iter__
|
||||||
type: Type = self.type_of(expr)
|
|
||||||
getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__")
|
getitem: Optional[Type] = self.types.lookup_member(type, "__getitem__")
|
||||||
if getitem is None:
|
if getitem is None:
|
||||||
return None
|
return None
|
||||||
@@ -996,3 +1182,173 @@ class PythonTyper(
|
|||||||
report_errors=False,
|
report_errors=False,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]:
|
||||||
|
def is_kw_true(name: str) -> bool:
|
||||||
|
match call.keywords.get(name):
|
||||||
|
case p.LiteralExpr(value=True):
|
||||||
|
return True
|
||||||
|
case _:
|
||||||
|
return False
|
||||||
|
|
||||||
|
match call:
|
||||||
|
case p.CallExpr(
|
||||||
|
arguments=[p.LiteralExpr(value=str() as name)],
|
||||||
|
):
|
||||||
|
bound: Optional[Type] = None
|
||||||
|
variance: Variance = Variance.INVARIANT
|
||||||
|
if "bound" in call.keywords:
|
||||||
|
bound_type: p.MidasType = self._parse_type_from_expr(
|
||||||
|
call.keywords["bound"]
|
||||||
|
)
|
||||||
|
bound = self.resolve_type_expr(bound_type)
|
||||||
|
|
||||||
|
if is_kw_true("covariant"):
|
||||||
|
variance = Variance.COVARIANT
|
||||||
|
|
||||||
|
if is_kw_true("contravariant"):
|
||||||
|
if variance == Variance.COVARIANT:
|
||||||
|
self.reporter.warning(
|
||||||
|
call.keywords["contravariant"].location,
|
||||||
|
"TypeVar cannot be covariant and contravariant at the same time. Marked as invariant",
|
||||||
|
)
|
||||||
|
variance = Variance.INVARIANT
|
||||||
|
else:
|
||||||
|
variance = Variance.CONTRAVARIANT
|
||||||
|
var: TypeVar = TypeVar(name=name, bound=bound, variance=variance)
|
||||||
|
self.types.define_type(name, var)
|
||||||
|
return var
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(
|
||||||
|
call.location, "Invalid usage of 'TypeVar', skipping"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_type_from_expr(self, expr: p.Expr) -> p.MidasType:
|
||||||
|
location: Location = expr.location
|
||||||
|
parser = PythonParser()
|
||||||
|
match expr:
|
||||||
|
case p.LiteralExpr(value=str() as value):
|
||||||
|
node: ast.Expression = ast.parse(value, mode="eval")
|
||||||
|
return parser._parse_type(node.body)
|
||||||
|
case p.VariableExpr(name=name):
|
||||||
|
return p.BaseType(location=location, base=name, args=())
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_literal(self, expr: p.Expr) -> tuple[bool, Any]:
|
||||||
|
match expr:
|
||||||
|
case p.LiteralExpr(value=value):
|
||||||
|
return True, value
|
||||||
|
|
||||||
|
case p.ListExpr(items=items):
|
||||||
|
values: list[Any] = []
|
||||||
|
for item in items:
|
||||||
|
is_lit, value = self._get_literal(item)
|
||||||
|
if not is_lit:
|
||||||
|
return False, None
|
||||||
|
values.append(value)
|
||||||
|
return True, values
|
||||||
|
|
||||||
|
case p.DictExpr(keys=keys, values=values):
|
||||||
|
pairs: list[tuple[Any, Any]] = []
|
||||||
|
for key, value in zip(keys, values):
|
||||||
|
key_val = None
|
||||||
|
if key is not None:
|
||||||
|
is_lit, key_val = self._get_literal(key)
|
||||||
|
if not is_lit:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
is_lit, value_val = self._get_literal(value)
|
||||||
|
if not is_lit:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
if key is None:
|
||||||
|
# TODO: check that value is always a dict
|
||||||
|
assert isinstance(value_val, dict)
|
||||||
|
pairs.extend(value_val.items())
|
||||||
|
else:
|
||||||
|
pairs.append((key_val, value_val))
|
||||||
|
return True, dict(pairs)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def _evaluate_cast_statically(
|
||||||
|
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
|
||||||
|
) -> bool:
|
||||||
|
match target_type:
|
||||||
|
case AliasType(type=base):
|
||||||
|
return self._evaluate_cast_statically(
|
||||||
|
expr, subject_type, base, lit_value
|
||||||
|
)
|
||||||
|
|
||||||
|
case AppliedType(body=body):
|
||||||
|
return self._evaluate_cast_statically(
|
||||||
|
expr, subject_type, body, lit_value
|
||||||
|
)
|
||||||
|
|
||||||
|
case ConstraintType(type=base, constraint=constraint):
|
||||||
|
evaluated: bool = True
|
||||||
|
if not self._evaluate_cast_statically(
|
||||||
|
expr, subject_type, base, lit_value
|
||||||
|
):
|
||||||
|
evaluated = False
|
||||||
|
|
||||||
|
evaluator = Evaluator(self.types)
|
||||||
|
evaluator.set_value("_", lit_value)
|
||||||
|
res = evaluator.evaluate(constraint)
|
||||||
|
if not res:
|
||||||
|
printer = MidasPrinter()
|
||||||
|
constraint_str: str = printer.print(constraint)
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Value {lit_value!r} does not fit constraint '{constraint_str}'",
|
||||||
|
)
|
||||||
|
evaluated = False
|
||||||
|
return evaluated
|
||||||
|
|
||||||
|
case BaseType():
|
||||||
|
# TODO: do we want to allow cast(float, int)? would require runtime conversion
|
||||||
|
if not self.types.is_subtype(
|
||||||
|
subject_type, target_type
|
||||||
|
) or not self.types.is_subtype(target_type, subject_type):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Value {lit_value!r} of type {subject_type} cannot be cast as {target_type}",
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
case DataFrameType() | ColumnType():
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Cannot cast {lit_value!r} to {target_type}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self.reporter.info(
|
||||||
|
expr.location, f"Cannot evaluate cast to {target_type} statically"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _visit_tuple_subscript(self, tup: TupleType, expr: p.SubscriptExpr) -> Type:
|
||||||
|
match expr.index:
|
||||||
|
case p.LiteralExpr(value=int() as index):
|
||||||
|
if index < 0 or index >= len(tup.items):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Index {index} out of range for tuple {tup}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return tup.items[index]
|
||||||
|
case _:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Invalid index type {expr.index} on {tup}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
def _visit_frame_subscript(
|
||||||
|
self, frame: DataFrameType, expr: p.SubscriptExpr
|
||||||
|
) -> Type:
|
||||||
|
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
|
||||||
|
|||||||
@@ -1,29 +1,44 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.ast.midas import MemberKind
|
||||||
from midas.checker.builtins import BUILTIN_SUBTYPES
|
from midas.checker.builtins import BUILTIN_SUBTYPES
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
|
ColumnType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
|
Predicate,
|
||||||
TopType,
|
TopType,
|
||||||
|
TupleType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnknownType,
|
UnknownType,
|
||||||
|
Variance,
|
||||||
substitute_typevars,
|
substitute_typevars,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Member:
|
||||||
|
kind: MemberKind
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
|
||||||
class TypesRegistry:
|
class TypesRegistry:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
self.logger: logging.Logger = logging.getLogger("TypesRegistry")
|
||||||
self._types: dict[str, Type] = {}
|
self._types: dict[str, Type] = {}
|
||||||
self._members: dict[str, dict[str, Type]] = {}
|
self._members: dict[str, dict[str, Member]] = {}
|
||||||
|
self._predicates: dict[str, Predicate] = {}
|
||||||
|
|
||||||
def get_type(self, name: str) -> Type:
|
def get_type(self, name: str) -> Type:
|
||||||
"""Get a type from its name
|
"""Get a type from its name
|
||||||
@@ -60,26 +75,43 @@ class TypesRegistry:
|
|||||||
return type
|
return type
|
||||||
|
|
||||||
def define_member(
|
def define_member(
|
||||||
self, type_name: str, member_name: str, member_type: Type, is_method: bool
|
self,
|
||||||
|
type_name: str,
|
||||||
|
member_name: str,
|
||||||
|
member_type: Type,
|
||||||
|
kind: MemberKind,
|
||||||
):
|
):
|
||||||
members: dict[str, Type] = self._members.setdefault(type_name, {})
|
members: dict[str, Member] = self._members.setdefault(type_name, {})
|
||||||
if member_name in members:
|
if member_name in members:
|
||||||
if not is_method:
|
current: Member = members[member_name]
|
||||||
|
if current.kind != kind:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
f"Member '{member_name}' already defined for type {type_name}"
|
f"Member '{member_name}' is already defined as a {current.kind},"
|
||||||
|
+ f" cannot define a {kind} with the same name"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
current: Type = members[member_name]
|
if kind != MemberKind.METHOD:
|
||||||
|
self.logger.error(
|
||||||
|
f"Member '{member_name}' already defined for type {type_name},"
|
||||||
|
+ " only methods can be overloaded"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
combined: Type
|
combined: Type
|
||||||
match current:
|
match current.type:
|
||||||
case OverloadedFunction(overloads=overloads):
|
case OverloadedFunction(overloads=overloads):
|
||||||
combined = OverloadedFunction(overloads=overloads + [member_type])
|
combined = OverloadedFunction(overloads=overloads + [member_type])
|
||||||
case _:
|
case _:
|
||||||
combined = OverloadedFunction(overloads=[current, member_type])
|
combined = OverloadedFunction(overloads=[current.type, member_type])
|
||||||
members[member_name] = combined
|
members[member_name] = Member(kind=current.kind, type=combined)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
members[member_name] = member_type
|
members[member_name] = Member(kind=kind, type=member_type)
|
||||||
|
|
||||||
|
def define_predicate(self, name: str, predicate: Predicate):
|
||||||
|
if name in self._predicates:
|
||||||
|
raise ValueError(f"Predicate {name} already defined")
|
||||||
|
self._predicates[name] = predicate
|
||||||
|
|
||||||
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
def is_subtype(self, type1: Type, type2: Type) -> bool:
|
||||||
"""Check whether `type1` is a subtype of `type2`
|
"""Check whether `type1` is a subtype of `type2`
|
||||||
@@ -101,6 +133,19 @@ class TypesRegistry:
|
|||||||
case (_, TopType()):
|
case (_, TopType()):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
case (_, UnknownType()):
|
||||||
|
return True
|
||||||
|
|
||||||
|
case (TypeVar(bound=bound), _):
|
||||||
|
if bound is None:
|
||||||
|
return False
|
||||||
|
return self.is_subtype(bound, type2)
|
||||||
|
|
||||||
|
case (_, TypeVar(bound=bound)):
|
||||||
|
if bound is None:
|
||||||
|
return True
|
||||||
|
return self.is_subtype(type1, bound)
|
||||||
|
|
||||||
case (AliasType(type=base1), _):
|
case (AliasType(type=base1), _):
|
||||||
return self.is_subtype(base1, type2)
|
return self.is_subtype(base1, type2)
|
||||||
|
|
||||||
@@ -115,16 +160,57 @@ class TypesRegistry:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
case (DataFrameType(columns=columns1), DataFrameType(columns=columns2)):
|
||||||
|
# TODO: check order?
|
||||||
|
by_name1: dict[str, DataFrameType.Column] = {
|
||||||
|
col.name: col for col in columns1 if col.name is not None
|
||||||
|
}
|
||||||
|
for col2 in columns2:
|
||||||
|
if col2.name not in by_name1:
|
||||||
|
return False
|
||||||
|
if not self.is_subtype(by_name1[col2.name].type, col2.type):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
case (ColumnType(type=inner1), ColumnType(type=inner2)):
|
||||||
|
# TODO: invariant, replace ColumnType with simple GenericType
|
||||||
|
if not self.are_equivalent(inner1, inner2):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
case (Function(), Function()):
|
case (Function(), Function()):
|
||||||
return self.is_func_subtype(type1, type2)
|
return self.is_func_subtype(type1, type2)
|
||||||
|
|
||||||
case (TypeVar(bound=bound), _):
|
case (ConstraintType(type=base1), _):
|
||||||
if bound is None:
|
return self.is_subtype(base1, type2)
|
||||||
return False
|
|
||||||
return self.is_subtype(bound, type2)
|
case (
|
||||||
|
AppliedType(name=name1, args=args1),
|
||||||
|
AppliedType(name=name2, args=args2),
|
||||||
|
) if (
|
||||||
|
name1 == name2
|
||||||
|
):
|
||||||
|
generic: Type = self.get_type(name1)
|
||||||
|
assert isinstance(generic, GenericType)
|
||||||
|
for param, arg1, arg2 in zip(generic.params, args1, args2):
|
||||||
|
variance: Variance = param.variance
|
||||||
|
if variance in {Variance.INVARIANT, Variance.COVARIANT}:
|
||||||
|
if not self.is_subtype(arg1, arg2):
|
||||||
|
return False
|
||||||
|
if variance in {Variance.INVARIANT, Variance.CONTRAVARIANT}:
|
||||||
|
if not self.is_subtype(arg2, arg1):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
# TODO: verify legitimacy
|
||||||
|
case (AppliedType(body=body), _):
|
||||||
|
return self.is_subtype(body, type2)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def are_equivalent(self, type1: Type, type2: Type) -> bool:
|
||||||
|
return self.is_subtype(type1, type2) and self.is_subtype(type2, type1)
|
||||||
|
|
||||||
# TODO: verify the logic in here
|
# TODO: verify the logic in here
|
||||||
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
|
||||||
"""Check whether a function is a subtype of another
|
"""Check whether a function is a subtype of another
|
||||||
@@ -261,6 +347,9 @@ class TypesRegistry:
|
|||||||
body=substitute_typevars(body, substitutions),
|
body=substitute_typevars(body, substitutions),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case BaseType(name="tuple"):
|
||||||
|
return TupleType(items=tuple(args))
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"{type} is not a generic type")
|
raise ValueError(f"{type} is not a generic type")
|
||||||
|
|
||||||
@@ -297,13 +386,13 @@ class TypesRegistry:
|
|||||||
case BaseType(name=name):
|
case BaseType(name=name):
|
||||||
if name in self._members:
|
if name in self._members:
|
||||||
if member_name in self._members[name]:
|
if member_name in self._members[name]:
|
||||||
return self._members[name][member_name]
|
return self._members[name][member_name].type
|
||||||
return None
|
return None
|
||||||
|
|
||||||
case AliasType(name=name, type=base):
|
case AliasType(name=name, type=base):
|
||||||
if name in self._members:
|
if name in self._members:
|
||||||
if member_name in self._members[name]:
|
if member_name in self._members[name]:
|
||||||
return self._members[name][member_name]
|
return self._members[name][member_name].type
|
||||||
return self.lookup_member(base, member_name)
|
return self.lookup_member(base, member_name)
|
||||||
|
|
||||||
case AppliedType(name=name, body=body, args=args):
|
case AppliedType(name=name, body=body, args=args):
|
||||||
@@ -317,7 +406,7 @@ class TypesRegistry:
|
|||||||
}
|
}
|
||||||
if name in self._members:
|
if name in self._members:
|
||||||
if member_name in self._members[name]:
|
if member_name in self._members[name]:
|
||||||
member_type: Type = self._members[name][member_name]
|
member_type: Type = self._members[name][member_name].type
|
||||||
return substitute_typevars(member_type, substitutions)
|
return substitute_typevars(member_type, substitutions)
|
||||||
|
|
||||||
member_type2: Optional[Type] = self.lookup_member(body, member_name)
|
member_type2: Optional[Type] = self.lookup_member(body, member_name)
|
||||||
@@ -339,9 +428,18 @@ class TypesRegistry:
|
|||||||
)
|
)
|
||||||
return self.lookup_member(base, member_name)
|
return self.lookup_member(base, member_name)
|
||||||
|
|
||||||
|
case ConstraintType(type=base):
|
||||||
|
return self.lookup_member(base, member_name)
|
||||||
|
|
||||||
|
case TypeVar(bound=bound) if bound is not None:
|
||||||
|
return self.lookup_member(bound, member_name)
|
||||||
|
|
||||||
case UnknownType():
|
case UnknownType():
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
self.logger.debug(f"Can't get member on {type}")
|
self.logger.debug(f"Can't get member on {type}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def lookup_predicate(self, name: str) -> Optional[Predicate]:
|
||||||
|
return self._predicates.get(name)
|
||||||
|
|||||||
@@ -61,3 +61,10 @@ class FileReporter:
|
|||||||
location=location,
|
location=location,
|
||||||
message=message,
|
message=message,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def debug(self, location: Location, message: str):
|
||||||
|
self.report(
|
||||||
|
type=DiagnosticType.DEBUG,
|
||||||
|
location=location,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|||||||
@@ -128,6 +128,10 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
|
|
||||||
case p.GetExpr():
|
case p.GetExpr():
|
||||||
target.accept(self)
|
target.accept(self)
|
||||||
|
|
||||||
|
case p.SubscriptExpr():
|
||||||
|
target.accept(self)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f"Unsupported assignment to {target}")
|
raise Exception(f"Unsupported assignment to {target}")
|
||||||
|
|
||||||
@@ -232,5 +236,9 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
|
|||||||
if expr.step is not None:
|
if expr.step is not None:
|
||||||
self.resolve(expr.step)
|
self.resolve(expr.step)
|
||||||
|
|
||||||
|
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||||
|
for item in expr.items:
|
||||||
|
self.resolve(item)
|
||||||
|
|
||||||
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
def visit_raw_expr(self, expr: p.RawExpr) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from enum import StrEnum
|
||||||
|
from typing import Optional, assert_never, cast
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
from midas.ast.printer import MidasPrinter
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
@@ -99,15 +103,27 @@ class ExtensionType:
|
|||||||
return f"{self.base} & {self.extension}"
|
return f"{self.base} & {self.extension}"
|
||||||
|
|
||||||
|
|
||||||
|
class Variance(StrEnum):
|
||||||
|
INVARIANT = "INVARIANT"
|
||||||
|
COVARIANT = "COVARIANT"
|
||||||
|
CONTRAVARIANT = "CONTRAVARIANT"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class TypeVar:
|
class TypeVar:
|
||||||
name: str
|
name: str
|
||||||
bound: Optional[Type]
|
bound: Optional[Type]
|
||||||
|
variance: Variance = Variance.INVARIANT
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
variance: str = {
|
||||||
|
Variance.COVARIANT: "+",
|
||||||
|
Variance.CONTRAVARIANT: "-",
|
||||||
|
}.get(self.variance, "")
|
||||||
|
res: str = f"{variance}{self.name}"
|
||||||
if self.bound is not None:
|
if self.bound is not None:
|
||||||
return f"{self.name} <: {self.bound}"
|
res = f"{res} <: {self.bound}"
|
||||||
return self.name
|
return res
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
@@ -130,6 +146,47 @@ class AppliedType:
|
|||||||
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
return f"{self.name}[{', '.join(map(str, self.args))}]"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ConstraintType:
|
||||||
|
type: Type
|
||||||
|
constraint: m.Expr
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
printer = MidasPrinter()
|
||||||
|
return f"{self.type} where {printer.print(self.constraint)}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class TupleType:
|
||||||
|
items: tuple[Type, ...]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"({', '.join(map(str, self.items))})"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ColumnType:
|
||||||
|
type: Type
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"Column[{self.type}]"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class DataFrameType:
|
||||||
|
columns: list[Column]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
schema: list[str] = [f"{col.name}: {col.type}" for col in self.columns]
|
||||||
|
return f"Frame[{', '.join(schema)}]"
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Column:
|
||||||
|
index: int
|
||||||
|
name: Optional[str]
|
||||||
|
type: ColumnType
|
||||||
|
|
||||||
|
|
||||||
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
||||||
def sub_argument(arg: Function.Argument):
|
def sub_argument(arg: Function.Argument):
|
||||||
return Function.Argument(
|
return Function.Argument(
|
||||||
@@ -139,7 +196,17 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
required=arg.required,
|
required=arg.required,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def sub_column(col: DataFrameType.Column):
|
||||||
|
return DataFrameType.Column(
|
||||||
|
index=col.index,
|
||||||
|
name=col.name,
|
||||||
|
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
|
||||||
|
)
|
||||||
|
|
||||||
match type:
|
match type:
|
||||||
|
case TopType():
|
||||||
|
return type
|
||||||
|
|
||||||
case BaseType(name=name) if name in substitutions:
|
case BaseType(name=name) if name in substitutions:
|
||||||
return substitutions[name]
|
return substitutions[name]
|
||||||
|
|
||||||
@@ -195,17 +262,58 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
|
|||||||
body=substitute_typevars(body, substitutions),
|
body=substitute_typevars(body, substitutions),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case ConstraintType():
|
||||||
|
return ConstraintType(
|
||||||
|
type=substitute_typevars(type.type, substitutions),
|
||||||
|
constraint=type.constraint,
|
||||||
|
)
|
||||||
|
|
||||||
case TypeVar(name=name):
|
case TypeVar(name=name):
|
||||||
if name in substitutions:
|
if name in substitutions:
|
||||||
return substitutions[name]
|
return substitutions[name]
|
||||||
raise ValueError(f"Missing TypeVar substitution for {name}")
|
raise ValueError(f"Missing TypeVar substitution for {name}")
|
||||||
|
|
||||||
|
case GenericType(name=name, params=params, body=body):
|
||||||
|
params2: list[TypeVar] = []
|
||||||
|
for param in params:
|
||||||
|
param2: Type = substitute_typevars(param, substitutions)
|
||||||
|
if not isinstance(param2, TypeVar):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid type parameter substitution, expected TypeVar, got {param2}"
|
||||||
|
)
|
||||||
|
params2.append(param2)
|
||||||
|
return GenericType(
|
||||||
|
name=name,
|
||||||
|
params=params2,
|
||||||
|
body=substitute_typevars(body, substitutions),
|
||||||
|
)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
return TupleType(
|
||||||
|
items=tuple(substitute_typevars(item, substitutions) for item in items),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ColumnType(type=items_type):
|
||||||
|
return ColumnType(
|
||||||
|
type=substitute_typevars(items_type, substitutions),
|
||||||
|
)
|
||||||
|
|
||||||
|
case DataFrameType(columns=columns):
|
||||||
|
return DataFrameType(
|
||||||
|
columns=list(map(sub_column, columns)),
|
||||||
|
)
|
||||||
|
|
||||||
case UnknownType() | UnitType():
|
case UnknownType() | UnitType():
|
||||||
return type
|
return type
|
||||||
|
|
||||||
case _:
|
case TopType() | GenericType():
|
||||||
|
|
||||||
raise NotImplementedError(f"Unsupported type {type}")
|
raise NotImplementedError(f"Unsupported type {type}")
|
||||||
|
|
||||||
|
# Ensure exhaustiveness
|
||||||
|
case _:
|
||||||
|
assert_never(type)
|
||||||
|
|
||||||
|
|
||||||
def unfold_type(type: Type) -> Type:
|
def unfold_type(type: Type) -> Type:
|
||||||
match type:
|
match type:
|
||||||
@@ -215,6 +323,74 @@ def unfold_type(type: Type) -> Type:
|
|||||||
return type
|
return type
|
||||||
|
|
||||||
|
|
||||||
|
def to_annotation(type: Type) -> str:
|
||||||
|
def _args_annotation(func: Function) -> str:
|
||||||
|
if len(func.kw_args) != 0:
|
||||||
|
return "..."
|
||||||
|
|
||||||
|
args: str = ", ".join(
|
||||||
|
to_annotation(arg.type) for arg in func.pos_args + func.args
|
||||||
|
)
|
||||||
|
return f"[{args}]"
|
||||||
|
|
||||||
|
match type:
|
||||||
|
case TopType():
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
case BaseType(name=name):
|
||||||
|
return name
|
||||||
|
|
||||||
|
case AliasType(name=name):
|
||||||
|
return name
|
||||||
|
|
||||||
|
case UnknownType():
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
case UnitType():
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
case Function(returns=returns):
|
||||||
|
params_annot: str = _args_annotation(type)
|
||||||
|
return f"Callable[{params_annot}, {to_annotation(returns)}]"
|
||||||
|
|
||||||
|
case OverloadedFunction():
|
||||||
|
return "Callable"
|
||||||
|
|
||||||
|
case ComplexType() | ExtensionType():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
case TypeVar(name=name):
|
||||||
|
return name
|
||||||
|
|
||||||
|
case GenericType(name=name, params=params):
|
||||||
|
return f"{name}[{', '.join(map(to_annotation, params))}]"
|
||||||
|
|
||||||
|
case AppliedType(name=name, args=args):
|
||||||
|
return f"{name}[{', '.join(map(to_annotation, args))}]"
|
||||||
|
|
||||||
|
case ConstraintType():
|
||||||
|
return str(type)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
return f"Tuple[{', '.join(map(to_annotation, items))}]"
|
||||||
|
|
||||||
|
case ColumnType():
|
||||||
|
return "pd.Series"
|
||||||
|
|
||||||
|
case DataFrameType():
|
||||||
|
return "pd.DataFrame"
|
||||||
|
|
||||||
|
case _:
|
||||||
|
assert_never(type)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Predicate:
|
||||||
|
type: Type
|
||||||
|
body: m.Expr
|
||||||
|
alias: bool
|
||||||
|
|
||||||
|
|
||||||
Type = (
|
Type = (
|
||||||
TopType
|
TopType
|
||||||
| BaseType
|
| BaseType
|
||||||
@@ -228,4 +404,8 @@ Type = (
|
|||||||
| TypeVar
|
| TypeVar
|
||||||
| GenericType
|
| GenericType
|
||||||
| AppliedType
|
| AppliedType
|
||||||
|
| ConstraintType
|
||||||
|
| TupleType
|
||||||
|
| ColumnType
|
||||||
|
| DataFrameType
|
||||||
)
|
)
|
||||||
|
|||||||
169
midas/checker/unifier.py
Normal file
169
midas/checker/unifier.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.types import (
|
||||||
|
AppliedType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
TopType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnificationError(Exception): ...
|
||||||
|
|
||||||
|
|
||||||
|
class Unifier:
|
||||||
|
def __init__(self, types: TypesRegistry) -> None:
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self.logger: logging.Logger = logging.getLogger("Unifier")
|
||||||
|
|
||||||
|
def unify_call(
|
||||||
|
self,
|
||||||
|
type: GenericType,
|
||||||
|
positional: list[Type],
|
||||||
|
keywords: dict[str, Type],
|
||||||
|
) -> Optional[Type]:
|
||||||
|
concrete_func: Function = Function(
|
||||||
|
pos_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=i,
|
||||||
|
name=str(i),
|
||||||
|
type=arg,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
for i, arg in enumerate(positional)
|
||||||
|
],
|
||||||
|
args=[],
|
||||||
|
kw_args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=len(positional) + i,
|
||||||
|
name=name,
|
||||||
|
type=arg,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
for i, (name, arg) in enumerate(keywords.items())
|
||||||
|
],
|
||||||
|
returns=TopType(), # TODO: use expected type
|
||||||
|
)
|
||||||
|
return self.unify_generic(type, concrete_func, match_return=False)
|
||||||
|
|
||||||
|
def unify_generic(
|
||||||
|
self,
|
||||||
|
template: GenericType,
|
||||||
|
concrete: Type,
|
||||||
|
match_return: bool = True,
|
||||||
|
) -> Optional[Type]:
|
||||||
|
substitutions: dict[str, Type]
|
||||||
|
try:
|
||||||
|
substitutions = self.match(template.body, concrete, match_return)
|
||||||
|
except UnificationError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
args: list[Type] = []
|
||||||
|
for param in template.params:
|
||||||
|
if param.name not in substitutions:
|
||||||
|
return None
|
||||||
|
args.append(substitutions[param.name])
|
||||||
|
|
||||||
|
applied: Type = self.types.apply_generic(template, args)
|
||||||
|
return applied
|
||||||
|
|
||||||
|
def match(
|
||||||
|
self,
|
||||||
|
template: Type,
|
||||||
|
concrete: Type,
|
||||||
|
match_return: bool = True,
|
||||||
|
) -> dict[str, Type]:
|
||||||
|
# TODO: if concrete is Generic, record bound TypeVar. Then when merging
|
||||||
|
# substitutions, check that the constraint is respected
|
||||||
|
match (template, concrete):
|
||||||
|
case (TypeVar(name=name), _):
|
||||||
|
return {name: concrete}
|
||||||
|
|
||||||
|
case (
|
||||||
|
AppliedType(name=template_name, args=template_args),
|
||||||
|
AppliedType(name=concrete_name, args=concrete_args),
|
||||||
|
) if template_name == concrete_name and len(template_args) == len(
|
||||||
|
concrete_args
|
||||||
|
):
|
||||||
|
substitutions: dict[str, Type] = {}
|
||||||
|
for template_arg, concrete_arg in zip(template_args, concrete_args):
|
||||||
|
new_substistutions: dict[str, Type] = self.match(
|
||||||
|
template_arg, concrete_arg
|
||||||
|
)
|
||||||
|
substitutions = self.merge(substitutions, new_substistutions)
|
||||||
|
|
||||||
|
return substitutions
|
||||||
|
|
||||||
|
case (Function(), Function()):
|
||||||
|
mapped: list[tuple[Function.Argument, Function.Argument]] = (
|
||||||
|
self.map_params(template, concrete)
|
||||||
|
)
|
||||||
|
substitutions: dict[str, Type] = {}
|
||||||
|
for template_arg, concrete_arg in mapped:
|
||||||
|
arg_subs: dict[str, Type] = self.match(
|
||||||
|
template_arg.type, concrete_arg.type
|
||||||
|
)
|
||||||
|
substitutions = self.merge(substitutions, arg_subs)
|
||||||
|
|
||||||
|
if match_return:
|
||||||
|
return_subs: dict[str, Type] = self.match(
|
||||||
|
template.returns, concrete.returns
|
||||||
|
)
|
||||||
|
substitutions = self.merge(substitutions, return_subs)
|
||||||
|
|
||||||
|
return substitutions
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self.logger.debug(f"Can't match {concrete!r} with {template!r}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def merge(self, subs1: dict[str, Type], subs2: dict[str, Type]) -> dict[str, Type]:
|
||||||
|
merged: dict[str, Type] = subs1.copy()
|
||||||
|
|
||||||
|
for k, v in subs2.items():
|
||||||
|
if k in merged and merged[k] != v:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Substitution already defined for {k} with type {merged[k]}, got {v}"
|
||||||
|
)
|
||||||
|
raise UnificationError
|
||||||
|
merged[k] = v
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def map_params(
|
||||||
|
self, func1: Function, func2: Function
|
||||||
|
) -> list[tuple[Function.Argument, Function.Argument]]:
|
||||||
|
pos1: list[Function.Argument] = func1.pos_args
|
||||||
|
mixed1: list[Function.Argument] = func1.args
|
||||||
|
kw1: list[Function.Argument] = func1.kw_args
|
||||||
|
|
||||||
|
pos2: list[Function.Argument] = func2.pos_args
|
||||||
|
mixed2: list[Function.Argument] = func2.args
|
||||||
|
kw2: list[Function.Argument] = func2.kw_args
|
||||||
|
|
||||||
|
mapped: list[tuple[Function.Argument, Function.Argument]] = []
|
||||||
|
|
||||||
|
by_pos2: dict[int, Function.Argument] = {arg.pos: arg for arg in pos2 + mixed2}
|
||||||
|
by_name2: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2 + kw2}
|
||||||
|
|
||||||
|
for arg1 in pos1:
|
||||||
|
if (arg2 := by_pos2.get(arg1.pos)) is not None:
|
||||||
|
mapped.append((arg1, arg2))
|
||||||
|
|
||||||
|
for arg1 in mixed1:
|
||||||
|
# Match both positionally and by name, conflicts are caught
|
||||||
|
# when merging substitutions
|
||||||
|
if (arg2 := by_pos2.get(arg1.pos)) is not None:
|
||||||
|
mapped.append((arg1, arg2))
|
||||||
|
|
||||||
|
if (arg2 := by_name2.get(arg1.name)) is not None:
|
||||||
|
mapped.append((arg1, arg2))
|
||||||
|
|
||||||
|
for arg1 in kw1:
|
||||||
|
if (arg2 := by_name2.get(arg1.name)) is not None:
|
||||||
|
mapped.append((arg1, arg2))
|
||||||
|
|
||||||
|
return mapped
|
||||||
129
midas/checker/variance.py
Normal file
129
midas/checker/variance.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
from typing import Literal, Optional, cast
|
||||||
|
|
||||||
|
from midas.checker.registry import Member, TypesRegistry
|
||||||
|
from midas.checker.types import (
|
||||||
|
AppliedType,
|
||||||
|
ConstraintType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
OverloadedFunction,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Variance,
|
||||||
|
)
|
||||||
|
|
||||||
|
Polarity = Literal[-1, 0, 1]
|
||||||
|
|
||||||
|
|
||||||
|
class Tracker:
|
||||||
|
def __init__(self, vars: list[TypeVar]) -> None:
|
||||||
|
self.vars: list[TypeVar] = vars
|
||||||
|
self.refs: dict[str, set[Polarity]] = {var.name: set() for var in self.vars}
|
||||||
|
|
||||||
|
def record(self, var: TypeVar, polarity: Polarity):
|
||||||
|
self.refs[var.name].add(polarity)
|
||||||
|
|
||||||
|
def get_updated_vars(self) -> list[TypeVar]:
|
||||||
|
return [
|
||||||
|
TypeVar(
|
||||||
|
name=var.name, bound=var.bound, variance=self.get_variance(var.name)
|
||||||
|
)
|
||||||
|
for var in self.vars
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_variance(self, name: str) -> Variance:
|
||||||
|
refs: set[Polarity] = self.refs[name]
|
||||||
|
if refs == {-1}:
|
||||||
|
return Variance.CONTRAVARIANT
|
||||||
|
if refs == {1}:
|
||||||
|
return Variance.COVARIANT
|
||||||
|
return Variance.INVARIANT
|
||||||
|
|
||||||
|
def __contains__(self, item: TypeVar | str):
|
||||||
|
if isinstance(item, TypeVar):
|
||||||
|
return item.name in self
|
||||||
|
return item in self.refs
|
||||||
|
|
||||||
|
|
||||||
|
class VarianceInferrer:
|
||||||
|
def __init__(self, types: TypesRegistry) -> None:
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self.tracker: Tracker = Tracker([])
|
||||||
|
|
||||||
|
def infer(self, type: GenericType) -> GenericType:
|
||||||
|
self.tracker = Tracker(type.params)
|
||||||
|
|
||||||
|
self.walk(type.body, 1, type.name)
|
||||||
|
members: dict[str, Member] = self.types._members.get(type.name, {})
|
||||||
|
for name, member in members.items():
|
||||||
|
self.walk(member.type, 1, type.name, [f"member:'{name}'"])
|
||||||
|
|
||||||
|
return GenericType(
|
||||||
|
name=type.name,
|
||||||
|
params=self.tracker.get_updated_vars(),
|
||||||
|
body=type.body,
|
||||||
|
)
|
||||||
|
|
||||||
|
def walk(
|
||||||
|
self,
|
||||||
|
type: Type,
|
||||||
|
polarity: Polarity,
|
||||||
|
base_name: str,
|
||||||
|
path: Optional[list[str]] = None,
|
||||||
|
):
|
||||||
|
if path is None:
|
||||||
|
path = []
|
||||||
|
|
||||||
|
match type:
|
||||||
|
# Arguments are negative positions -> flip polarity
|
||||||
|
# Return is positive position -> keep polarity
|
||||||
|
case Function(pos_args=pos_args, args=mixed_args, kw_args=kw_args):
|
||||||
|
all_args: list[Function.Argument] = pos_args + mixed_args + kw_args
|
||||||
|
for arg in all_args:
|
||||||
|
self.walk(
|
||||||
|
arg.type,
|
||||||
|
-polarity,
|
||||||
|
base_name,
|
||||||
|
path + [f"arg:'{arg.name}'"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.walk(type.returns, polarity, base_name, path + ["return"])
|
||||||
|
|
||||||
|
# Walk all overloads
|
||||||
|
case OverloadedFunction(overloads=overloads):
|
||||||
|
for overload in overloads:
|
||||||
|
self.walk(overload, polarity, base_name, path)
|
||||||
|
|
||||||
|
# If same name as root generic -> skip
|
||||||
|
# Get inferred variance of parameters and multiply with current
|
||||||
|
# polarity to recurse through arguments
|
||||||
|
case AppliedType(name=name, args=args):
|
||||||
|
# TODO: handle mutually recursive types
|
||||||
|
if name == base_name:
|
||||||
|
return
|
||||||
|
generic: Type = self.types.get_type(name)
|
||||||
|
assert isinstance(generic, GenericType)
|
||||||
|
params: list[TypeVar] = generic.params
|
||||||
|
polarities: dict[Variance, Polarity] = {
|
||||||
|
Variance.INVARIANT: 0,
|
||||||
|
Variance.COVARIANT: 1,
|
||||||
|
Variance.CONTRAVARIANT: -1,
|
||||||
|
}
|
||||||
|
for arg, param in zip(args, params):
|
||||||
|
param_polarity: Polarity = polarities[param.variance]
|
||||||
|
self.walk(
|
||||||
|
arg,
|
||||||
|
cast(Polarity, polarity * param_polarity),
|
||||||
|
base_name,
|
||||||
|
path + [f"applied:'{name}'"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Walk base type
|
||||||
|
case ConstraintType(type=base):
|
||||||
|
self.walk(base, polarity, base_name, path + ["constraint"])
|
||||||
|
|
||||||
|
# Reached end
|
||||||
|
# If tracked, record polarity
|
||||||
|
case TypeVar():
|
||||||
|
if type in self.tracker:
|
||||||
|
self.tracker.record(type, polarity)
|
||||||
@@ -4,5 +4,6 @@ from .format import format as format
|
|||||||
from .highlight import highlight as highlight
|
from .highlight import highlight as highlight
|
||||||
from .parse import parse as parse
|
from .parse import parse as parse
|
||||||
from .registry import dump_registry as dump_registry
|
from .registry import dump_registry as dump_registry
|
||||||
|
from .stubs import stubs as stubs
|
||||||
from .types import types as types
|
from .types import types as types
|
||||||
from .validate import validate as validate
|
from .validate import validate as validate
|
||||||
|
|||||||
@@ -19,9 +19,11 @@ from midas.utils import TypedAST
|
|||||||
@click.command(help="Compile source")
|
@click.command(help="Compile source")
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
@click.option("-t", "--types", type=click.File("r"), multiple=True)
|
||||||
|
@click.option("--ignore-errors", is_flag=True)
|
||||||
def compile(
|
def compile(
|
||||||
file: TextIO,
|
file: TextIO,
|
||||||
types: tuple[TextIO],
|
types: tuple[TextIO],
|
||||||
|
ignore_errors: bool,
|
||||||
):
|
):
|
||||||
source: str = file.read()
|
source: str = file.read()
|
||||||
source_path: Path = Path(file.name).resolve()
|
source_path: Path = Path(file.name).resolve()
|
||||||
@@ -35,8 +37,10 @@ def compile(
|
|||||||
printer = DiagnosticPrinter()
|
printer = DiagnosticPrinter()
|
||||||
printer.print_all(diagnostics)
|
printer.print_all(diagnostics)
|
||||||
|
|
||||||
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
|
if not ignore_errors and any(
|
||||||
|
map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)
|
||||||
|
):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
generator = Generator(workdir=source_path.parent)
|
generator = Generator(workdir=source_path.parent, types=checker.types)
|
||||||
generator.generate(typed_ast, source_path)
|
generator.generate(typed_ast, source_path)
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ from typing import TextIO
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
from midas.ast.printer import MidasPrinter
|
||||||
from midas.checker.checker import TypeChecker
|
from midas.checker.checker import TypeChecker
|
||||||
|
from midas.checker.registry import Member
|
||||||
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
|
from midas.checker.types import AliasType, AppliedType, BaseType, GenericType, Type
|
||||||
|
|
||||||
|
|
||||||
@@ -35,10 +37,30 @@ def dump_registry(
|
|||||||
for types_file in types:
|
for types_file in types:
|
||||||
checker.import_midas(Path(types_file.name).resolve())
|
checker.import_midas(Path(types_file.name).resolve())
|
||||||
|
|
||||||
|
print("##### Types #####")
|
||||||
for name, type in checker.types._types.items():
|
for name, type in checker.types._types.items():
|
||||||
members: dict[str, Type] = checker.types._members.get(name, {})
|
members: dict[str, Member] = checker.types._members.get(name, {})
|
||||||
print(f"{name} = {base_type(type)}")
|
params: str = ""
|
||||||
|
if isinstance(type, GenericType):
|
||||||
|
params = ", ".join(map(str, type.params))
|
||||||
|
params = f"[{params}]"
|
||||||
|
print(f"{name}{params} = {base_type(type)}")
|
||||||
if len(members) != 0:
|
if len(members) != 0:
|
||||||
print(" " * 4 + "Members:")
|
print(" " * 4 + "Members:")
|
||||||
for member_name, member_type in members.items():
|
for member_name, member in members.items():
|
||||||
print(" " * 8 + f"{member_name}: {member_type}")
|
kind: str = member.kind.name
|
||||||
|
print(" " * 8 + f"({kind:8}) {member_name}: {member.type}")
|
||||||
|
|
||||||
|
print("##### Predicates #####")
|
||||||
|
printer = MidasPrinter()
|
||||||
|
for name, predicate in checker.types._predicates.items():
|
||||||
|
body: str = printer.print(predicate.body)
|
||||||
|
if predicate.alias:
|
||||||
|
print(f"{name}: {predicate.type} = {body}")
|
||||||
|
else:
|
||||||
|
print(f"{name}{predicate.type}:")
|
||||||
|
body = "\n".join(
|
||||||
|
" " + ("return " if i == 0 else "") + line
|
||||||
|
for i, line in enumerate(body.split("\n"))
|
||||||
|
)
|
||||||
|
print(body)
|
||||||
|
|||||||
64
midas/cli/commands/stubs.py
Normal file
64
midas/cli/commands/stubs.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import ast
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TextIO
|
||||||
|
|
||||||
|
import black
|
||||||
|
import click
|
||||||
|
from watchdog.events import DirModifiedEvent, FileModifiedEvent, FileSystemEventHandler
|
||||||
|
from watchdog.observers import Observer
|
||||||
|
|
||||||
|
from midas.checker.checker import TypeChecker
|
||||||
|
from midas.generator.stubs import StubsGenerator
|
||||||
|
|
||||||
|
|
||||||
|
def generate_stubs(in_path: Path, out_path: Path):
|
||||||
|
checker = TypeChecker()
|
||||||
|
checker.import_midas(in_path)
|
||||||
|
|
||||||
|
generator = StubsGenerator(checker.types)
|
||||||
|
module: ast.Module = generator.generate_stubs()
|
||||||
|
module = ast.fix_missing_locations(module)
|
||||||
|
|
||||||
|
output: str = ast.unparse(module)
|
||||||
|
output = black.format_str(output, mode=black.Mode(is_pyi=True))
|
||||||
|
|
||||||
|
out_path.write_text(output)
|
||||||
|
|
||||||
|
|
||||||
|
class Handler(FileSystemEventHandler):
|
||||||
|
def __init__(self, in_path: Path, out_path: Path) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.in_path: Path = in_path
|
||||||
|
self.out_path: Path = out_path
|
||||||
|
|
||||||
|
def on_modified(self, event: DirModifiedEvent | FileModifiedEvent) -> None:
|
||||||
|
generate_stubs(self.in_path, self.out_path)
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(help="Generate stubs from Midas definitions")
|
||||||
|
@click.argument("file", type=click.File("r"))
|
||||||
|
@click.option("-o", "--output", type=click.File("w"), default="-")
|
||||||
|
@click.option("-w", "--watch", is_flag=True)
|
||||||
|
def stubs(
|
||||||
|
file: TextIO,
|
||||||
|
output: TextIO,
|
||||||
|
watch: bool,
|
||||||
|
):
|
||||||
|
source_path: Path = Path(file.name).resolve()
|
||||||
|
out_path: Path = Path(output.name).resolve()
|
||||||
|
generate_stubs(source_path, out_path)
|
||||||
|
|
||||||
|
if watch:
|
||||||
|
print(f"Watching {source_path}...")
|
||||||
|
print("Press CTRL+C to stop")
|
||||||
|
handler = Handler(source_path, out_path)
|
||||||
|
observer = Observer()
|
||||||
|
observer.schedule(handler, str(source_path))
|
||||||
|
observer.start()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
observer.stop()
|
||||||
|
observer.join()
|
||||||
@@ -41,6 +41,7 @@ def types(
|
|||||||
message=f"Type: {type}",
|
message=f"Type: {type}",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
diagnostics.extend(checker.diagnostics)
|
||||||
printer = DiagnosticPrinter()
|
printer = DiagnosticPrinter()
|
||||||
printer.print_all(diagnostics)
|
printer.print_all(diagnostics)
|
||||||
|
|
||||||
|
|||||||
@@ -134,9 +134,9 @@ class PythonHighlighter(
|
|||||||
|
|
||||||
def visit_base_type(self, node: p.BaseType) -> None:
|
def visit_base_type(self, node: p.BaseType) -> None:
|
||||||
self.wrap(node, "base-type")
|
self.wrap(node, "base-type")
|
||||||
if node.param is not None:
|
for arg in node.args:
|
||||||
self.wrap(node.param, "param")
|
self.wrap(arg, "arg")
|
||||||
node.param.accept(self)
|
arg.accept(self)
|
||||||
|
|
||||||
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
def visit_constraint_type(self, node: p.ConstraintType) -> None:
|
||||||
self.wrap(node, "constraint-type")
|
self.wrap(node, "constraint-type")
|
||||||
@@ -228,6 +228,13 @@ class PythonHighlighter(
|
|||||||
for item in expr.items:
|
for item in expr.items:
|
||||||
item.accept(self)
|
item.accept(self)
|
||||||
|
|
||||||
|
def visit_dict_expr(self, expr: p.DictExpr) -> None:
|
||||||
|
for key in expr.keys:
|
||||||
|
if key is not None:
|
||||||
|
key.accept(self)
|
||||||
|
for value in expr.values:
|
||||||
|
value.accept(self)
|
||||||
|
|
||||||
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
|
||||||
expr.object.accept(self)
|
expr.object.accept(self)
|
||||||
expr.index.accept(self)
|
expr.index.accept(self)
|
||||||
@@ -240,6 +247,14 @@ class PythonHighlighter(
|
|||||||
if expr.step is not None:
|
if expr.step is not None:
|
||||||
expr.step.accept(self)
|
expr.step.accept(self)
|
||||||
|
|
||||||
|
def visit_tuple_expr(self, expr: p.TupleExpr) -> None:
|
||||||
|
for item in expr.items:
|
||||||
|
item.accept(self)
|
||||||
|
|
||||||
|
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
|
||||||
|
|
||||||
|
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class MidasHighlighter(
|
class MidasHighlighter(
|
||||||
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
|
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
|
||||||
@@ -266,8 +281,9 @@ class MidasHighlighter(
|
|||||||
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
|
||||||
self.wrap(stmt, "predicate")
|
self.wrap(stmt, "predicate")
|
||||||
self.wrap(LocatableToken(stmt.name), "predicate-name")
|
self.wrap(LocatableToken(stmt.name), "predicate-name")
|
||||||
stmt.type.accept(self)
|
for spec in stmt.params:
|
||||||
stmt.condition.accept(self)
|
self._visit_param_spec(spec)
|
||||||
|
stmt.body.accept(self)
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
|
||||||
self.wrap(expr, "logical-expr")
|
self.wrap(expr, "logical-expr")
|
||||||
@@ -283,6 +299,14 @@ class MidasHighlighter(
|
|||||||
self.wrap(expr, "unary-expr")
|
self.wrap(expr, "unary-expr")
|
||||||
expr.right.accept(self)
|
expr.right.accept(self)
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: m.CallExpr) -> None:
|
||||||
|
self.wrap(expr, "call-expr")
|
||||||
|
expr.callee.accept(self)
|
||||||
|
for arg in expr.arguments:
|
||||||
|
arg.accept(self)
|
||||||
|
for arg in expr.keywords.values():
|
||||||
|
arg.accept(self)
|
||||||
|
|
||||||
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
def visit_get_expr(self, expr: m.GetExpr) -> None:
|
||||||
self.wrap(expr, "get-expr")
|
self.wrap(expr, "get-expr")
|
||||||
expr.expr.accept(self)
|
expr.expr.accept(self)
|
||||||
@@ -318,8 +342,7 @@ class MidasHighlighter(
|
|||||||
|
|
||||||
def visit_function_type(self, type: m.FunctionType) -> None:
|
def visit_function_type(self, type: m.FunctionType) -> None:
|
||||||
self.wrap(type, "function")
|
self.wrap(type, "function")
|
||||||
for arg in type.pos_args + type.args + type.kw_args:
|
self._visit_param_spec(type.params)
|
||||||
arg.type.accept(self)
|
|
||||||
type.returns.accept(self)
|
type.returns.accept(self)
|
||||||
|
|
||||||
def visit_extension_type(self, type: m.ExtensionType) -> None:
|
def visit_extension_type(self, type: m.ExtensionType) -> None:
|
||||||
@@ -327,6 +350,18 @@ class MidasHighlighter(
|
|||||||
type.base.accept(self)
|
type.base.accept(self)
|
||||||
type.extension.accept(self)
|
type.extension.accept(self)
|
||||||
|
|
||||||
|
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
|
||||||
|
for param in spec.pos + spec.mixed + spec.kw:
|
||||||
|
param.type.accept(self)
|
||||||
|
|
||||||
|
def visit_frame_type(self, type: m.FrameType) -> None:
|
||||||
|
self.wrap(type, "frame")
|
||||||
|
for column in type.columns:
|
||||||
|
self._visit_frame_column(column)
|
||||||
|
|
||||||
|
def _visit_frame_column(self, column: m.FrameType.Column) -> None:
|
||||||
|
self.wrap(column, "column")
|
||||||
|
|
||||||
|
|
||||||
class DiagnosticsHighlighter(Highlighter):
|
class DiagnosticsHighlighter(Highlighter):
|
||||||
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ span {
|
|||||||
--col: 108, 233, 108;
|
--col: 108, 233, 108;
|
||||||
}
|
}
|
||||||
|
|
||||||
&.param {
|
&.arg {
|
||||||
--col: 103, 192, 224;
|
--col: 103, 192, 224;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ midas.add_command(commands.highlight)
|
|||||||
midas.add_command(commands.parse)
|
midas.add_command(commands.parse)
|
||||||
midas.add_command(commands.dump_registry)
|
midas.add_command(commands.dump_registry)
|
||||||
midas.add_command(commands.types)
|
midas.add_command(commands.types)
|
||||||
|
midas.add_command(commands.stubs)
|
||||||
midas.add_command(commands.validate)
|
midas.add_command(commands.validate)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -7,6 +8,13 @@ from midas.cli.ansi import Ansi
|
|||||||
|
|
||||||
|
|
||||||
class DiagnosticPrinter:
|
class DiagnosticPrinter:
|
||||||
|
COLORS: dict[DiagnosticType, int] = {
|
||||||
|
DiagnosticType.ERROR: Ansi.RED,
|
||||||
|
DiagnosticType.WARNING: Ansi.YELLOW,
|
||||||
|
DiagnosticType.INFO: Ansi.CYAN,
|
||||||
|
DiagnosticType.DEBUG: Ansi.MAGENTA,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.files: dict[Optional[str], list[str]] = {}
|
self.files: dict[Optional[str], list[str]] = {}
|
||||||
|
|
||||||
@@ -22,10 +30,25 @@ class DiagnosticPrinter:
|
|||||||
return self.files[filename]
|
return self.files[filename]
|
||||||
|
|
||||||
def print_all(self, diagnostics: list[Diagnostic], indent: int = 4):
|
def print_all(self, diagnostics: list[Diagnostic], indent: int = 4):
|
||||||
|
by_type: dict[DiagnosticType, int] = defaultdict(int)
|
||||||
for diagnostic in diagnostics:
|
for diagnostic in diagnostics:
|
||||||
filename: Optional[str] = diagnostic.file_path
|
filename: Optional[str] = diagnostic.file_path
|
||||||
lines = self.get_lines(filename)
|
lines = self.get_lines(filename)
|
||||||
self.print(lines, diagnostic, indent=indent)
|
self.print(lines, diagnostic, indent=indent)
|
||||||
|
by_type[diagnostic.type] += 1
|
||||||
|
|
||||||
|
if len(diagnostics) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
counts: list[str] = []
|
||||||
|
for type in DiagnosticType:
|
||||||
|
if type not in by_type:
|
||||||
|
continue
|
||||||
|
count: int = by_type[type]
|
||||||
|
color: int = self.COLORS.get(type, Ansi.WHITE)
|
||||||
|
counts.append(f"{Ansi.FG(color)}{type.value}s{Ansi.RESET}: {count}")
|
||||||
|
|
||||||
|
print(" ".join(counts))
|
||||||
|
|
||||||
def print(self, lines: list[str], diagnostic: Diagnostic, indent: int = 4):
|
def print(self, lines: list[str], diagnostic: Diagnostic, indent: int = 4):
|
||||||
"""Pretty-print a diagnostic, showing some context if possible
|
"""Pretty-print a diagnostic, showing some context if possible
|
||||||
@@ -45,7 +68,7 @@ class DiagnosticPrinter:
|
|||||||
|
|
||||||
loc: Location = diagnostic.location
|
loc: Location = diagnostic.location
|
||||||
if loc.lineno != loc.end_lineno:
|
if loc.lineno != loc.end_lineno:
|
||||||
print(diagnostic)
|
self.print_multiline(lines, diagnostic, indent)
|
||||||
return
|
return
|
||||||
|
|
||||||
start_offset: int = loc.col_offset
|
start_offset: int = loc.col_offset
|
||||||
@@ -55,11 +78,7 @@ class DiagnosticPrinter:
|
|||||||
before: str = line[:start_offset]
|
before: str = line[:start_offset]
|
||||||
after: str = line[end_offset:]
|
after: str = line[end_offset:]
|
||||||
|
|
||||||
color: int = {
|
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
|
||||||
DiagnosticType.ERROR: Ansi.RED,
|
|
||||||
DiagnosticType.WARNING: Ansi.YELLOW,
|
|
||||||
DiagnosticType.INFO: Ansi.CYAN,
|
|
||||||
}.get(diagnostic.type, Ansi.WHITE)
|
|
||||||
|
|
||||||
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
|
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
|
||||||
cursor: str = (
|
cursor: str = (
|
||||||
@@ -76,3 +95,27 @@ class DiagnosticPrinter:
|
|||||||
print(indent_str + before + subject + after)
|
print(indent_str + before + subject + after)
|
||||||
print(indent_str + cursor)
|
print(indent_str + cursor)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
def print_multiline(
|
||||||
|
self, all_lines: list[str], diagnostic: Diagnostic, indent: int = 4
|
||||||
|
):
|
||||||
|
loc: Location = diagnostic.location
|
||||||
|
lines: list[str] = all_lines[loc.lineno - 1 : loc.end_lineno]
|
||||||
|
|
||||||
|
start_offset: int = loc.col_offset
|
||||||
|
end_offset: int = loc.end_col_offset or (start_offset + 1)
|
||||||
|
|
||||||
|
indent_str: str = " " * indent
|
||||||
|
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
|
||||||
|
res: str = indent_str + lines[0][:start_offset]
|
||||||
|
res += Ansi.FG(color) + lines[0][start_offset:]
|
||||||
|
for line in lines[1:-1]:
|
||||||
|
res += "\n" + indent_str + line
|
||||||
|
res += "\n" + indent_str + lines[-1][:end_offset]
|
||||||
|
res += Ansi.RESET + lines[-1][end_offset:]
|
||||||
|
|
||||||
|
print(diagnostic.location_str + ":")
|
||||||
|
print(res)
|
||||||
|
print()
|
||||||
|
print(Ansi.FG(color) + diagnostic.message + Ansi.RESET)
|
||||||
|
print()
|
||||||
|
|||||||
224
midas/generator/constraints.py
Normal file
224
midas/generator/constraints.py
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
import ast
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.types import (
|
||||||
|
Function,
|
||||||
|
Predicate,
|
||||||
|
Type,
|
||||||
|
to_annotation,
|
||||||
|
)
|
||||||
|
from midas.lexer.token import TokenType
|
||||||
|
|
||||||
|
LOGICAL_OPERATORS: dict[TokenType, type[ast.boolop]] = {
|
||||||
|
TokenType.AND: ast.And,
|
||||||
|
# TokenType.OR: ast.Or,
|
||||||
|
}
|
||||||
|
|
||||||
|
BINARY_OPERATORS: dict[TokenType, type[ast.operator]] = {
|
||||||
|
# TokenType.PLUS: ast.Add,
|
||||||
|
TokenType.MINUS: ast.Sub,
|
||||||
|
TokenType.STAR: ast.Mult,
|
||||||
|
TokenType.SLASH: ast.Div,
|
||||||
|
}
|
||||||
|
|
||||||
|
UNARY_OPERATORS: dict[TokenType, type[ast.unaryop]] = {
|
||||||
|
# TokenType.PLUS: ast.UAdd,
|
||||||
|
TokenType.MINUS: ast.USub,
|
||||||
|
}
|
||||||
|
|
||||||
|
COMPARISON_OPERATORS: dict[TokenType, type[ast.cmpop]] = {
|
||||||
|
TokenType.GREATER: ast.Gt,
|
||||||
|
TokenType.GREATER_EQUAL: ast.GtE,
|
||||||
|
TokenType.LESS: ast.Lt,
|
||||||
|
TokenType.LESS_EQUAL: ast.LtE,
|
||||||
|
TokenType.EQUAL_EQUAL: ast.Eq,
|
||||||
|
TokenType.BANG_EQUAL: ast.NotEq,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ConstraintGenerator(m.Expr.Visitor[ast.expr]):
|
||||||
|
def __init__(self, types: TypesRegistry):
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self._id: int = 0
|
||||||
|
self._definitions: list[ast.stmt] = []
|
||||||
|
self._aliases: dict[str, str] = {}
|
||||||
|
|
||||||
|
def get_definitions(self) -> list[ast.stmt]:
|
||||||
|
return self._definitions
|
||||||
|
|
||||||
|
def generate(self, expr: m.Expr) -> ast.expr:
|
||||||
|
match expr:
|
||||||
|
case m.VariableExpr():
|
||||||
|
return expr.accept(self)
|
||||||
|
case _:
|
||||||
|
func = Function(
|
||||||
|
pos_args=[],
|
||||||
|
args=[
|
||||||
|
Function.Argument(
|
||||||
|
pos=0,
|
||||||
|
name="_",
|
||||||
|
type=self.types.get_type("Any"),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
kw_args=[],
|
||||||
|
returns=self.types.get_type("bool"),
|
||||||
|
)
|
||||||
|
alias: str = self.make_alias(None)
|
||||||
|
definition: ast.stmt = self.make_definition(
|
||||||
|
alias, Predicate(type=func, body=expr, alias=False)
|
||||||
|
)
|
||||||
|
self._definitions.append(definition)
|
||||||
|
return ast.Name(id=alias)
|
||||||
|
|
||||||
|
def make_alias(self, name: Optional[str]) -> str:
|
||||||
|
suffix: str
|
||||||
|
if name is None:
|
||||||
|
suffix = f"p{self._id}"
|
||||||
|
self._id += 1
|
||||||
|
else:
|
||||||
|
suffix = name
|
||||||
|
alias: str = f"__midas_{suffix}__"
|
||||||
|
return alias
|
||||||
|
|
||||||
|
def make_definition(self, name: str, predicate: Predicate) -> ast.stmt:
|
||||||
|
body: ast.expr = predicate.body.accept(self)
|
||||||
|
if predicate.alias:
|
||||||
|
return ast.Assign(
|
||||||
|
targets=[
|
||||||
|
ast.Name(id=name),
|
||||||
|
],
|
||||||
|
value=body,
|
||||||
|
)
|
||||||
|
return self.make_func(name, [ast.Return(value=body)], predicate.type)
|
||||||
|
|
||||||
|
def make_args(self, func: Function) -> ast.arguments:
|
||||||
|
return ast.arguments(
|
||||||
|
posonlyargs=[
|
||||||
|
ast.arg(
|
||||||
|
arg=arg.name,
|
||||||
|
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||||
|
)
|
||||||
|
for arg in func.pos_args
|
||||||
|
],
|
||||||
|
args=[
|
||||||
|
ast.arg(
|
||||||
|
arg=arg.name,
|
||||||
|
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||||
|
)
|
||||||
|
for arg in func.args
|
||||||
|
],
|
||||||
|
kwonlyargs=[
|
||||||
|
ast.arg(
|
||||||
|
arg=arg.name,
|
||||||
|
annotation=ast.Constant(value=to_annotation(arg.type)),
|
||||||
|
)
|
||||||
|
for arg in func.kw_args
|
||||||
|
],
|
||||||
|
defaults=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_func(
|
||||||
|
self, name: str, inner_body: list[ast.stmt], type: Type, level: int = 0
|
||||||
|
) -> ast.stmt:
|
||||||
|
match type:
|
||||||
|
case Function(returns=Function()):
|
||||||
|
inner_name: str = f"inner{level}"
|
||||||
|
return ast.FunctionDef(
|
||||||
|
name=name,
|
||||||
|
args=self.make_args(type),
|
||||||
|
body=[
|
||||||
|
self.make_func(inner_name, inner_body, type.returns, level + 1),
|
||||||
|
ast.Return(value=ast.Name(id=inner_name)),
|
||||||
|
],
|
||||||
|
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||||
|
decorator_list=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
case Function():
|
||||||
|
return ast.FunctionDef(
|
||||||
|
name=name,
|
||||||
|
args=self.make_args(type),
|
||||||
|
body=inner_body,
|
||||||
|
returns=ast.Constant(value=to_annotation(type.returns)),
|
||||||
|
decorator_list=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Expected function, got {type!r}")
|
||||||
|
|
||||||
|
def get_predicate(self, name: str) -> Optional[ast.expr]:
|
||||||
|
if name not in self._aliases:
|
||||||
|
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
|
||||||
|
if predicate is None:
|
||||||
|
return None
|
||||||
|
alias: str = self.make_alias(name)
|
||||||
|
self._aliases[name] = alias
|
||||||
|
self._definitions.append(self.make_definition(alias, predicate))
|
||||||
|
|
||||||
|
return ast.Name(id=self._aliases[name])
|
||||||
|
|
||||||
|
def visit_logical_expr(self, expr: m.LogicalExpr) -> ast.expr:
|
||||||
|
return ast.BoolOp(
|
||||||
|
op=LOGICAL_OPERATORS[expr.operator.type](),
|
||||||
|
values=[
|
||||||
|
expr.left.accept(self),
|
||||||
|
expr.right.accept(self),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_binary_expr(self, expr: m.BinaryExpr) -> ast.expr:
|
||||||
|
op: TokenType = expr.operator.type
|
||||||
|
if op in BINARY_OPERATORS:
|
||||||
|
return ast.BinOp(
|
||||||
|
left=expr.left.accept(self),
|
||||||
|
op=BINARY_OPERATORS[op](),
|
||||||
|
right=expr.right.accept(self),
|
||||||
|
)
|
||||||
|
if op in COMPARISON_OPERATORS:
|
||||||
|
return ast.Compare(
|
||||||
|
left=expr.left.accept(self),
|
||||||
|
ops=[COMPARISON_OPERATORS[op]()],
|
||||||
|
comparators=[expr.right.accept(self)],
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unexpected binary operator {op}")
|
||||||
|
|
||||||
|
def visit_unary_expr(self, expr: m.UnaryExpr) -> ast.expr:
|
||||||
|
return ast.UnaryOp(
|
||||||
|
op=UNARY_OPERATORS[expr.operator.type](),
|
||||||
|
operand=expr.right.accept(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: m.CallExpr) -> ast.expr:
|
||||||
|
return ast.Call(
|
||||||
|
func=expr.callee.accept(self),
|
||||||
|
args=[arg.accept(self) for arg in expr.arguments],
|
||||||
|
keywords=[
|
||||||
|
ast.keyword(arg=name, value=arg.accept(self))
|
||||||
|
for name, arg in expr.keywords.items()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_get_expr(self, expr: m.GetExpr) -> ast.expr:
|
||||||
|
return ast.Attribute(
|
||||||
|
value=expr.expr.accept(self),
|
||||||
|
attr=expr.name.lexeme,
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_variable_expr(self, expr: m.VariableExpr) -> ast.expr:
|
||||||
|
name: str = expr.name.lexeme
|
||||||
|
if (p := self.get_predicate(name)) is not None:
|
||||||
|
return p
|
||||||
|
return ast.Name(id=name)
|
||||||
|
|
||||||
|
def visit_grouping_expr(self, expr: m.GroupingExpr) -> ast.expr:
|
||||||
|
return expr.accept(self)
|
||||||
|
|
||||||
|
def visit_literal_expr(self, expr: m.LiteralExpr) -> ast.expr:
|
||||||
|
return ast.Constant(value=expr.value)
|
||||||
|
|
||||||
|
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> ast.expr:
|
||||||
|
return ast.Name(id="_")
|
||||||
@@ -1,54 +1,83 @@
|
|||||||
import ast
|
import ast
|
||||||
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, assert_never
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
|
from midas.ast.printer import MidasPrinter
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
AliasType,
|
AliasType,
|
||||||
AppliedType,
|
AppliedType,
|
||||||
BaseType,
|
BaseType,
|
||||||
|
ColumnType,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
Function,
|
Function,
|
||||||
GenericType,
|
GenericType,
|
||||||
OverloadedFunction,
|
OverloadedFunction,
|
||||||
TopType,
|
TopType,
|
||||||
|
TupleType,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
UnitType,
|
UnitType,
|
||||||
|
UnknownType,
|
||||||
)
|
)
|
||||||
|
from midas.generator.constraints import ConstraintGenerator
|
||||||
from midas.utils import TypedAST
|
from midas.utils import TypedAST
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Scope:
|
class Scope:
|
||||||
pre_assertions: list[ast.stmt] = field(default_factory=list)
|
pre_assertions: list[ast.stmt] = field(default_factory=list[ast.stmt])
|
||||||
aliases: list[str] = field(default_factory=list)
|
aliases: list[str] = field(default_factory=list[str])
|
||||||
|
|
||||||
|
|
||||||
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||||
def __init__(self, workdir: Path) -> None:
|
IS_DATAFRAME_FUNC = "__midas_is_dataframe__"
|
||||||
|
IS_COLUMN_FUNC = "__midas_is_column__"
|
||||||
|
|
||||||
|
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
|
||||||
self.workdir: Path = workdir.resolve()
|
self.workdir: Path = workdir.resolve()
|
||||||
self.build_dir: Path = self.workdir / "build" / "midas"
|
self.build_dir: Path = self.workdir / "build" / "midas"
|
||||||
if self.build_dir.exists():
|
|
||||||
shutil.rmtree(self.build_dir)
|
|
||||||
self.build_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
self.rel_src_path: Path = Path()
|
self.rel_src_path: Path = Path()
|
||||||
|
self.logger: logging.Logger = logging.getLogger("Generator")
|
||||||
|
|
||||||
self._typed_ast: TypedAST = TypedAST(
|
self._typed_ast: TypedAST = TypedAST(
|
||||||
stmts=[],
|
stmts=[],
|
||||||
judgements=[],
|
judgements=[],
|
||||||
|
evaluated_casts=[],
|
||||||
)
|
)
|
||||||
self._alias_count: int = 0
|
self._alias_count: int = 0
|
||||||
|
self._predicate_count: int = 0
|
||||||
self._scopes: list[Scope] = []
|
self._scopes: list[Scope] = []
|
||||||
|
|
||||||
|
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
|
||||||
|
self._constraints: list[tuple[m.Expr, ast.expr]] = []
|
||||||
|
|
||||||
|
self.define_is_dataframe: bool = False
|
||||||
|
self.define_is_column: bool = False
|
||||||
|
|
||||||
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
|
||||||
self.rel_src_path = src_path.relative_to(self.workdir)
|
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
|
||||||
self._typed_ast = typed_ast
|
self._typed_ast = typed_ast
|
||||||
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
|
||||||
|
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
|
||||||
|
|
||||||
|
body = predicates + body
|
||||||
|
|
||||||
|
if self.define_is_dataframe:
|
||||||
|
body = [self._is_dataframe_definition()] + body
|
||||||
|
|
||||||
|
if self.define_is_column:
|
||||||
|
body = [self._is_column_definition()] + body
|
||||||
|
|
||||||
module = ast.Module(body=body, type_ignores=[])
|
module = ast.Module(body=body, type_ignores=[])
|
||||||
module = ast.fix_missing_locations(module)
|
module = ast.fix_missing_locations(module)
|
||||||
return module
|
return module
|
||||||
@@ -59,6 +88,9 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
module: ast.AST = self.generate_ast(typed_ast, src_path)
|
module: ast.AST = self.generate_ast(typed_ast, src_path)
|
||||||
compiled: str = ast.unparse(module)
|
compiled: str = ast.unparse(module)
|
||||||
if out_path is None:
|
if out_path is None:
|
||||||
|
if self.build_dir.exists():
|
||||||
|
shutil.rmtree(self.build_dir)
|
||||||
|
self.build_dir.mkdir(parents=True, exist_ok=True)
|
||||||
out_path = (self.build_dir / self.rel_src_path).resolve()
|
out_path = (self.build_dir / self.rel_src_path).resolve()
|
||||||
try:
|
try:
|
||||||
_ = out_path.relative_to(self.build_dir)
|
_ = out_path.relative_to(self.build_dir)
|
||||||
@@ -120,10 +152,16 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
|
|
||||||
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
|
||||||
expr2: ast.expr = expr.expr.accept(self)
|
expr2: ast.expr = expr.expr.accept(self)
|
||||||
|
|
||||||
|
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
|
||||||
|
return expr2
|
||||||
|
|
||||||
alias: ast.expr = self._make_alias(expr2)
|
alias: ast.expr = self._make_alias(expr2)
|
||||||
|
|
||||||
type: Type = self._get_expr_type(expr)
|
type: Type = self._get_expr_type(expr)
|
||||||
self._make_cast_asserts(expr.location, alias, type)
|
asserts: list[ast.stmt] = self._make_cast_asserts(expr.location, alias, type)
|
||||||
|
for assert_ in asserts:
|
||||||
|
self._add_assert(assert_)
|
||||||
|
|
||||||
return alias
|
return alias
|
||||||
|
|
||||||
@@ -158,6 +196,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
step=expr.step.accept(self) if expr.step is not None else None,
|
step=expr.step.accept(self) if expr.step is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def visit_tuple_expr(self, expr: p.TupleExpr) -> ast.expr:
|
||||||
|
return ast.Tuple(
|
||||||
|
elts=[item.accept(self) for item in expr.items],
|
||||||
|
)
|
||||||
|
|
||||||
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
|
def visit_raw_expr(self, expr: p.RawExpr) -> ast.expr:
|
||||||
return expr.expr
|
return expr.expr
|
||||||
|
|
||||||
@@ -246,7 +289,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
return generated
|
return generated
|
||||||
|
|
||||||
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
def _make_alias(self, expr: ast.expr) -> ast.expr:
|
||||||
name: str = f"__midas_alias_{self._alias_count}__"
|
name: str = f"__midas_a{self._alias_count}__"
|
||||||
alias = ast.Name(id=name)
|
alias = ast.Name(id=name)
|
||||||
self._alias_count += 1
|
self._alias_count += 1
|
||||||
self._scopes[-1].aliases.append(name)
|
self._scopes[-1].aliases.append(name)
|
||||||
@@ -258,51 +301,156 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
)
|
)
|
||||||
return alias
|
return alias
|
||||||
|
|
||||||
def _add_assert(self, expr: ast.expr, message: str | ast.expr):
|
def _build_assert(self, expr: ast.expr, message: str | ast.expr) -> ast.stmt:
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
message = ast.Constant(value=message)
|
message = ast.Constant(value=message)
|
||||||
self._scopes[-1].pre_assertions.append(
|
return ast.Assert(
|
||||||
ast.Assert(
|
test=expr,
|
||||||
test=expr,
|
msg=message,
|
||||||
msg=message,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _add_assert(self, assertion: ast.stmt):
|
||||||
|
self._scopes[-1].pre_assertions.append(assertion)
|
||||||
|
|
||||||
def _get_expr_type(self, query: p.Expr) -> Type:
|
def _get_expr_type(self, query: p.Expr) -> Type:
|
||||||
for expr, type in self._typed_ast.judgements:
|
for expr, type in self._typed_ast.judgements:
|
||||||
if expr == query:
|
if expr == query:
|
||||||
return type
|
return type
|
||||||
raise RuntimeError(f"Cannot get type judgement for {query}")
|
raise RuntimeError(f"Cannot get type judgement for {query}")
|
||||||
|
|
||||||
def _make_cast_asserts(self, src_location: Location, expr: ast.expr, type: Type):
|
def _make_cast_asserts(
|
||||||
|
self, src_location: Location, expr: ast.expr, type: Type
|
||||||
|
) -> list[ast.stmt]:
|
||||||
match type:
|
match type:
|
||||||
|
case UnknownType():
|
||||||
|
return []
|
||||||
|
|
||||||
case BaseType(name=name):
|
case BaseType(name=name):
|
||||||
self._add_assert(
|
return [
|
||||||
ast.Call(
|
self._build_assert(
|
||||||
func=ast.Name(id="isinstance"),
|
ast.Call(
|
||||||
args=[expr, ast.Name(id=name)],
|
func=ast.Name(id="isinstance"),
|
||||||
keywords=[],
|
args=[expr, ast.Name(id=name)],
|
||||||
),
|
keywords=[],
|
||||||
self._make_cast_assert_message(src_location, expr, type),
|
),
|
||||||
)
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
case AliasType(type=base):
|
case AliasType(type=base):
|
||||||
self._make_cast_asserts(src_location, expr, base)
|
return self._make_cast_asserts(src_location, expr, base)
|
||||||
|
|
||||||
case UnitType():
|
case UnitType():
|
||||||
self._add_assert(
|
return [
|
||||||
ast.Compare(
|
self._build_assert(
|
||||||
left=expr,
|
ast.Compare(
|
||||||
ops=[ast.Is()],
|
left=expr,
|
||||||
comparators=[
|
ops=[ast.Is()],
|
||||||
ast.Constant(value=None),
|
comparators=[
|
||||||
],
|
ast.Constant(value=None),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
),
|
),
|
||||||
self._make_cast_assert_message(src_location, expr, type),
|
]
|
||||||
)
|
|
||||||
|
|
||||||
case AppliedType():
|
case AppliedType(body=body):
|
||||||
self._make_cast_asserts(src_location, expr, type.body)
|
return self._make_cast_asserts(src_location, expr, body)
|
||||||
|
|
||||||
|
case ConstraintType(type=base, constraint=constraint):
|
||||||
|
asserts: list[ast.stmt] = self._make_cast_asserts(
|
||||||
|
src_location, expr, base
|
||||||
|
)
|
||||||
|
asserts.append(
|
||||||
|
self._make_constraint_assert(src_location, expr, constraint)
|
||||||
|
)
|
||||||
|
return asserts
|
||||||
|
|
||||||
|
case TypeVar(bound=bound):
|
||||||
|
# TODO: check with type from arguments / use call-site context
|
||||||
|
if bound is None:
|
||||||
|
return []
|
||||||
|
return self._make_cast_asserts(src_location, expr, bound)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
asserts: list[ast.stmt] = [
|
||||||
|
self._build_assert(
|
||||||
|
ast.Call(
|
||||||
|
func=ast.Name(id="isinstance"),
|
||||||
|
args=[expr, ast.Name(id="tuple")],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(src_location, expr, type),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert isinstance(expr, ast.Tuple)
|
||||||
|
for item, item_type in zip(expr.elts, items):
|
||||||
|
asserts.extend(
|
||||||
|
self._make_cast_asserts(src_location, item, item_type)
|
||||||
|
)
|
||||||
|
return asserts
|
||||||
|
|
||||||
|
case DataFrameType(columns=columns):
|
||||||
|
self.define_is_dataframe = True
|
||||||
|
asserts: list[ast.stmt] = [
|
||||||
|
self._build_assert(
|
||||||
|
ast.Call(
|
||||||
|
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
|
||||||
|
args=[expr],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(
|
||||||
|
src_location, expr, type, ": Not a dataframe"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for column in columns:
|
||||||
|
asserts.append(
|
||||||
|
self._build_assert(
|
||||||
|
ast.Compare(
|
||||||
|
left=ast.Constant(value=column.name),
|
||||||
|
ops=[ast.In()],
|
||||||
|
comparators=[expr],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(
|
||||||
|
src_location,
|
||||||
|
expr,
|
||||||
|
type,
|
||||||
|
f": Missing column {column.name}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
asserts.extend(
|
||||||
|
self._make_cast_asserts(
|
||||||
|
src_location,
|
||||||
|
ast.Subscript(
|
||||||
|
value=expr, slice=ast.Constant(value=column.name)
|
||||||
|
),
|
||||||
|
column.type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return asserts
|
||||||
|
|
||||||
|
case ColumnType():
|
||||||
|
self.define_is_column = True
|
||||||
|
asserts: list[ast.stmt] = [
|
||||||
|
self._build_assert(
|
||||||
|
ast.Call(
|
||||||
|
func=ast.Name(id=self.IS_COLUMN_FUNC),
|
||||||
|
args=[expr],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_cast_assert_message(
|
||||||
|
src_location, expr, type, ": Not a column"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
inner_assert: Optional[ast.stmt] = self._make_column_inner_assert(
|
||||||
|
src_location, expr, type
|
||||||
|
)
|
||||||
|
if inner_assert is not None:
|
||||||
|
asserts.append(inner_assert)
|
||||||
|
return asserts
|
||||||
|
|
||||||
case (
|
case (
|
||||||
TopType()
|
TopType()
|
||||||
@@ -312,13 +460,19 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
| ExtensionType()
|
| ExtensionType()
|
||||||
| GenericType()
|
| GenericType()
|
||||||
):
|
):
|
||||||
raise NotImplementedError(f"Can't make assertion for type {type}")
|
self.logger.warning(f"Can't make assertion for type {type}")
|
||||||
|
return []
|
||||||
|
|
||||||
case TypeVar():
|
# Ensure exhaustiveness
|
||||||
raise RuntimeError("Unexpected TypeVar")
|
case _:
|
||||||
|
assert_never(type)
|
||||||
|
|
||||||
def _make_cast_assert_message(
|
def _make_cast_assert_message(
|
||||||
self, location: Location, expr: ast.expr, type: Type
|
self,
|
||||||
|
location: Location,
|
||||||
|
expr: ast.expr,
|
||||||
|
type: Type,
|
||||||
|
extra: Optional[str] = None,
|
||||||
) -> ast.expr:
|
) -> ast.expr:
|
||||||
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||||
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
|
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
|
||||||
@@ -336,6 +490,126 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
|||||||
),
|
),
|
||||||
conversion=-1,
|
conversion=-1,
|
||||||
),
|
),
|
||||||
ast.Constant(f" to {type}"),
|
ast.Constant(f" to {type}{extra or ''}"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _make_constraint_assert(
|
||||||
|
self, src_location: Location, expr: ast.expr, constraint: m.Expr
|
||||||
|
) -> ast.stmt:
|
||||||
|
test_func: ast.expr = self._get_constraint(constraint)
|
||||||
|
return self._build_assert(
|
||||||
|
ast.Call(
|
||||||
|
func=test_func,
|
||||||
|
args=[expr],
|
||||||
|
keywords=[],
|
||||||
|
),
|
||||||
|
self._make_constraint_assert_message(src_location, expr, constraint),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_constraint_assert_message(
|
||||||
|
self, location: Location, expr: ast.expr, constraint: m.Expr
|
||||||
|
) -> ast.expr:
|
||||||
|
printer = MidasPrinter()
|
||||||
|
constraint_str: str = printer.print(constraint)
|
||||||
|
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
|
||||||
|
# f"file.py:L1:1: ConstraintError: Value does not fit constraint 'v > 0'"
|
||||||
|
return ast.Constant(
|
||||||
|
f"{loc_str}: ConstraintError: Value does not fit constraint '{constraint_str}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_constraint(self, expr: m.Expr) -> ast.expr:
|
||||||
|
for expr2, constraint in self._constraints:
|
||||||
|
if expr2 == expr:
|
||||||
|
return constraint
|
||||||
|
|
||||||
|
constraint: ast.expr = self._constraint_generator.generate(expr)
|
||||||
|
self._constraints.append((expr, constraint))
|
||||||
|
return constraint
|
||||||
|
|
||||||
|
def _is_dataframe_definition(self) -> ast.stmt:
|
||||||
|
"""
|
||||||
|
def IS_DATAFRAME_FUNC(obj) -> bool:
|
||||||
|
import pandas as pd
|
||||||
|
return isinstance(obj, pd.DataFrame)
|
||||||
|
"""
|
||||||
|
|
||||||
|
return ast.FunctionDef(
|
||||||
|
name=self.IS_DATAFRAME_FUNC,
|
||||||
|
args=ast.arguments(
|
||||||
|
posonlyargs=[ast.arg(arg="obj")],
|
||||||
|
args=[],
|
||||||
|
kwonlyargs=[],
|
||||||
|
defaults=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
),
|
||||||
|
body=[
|
||||||
|
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
|
||||||
|
ast.Return(
|
||||||
|
value=ast.Call(
|
||||||
|
func=ast.Name(id="isinstance"),
|
||||||
|
args=[
|
||||||
|
ast.Name(id="obj"),
|
||||||
|
ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="DataFrame",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
keywords=[],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=ast.Name(id="bool"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_column_definition(self) -> ast.stmt:
|
||||||
|
"""
|
||||||
|
def IS_COLUMN_FUNC(obj) -> bool:
|
||||||
|
import pandas as pd
|
||||||
|
return isinstance(obj, pd.Series)
|
||||||
|
"""
|
||||||
|
|
||||||
|
return ast.FunctionDef(
|
||||||
|
name=self.IS_COLUMN_FUNC,
|
||||||
|
args=ast.arguments(
|
||||||
|
posonlyargs=[ast.arg(arg="obj")],
|
||||||
|
args=[],
|
||||||
|
kwonlyargs=[],
|
||||||
|
defaults=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
),
|
||||||
|
body=[
|
||||||
|
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
|
||||||
|
ast.Return(
|
||||||
|
value=ast.Call(
|
||||||
|
func=ast.Name(id="isinstance"),
|
||||||
|
args=[
|
||||||
|
ast.Name(id="obj"),
|
||||||
|
ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="Series",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
keywords=[],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=ast.Name(id="bool"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_column_inner_assert(
|
||||||
|
self, src_location: Location, column: ast.expr, type: ColumnType
|
||||||
|
) -> Optional[ast.stmt]:
|
||||||
|
# TODO: improve message, maybe chain contexts
|
||||||
|
col: ast.expr = ast.Name(id="col")
|
||||||
|
body: list[ast.stmt] = self._make_cast_asserts(src_location, col, type.type)
|
||||||
|
if len(body) == 0:
|
||||||
|
return None
|
||||||
|
return ast.For(
|
||||||
|
target=col,
|
||||||
|
iter=column,
|
||||||
|
body=body,
|
||||||
|
orelse=[],
|
||||||
|
)
|
||||||
|
|||||||
427
midas/generator/stubs.py
Normal file
427
midas/generator/stubs.py
Normal file
@@ -0,0 +1,427 @@
|
|||||||
|
import ast
|
||||||
|
from typing import Optional, assert_never
|
||||||
|
|
||||||
|
import midas.ast.midas as m
|
||||||
|
from midas.checker.registry import Member, TypesRegistry
|
||||||
|
from midas.checker.types import (
|
||||||
|
AliasType,
|
||||||
|
AppliedType,
|
||||||
|
BaseType,
|
||||||
|
ColumnType,
|
||||||
|
ComplexType,
|
||||||
|
ConstraintType,
|
||||||
|
DataFrameType,
|
||||||
|
ExtensionType,
|
||||||
|
Function,
|
||||||
|
GenericType,
|
||||||
|
OverloadedFunction,
|
||||||
|
TopType,
|
||||||
|
TupleType,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
UnitType,
|
||||||
|
UnknownType,
|
||||||
|
Variance,
|
||||||
|
substitute_typevars,
|
||||||
|
)
|
||||||
|
|
||||||
|
Empty = ast.Constant(value=...)
|
||||||
|
|
||||||
|
|
||||||
|
class StubsGenerator:
|
||||||
|
def __init__(self, types: TypesRegistry) -> None:
|
||||||
|
self.types: TypesRegistry = types
|
||||||
|
self.stubs: list[ast.stmt] = []
|
||||||
|
self.typing_imports: set[str] = set()
|
||||||
|
self.import_pandas: bool = False
|
||||||
|
self.protocol_idx: int = 0
|
||||||
|
self.stub_idx: int = 0
|
||||||
|
self.type_var_idx: int = 0
|
||||||
|
self.substitutions: dict[str, dict[str, Type]] = {}
|
||||||
|
|
||||||
|
def generate_stubs(self) -> ast.Module:
|
||||||
|
self.stubs = []
|
||||||
|
self.typing_imports = set()
|
||||||
|
self.import_pandas = False
|
||||||
|
for name, type in self.types._types.items():
|
||||||
|
# Skip builtin types, not just based on name so the user can override
|
||||||
|
# TODO: check if added members on builtin type
|
||||||
|
match type:
|
||||||
|
case BaseType(name=name_) if name == name_:
|
||||||
|
continue
|
||||||
|
case GenericType(
|
||||||
|
name=name1,
|
||||||
|
body=BaseType(name=name2),
|
||||||
|
) if (
|
||||||
|
name == name1 == name2
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
self.generate_stub(name, type)
|
||||||
|
|
||||||
|
imports: list[ast.stmt] = [
|
||||||
|
ast.ImportFrom(
|
||||||
|
module="__future__",
|
||||||
|
names=[ast.alias(name="annotations")],
|
||||||
|
level=0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if len(self.typing_imports) != 0:
|
||||||
|
imports.append(
|
||||||
|
ast.ImportFrom(
|
||||||
|
module="typing",
|
||||||
|
names=[
|
||||||
|
ast.alias(name=name) for name in sorted(self.typing_imports)
|
||||||
|
],
|
||||||
|
level=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.import_pandas:
|
||||||
|
imports.append(
|
||||||
|
ast.Import(
|
||||||
|
names=[
|
||||||
|
ast.alias(
|
||||||
|
name="pandas",
|
||||||
|
asname="pd",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return ast.Module(body=imports + self.stubs, type_ignores=[])
|
||||||
|
|
||||||
|
def generate_stub(self, name: str, type: Type):
|
||||||
|
base_type: Type = type
|
||||||
|
|
||||||
|
members: dict[str, Member] = self.types._members.get(name, {})
|
||||||
|
if isinstance(base_type, (BaseType, TopType, UnitType)) and len(members) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
bases: list[ast.expr] = []
|
||||||
|
substitutions: dict[str, Type] = {}
|
||||||
|
bases, substitutions = self.get_bases(type)
|
||||||
|
self.substitutions[name] = substitutions
|
||||||
|
|
||||||
|
body = self.generate_body(members, substitutions)
|
||||||
|
stub = ast.ClassDef(
|
||||||
|
name=name,
|
||||||
|
bases=bases,
|
||||||
|
body=body,
|
||||||
|
keywords=[],
|
||||||
|
decorator_list=[],
|
||||||
|
)
|
||||||
|
self.add_stub(stub)
|
||||||
|
|
||||||
|
def get_bases(self, type: Type) -> tuple[list[ast.expr], dict[str, Type]]:
|
||||||
|
match type:
|
||||||
|
case AliasType(type=base):
|
||||||
|
return [self.dump_type(base)], {}
|
||||||
|
|
||||||
|
case GenericType(params=params, body=body):
|
||||||
|
self.add_typing_import("Generic")
|
||||||
|
type_vars: ast.expr
|
||||||
|
|
||||||
|
params2: list[TypeVar] = self.define_type_vars(params)
|
||||||
|
if len(params) == 1:
|
||||||
|
type_vars = ast.Name(id=params2[0].name)
|
||||||
|
else:
|
||||||
|
type_vars = ast.Tuple(
|
||||||
|
elts=[ast.Name(id=param.name) for param in params2]
|
||||||
|
)
|
||||||
|
|
||||||
|
substitutions: dict[str, TypeVar] = {
|
||||||
|
param.name: param2 for param, param2 in zip(params, params2)
|
||||||
|
}
|
||||||
|
|
||||||
|
body_bases, body_subsitutions = self.get_bases(body)
|
||||||
|
return (
|
||||||
|
body_bases
|
||||||
|
+ [
|
||||||
|
ast.Subscript(
|
||||||
|
value=ast.Name(id="Generic"),
|
||||||
|
slice=type_vars,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
body_subsitutions | substitutions,
|
||||||
|
)
|
||||||
|
|
||||||
|
case ConstraintType(type=base):
|
||||||
|
return self.get_bases(base)
|
||||||
|
|
||||||
|
case TypeVar(bound=bound) if bound is not None:
|
||||||
|
return [self.dump_type(bound)], {}
|
||||||
|
|
||||||
|
case _:
|
||||||
|
return [], {}
|
||||||
|
|
||||||
|
def generate_body(
|
||||||
|
self, members: dict[str, Member], substitutions: dict[str, Type]
|
||||||
|
) -> list[ast.stmt]:
|
||||||
|
if len(members) == 0:
|
||||||
|
return [ast.Expr(value=Empty)]
|
||||||
|
|
||||||
|
body: list[ast.stmt] = []
|
||||||
|
for name, member in members.items():
|
||||||
|
type: Type = member.type
|
||||||
|
type = substitute_typevars(type, substitutions)
|
||||||
|
match member.kind:
|
||||||
|
case m.MemberKind.PROPERTY:
|
||||||
|
body.append(
|
||||||
|
ast.AnnAssign(
|
||||||
|
target=ast.Name(id=name),
|
||||||
|
annotation=self.dump_type(type),
|
||||||
|
simple=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case m.MemberKind.METHOD:
|
||||||
|
body.extend(self.dump_method(name, type))
|
||||||
|
return body
|
||||||
|
|
||||||
|
def dump_type(self, type: Type) -> ast.expr:
|
||||||
|
match type:
|
||||||
|
case AliasType(name=name) | GenericType(name=name) if (
|
||||||
|
name in self.substitutions
|
||||||
|
):
|
||||||
|
type = substitute_typevars(type, self.substitutions[name])
|
||||||
|
|
||||||
|
match type:
|
||||||
|
case TopType() | UnknownType():
|
||||||
|
self.add_typing_import("Any")
|
||||||
|
return ast.Name(id="Any")
|
||||||
|
|
||||||
|
case BaseType(name=name):
|
||||||
|
return ast.Name(id=name)
|
||||||
|
|
||||||
|
case AliasType(name=name):
|
||||||
|
return ast.Name(id=name)
|
||||||
|
|
||||||
|
case UnitType():
|
||||||
|
return ast.Constant(value=None)
|
||||||
|
|
||||||
|
case Function():
|
||||||
|
name: str = self.define_protocol(type)
|
||||||
|
return ast.Name(id=name)
|
||||||
|
|
||||||
|
case OverloadedFunction(overloads=overloads):
|
||||||
|
if len(overloads) == 1:
|
||||||
|
return self.dump_type(overloads[0])
|
||||||
|
return ast.BinOp(
|
||||||
|
left=self.dump_type(OverloadedFunction(overloads=overloads[:-1])),
|
||||||
|
op=ast.BitOr(),
|
||||||
|
right=self.dump_type(overloads[-1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ComplexType():
|
||||||
|
name: str = self.new_stub_name()
|
||||||
|
self.generate_stub(name, type)
|
||||||
|
return ast.Name(id=name)
|
||||||
|
|
||||||
|
case ExtensionType():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
case TypeVar():
|
||||||
|
return ast.Name(id=type.name)
|
||||||
|
|
||||||
|
case GenericType(name=name):
|
||||||
|
params: ast.expr
|
||||||
|
if len(type.params) == 1:
|
||||||
|
params = self.dump_type(type.params[0])
|
||||||
|
else:
|
||||||
|
params = ast.Tuple(
|
||||||
|
elts=[self.dump_type(param) for param in type.params]
|
||||||
|
)
|
||||||
|
return ast.Subscript(
|
||||||
|
value=ast.Name(id=type.name),
|
||||||
|
slice=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
case AppliedType():
|
||||||
|
args: ast.expr
|
||||||
|
if len(type.args) == 1:
|
||||||
|
args = self.dump_type(type.args[0])
|
||||||
|
else:
|
||||||
|
args = ast.Tuple(elts=[self.dump_type(arg) for arg in type.args])
|
||||||
|
return ast.Subscript(
|
||||||
|
value=ast.Name(id=type.name),
|
||||||
|
slice=args,
|
||||||
|
)
|
||||||
|
|
||||||
|
case ConstraintType():
|
||||||
|
return self.dump_type(type.type)
|
||||||
|
|
||||||
|
case TupleType(items=items):
|
||||||
|
return ast.Subscript(
|
||||||
|
value=ast.Name(id="tuple"),
|
||||||
|
slice=ast.Tuple(
|
||||||
|
elts=[self.dump_type(item) for item in items],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
case ColumnType(type=inner):
|
||||||
|
self.import_pandas = True
|
||||||
|
return ast.Subscript(
|
||||||
|
value=ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="Series",
|
||||||
|
),
|
||||||
|
slice=self.dump_type(inner),
|
||||||
|
)
|
||||||
|
|
||||||
|
case DataFrameType():
|
||||||
|
self.import_pandas = True
|
||||||
|
return ast.Attribute(
|
||||||
|
value=ast.Name(id="pd"),
|
||||||
|
attr="DataFrame",
|
||||||
|
)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
assert_never(type)
|
||||||
|
|
||||||
|
def dump_method(
|
||||||
|
self, name: str, method: Type, overloaded: bool = False
|
||||||
|
) -> list[ast.stmt]:
|
||||||
|
match method:
|
||||||
|
case Function():
|
||||||
|
if overloaded:
|
||||||
|
self.add_typing_import("overload")
|
||||||
|
return [
|
||||||
|
ast.FunctionDef(
|
||||||
|
name=name,
|
||||||
|
args=self.dump_args(method, with_self=True),
|
||||||
|
returns=self.dump_type(method.returns),
|
||||||
|
body=[ast.Expr(value=Empty)],
|
||||||
|
decorator_list=[ast.Name(id="overload")] if overloaded else [],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
case OverloadedFunction(overloads=overloads):
|
||||||
|
stmts: list[ast.stmt] = []
|
||||||
|
for overload in overloads:
|
||||||
|
stmts.extend(self.dump_method(name, overload, True))
|
||||||
|
return stmts
|
||||||
|
case _:
|
||||||
|
return [
|
||||||
|
ast.AnnAssign(
|
||||||
|
target=ast.Name(id=name),
|
||||||
|
annotation=self.dump_type(method),
|
||||||
|
simple=1,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def dump_args(self, func: Function, with_self: bool = False) -> ast.arguments:
|
||||||
|
pos: list[ast.arg] = [
|
||||||
|
ast.arg(arg=f"_{arg.pos}", annotation=self.dump_type(arg.type))
|
||||||
|
for arg in func.pos_args
|
||||||
|
]
|
||||||
|
mixed: list[ast.arg] = [
|
||||||
|
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
|
||||||
|
for arg in func.args
|
||||||
|
]
|
||||||
|
kw: list[ast.arg] = [
|
||||||
|
ast.arg(arg=arg.name, annotation=self.dump_type(arg.type))
|
||||||
|
for arg in func.kw_args
|
||||||
|
]
|
||||||
|
defaults: list[ast.expr] = [
|
||||||
|
Empty for arg in func.pos_args + func.args if not arg.required
|
||||||
|
]
|
||||||
|
kw_defaults: list[Optional[ast.expr]] = [
|
||||||
|
None if arg.required else Empty for arg in func.kw_args
|
||||||
|
]
|
||||||
|
if with_self:
|
||||||
|
arg = ast.arg(arg="self", annotation=None)
|
||||||
|
if len(pos) != 0:
|
||||||
|
pos.insert(0, arg)
|
||||||
|
else:
|
||||||
|
mixed.insert(0, arg)
|
||||||
|
return ast.arguments(
|
||||||
|
posonlyargs=pos,
|
||||||
|
args=mixed,
|
||||||
|
kwonlyargs=kw,
|
||||||
|
defaults=defaults,
|
||||||
|
kw_defaults=kw_defaults,
|
||||||
|
)
|
||||||
|
|
||||||
|
def define_protocol(self, func: Function) -> str:
|
||||||
|
self.add_typing_import("Protocol")
|
||||||
|
name: str = self.new_protocol_name()
|
||||||
|
protocol = ast.ClassDef(
|
||||||
|
name=name,
|
||||||
|
bases=[ast.Name(id="Protocol")],
|
||||||
|
keywords=[],
|
||||||
|
body=[
|
||||||
|
ast.FunctionDef(
|
||||||
|
name="__call__",
|
||||||
|
args=self.dump_args(func, with_self=True),
|
||||||
|
returns=self.dump_type(func.returns),
|
||||||
|
body=[ast.Expr(value=Empty)],
|
||||||
|
decorator_list=[],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
decorator_list=[],
|
||||||
|
)
|
||||||
|
self.add_stub(protocol)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def new_protocol_name(self) -> str:
|
||||||
|
name: str = f"_Protocol{self.protocol_idx}"
|
||||||
|
self.protocol_idx += 1
|
||||||
|
return name
|
||||||
|
|
||||||
|
def new_stub_name(self) -> str:
|
||||||
|
name: str = f"_Stub_{self.stub_idx}"
|
||||||
|
self.stub_idx += 1
|
||||||
|
return name
|
||||||
|
|
||||||
|
def new_type_var_name(self) -> str:
|
||||||
|
name: str = f"_T{self.type_var_idx}"
|
||||||
|
self.type_var_idx += 1
|
||||||
|
return name
|
||||||
|
|
||||||
|
def add_stub(self, stub: ast.stmt):
|
||||||
|
self.stubs.append(stub)
|
||||||
|
|
||||||
|
def add_typing_import(self, name: str):
|
||||||
|
self.typing_imports.add(name)
|
||||||
|
|
||||||
|
def define_type_vars(self, vars: list[TypeVar]) -> list[TypeVar]:
|
||||||
|
vars2: list[TypeVar] = []
|
||||||
|
for var in vars:
|
||||||
|
vars2.append(self.define_type_var(var))
|
||||||
|
return vars2
|
||||||
|
|
||||||
|
def define_type_var(self, var: TypeVar) -> TypeVar:
|
||||||
|
name: str = self.new_type_var_name()
|
||||||
|
self.add_typing_import("TypeVar")
|
||||||
|
|
||||||
|
kwargs: list[ast.keyword] = []
|
||||||
|
if var.bound is not None:
|
||||||
|
kwargs.append(
|
||||||
|
ast.keyword(
|
||||||
|
arg="bound",
|
||||||
|
value=self.dump_type(var.bound),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if var.variance == Variance.COVARIANT:
|
||||||
|
kwargs.append(
|
||||||
|
ast.keyword(
|
||||||
|
arg="covariant",
|
||||||
|
value=ast.Constant(value=True),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif var.variance == Variance.CONTRAVARIANT:
|
||||||
|
kwargs.append(
|
||||||
|
ast.keyword(
|
||||||
|
arg="contravariant",
|
||||||
|
value=ast.Constant(value=True),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.add_stub(
|
||||||
|
ast.Assign(
|
||||||
|
targets=[ast.Name(id=name)],
|
||||||
|
value=ast.Call(
|
||||||
|
func=ast.Name(id="TypeVar"),
|
||||||
|
args=[
|
||||||
|
ast.Constant(value=name),
|
||||||
|
],
|
||||||
|
keywords=kwargs,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return TypeVar(name=name, bound=None)
|
||||||
@@ -69,6 +69,8 @@ class MidasLexer(Lexer):
|
|||||||
):
|
):
|
||||||
self.advance()
|
self.advance()
|
||||||
self.add_token(TokenType.WHITESPACE)
|
self.add_token(TokenType.WHITESPACE)
|
||||||
|
case '"' | "'":
|
||||||
|
self.scan_string(char)
|
||||||
case _:
|
case _:
|
||||||
if char.isdigit():
|
if char.isdigit():
|
||||||
self.scan_number()
|
self.scan_number()
|
||||||
@@ -78,6 +80,17 @@ class MidasLexer(Lexer):
|
|||||||
self.error("Unexpected character")
|
self.error("Unexpected character")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def scan_string(self, opening: str):
|
||||||
|
while self.peek() != opening and not self.is_at_end():
|
||||||
|
self.advance()
|
||||||
|
|
||||||
|
if self.is_at_end():
|
||||||
|
self.error("Unterminated string")
|
||||||
|
|
||||||
|
self.advance()
|
||||||
|
value: str = self.source[self.start + 1 : self.idx - 1]
|
||||||
|
self.add_token(TokenType.STRING, value)
|
||||||
|
|
||||||
def scan_number(self):
|
def scan_number(self):
|
||||||
"""Scan the rest of number and add it as a token
|
"""Scan the rest of number and add it as a token
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ class TokenType(Enum):
|
|||||||
TRUE = auto()
|
TRUE = auto()
|
||||||
FALSE = auto()
|
FALSE = auto()
|
||||||
NONE = auto()
|
NONE = auto()
|
||||||
|
STRING = auto()
|
||||||
|
|
||||||
# Keywords
|
# Keywords
|
||||||
TYPE = auto()
|
TYPE = auto()
|
||||||
|
|||||||
@@ -3,11 +3,13 @@ from typing import Optional
|
|||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.ast.midas import (
|
from midas.ast.midas import (
|
||||||
BinaryExpr,
|
BinaryExpr,
|
||||||
|
CallExpr,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
Expr,
|
Expr,
|
||||||
ExtendStmt,
|
ExtendStmt,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
|
FrameType,
|
||||||
FunctionType,
|
FunctionType,
|
||||||
GenericType,
|
GenericType,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
@@ -17,6 +19,7 @@ from midas.ast.midas import (
|
|||||||
MemberKind,
|
MemberKind,
|
||||||
MemberStmt,
|
MemberStmt,
|
||||||
NamedType,
|
NamedType,
|
||||||
|
ParamSpec,
|
||||||
PredicateStmt,
|
PredicateStmt,
|
||||||
Stmt,
|
Stmt,
|
||||||
Type,
|
Type,
|
||||||
@@ -202,8 +205,10 @@ class MidasParser(Parser):
|
|||||||
return self.generic_type()
|
return self.generic_type()
|
||||||
|
|
||||||
def generic_type(self) -> Type:
|
def generic_type(self) -> Type:
|
||||||
type: Type = self.named_type()
|
type: NamedType = self.named_type()
|
||||||
if self.check(TokenType.LEFT_BRACKET):
|
if self.check(TokenType.LEFT_BRACKET):
|
||||||
|
if type.name.lexeme == "Frame":
|
||||||
|
return self.frame_type()
|
||||||
args: list[Type] = self.type_args()
|
args: list[Type] = self.type_args()
|
||||||
return GenericType(
|
return GenericType(
|
||||||
location=Location.span(type.location, self.previous().get_location()),
|
location=Location.span(type.location, self.previous().get_location()),
|
||||||
@@ -222,7 +227,7 @@ class MidasParser(Parser):
|
|||||||
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
|
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def named_type(self) -> Type:
|
def named_type(self) -> NamedType:
|
||||||
name: Token = self.consume_identifier("Expected type name")
|
name: Token = self.consume_identifier("Expected type name")
|
||||||
return NamedType(
|
return NamedType(
|
||||||
location=name.get_location(),
|
location=name.get_location(),
|
||||||
@@ -257,6 +262,32 @@ class MidasParser(Parser):
|
|||||||
members=members,
|
members=members,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def frame_type(self) -> FrameType:
|
||||||
|
keyword: Token = self.previous()
|
||||||
|
self.consume(TokenType.LEFT_BRACKET, "Expected '[' to start frame schema")
|
||||||
|
|
||||||
|
columns: list[FrameType.Column] = []
|
||||||
|
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
|
||||||
|
name: Token = self.advance()
|
||||||
|
self.consume(TokenType.COLON, "Expected ':' between column name and type")
|
||||||
|
type: Type = self.type_expr()
|
||||||
|
columns.append(
|
||||||
|
FrameType.Column(
|
||||||
|
location=name.location_to(self.previous()),
|
||||||
|
name=name,
|
||||||
|
type=type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not self.match(TokenType.COMMA):
|
||||||
|
break
|
||||||
|
|
||||||
|
self.consume(TokenType.RIGHT_BRACKET, "Unclosed frame schema")
|
||||||
|
|
||||||
|
return FrameType(
|
||||||
|
location=keyword.location_to(self.previous()),
|
||||||
|
columns=columns,
|
||||||
|
)
|
||||||
|
|
||||||
def constraint(self) -> Expr:
|
def constraint(self) -> Expr:
|
||||||
"""Parse a constraint
|
"""Parse a constraint
|
||||||
|
|
||||||
@@ -265,6 +296,9 @@ class MidasParser(Parser):
|
|||||||
Returns:
|
Returns:
|
||||||
Expr: the parsed constraint expression
|
Expr: the parsed constraint expression
|
||||||
"""
|
"""
|
||||||
|
return self.expression()
|
||||||
|
|
||||||
|
def expression(self) -> Expr:
|
||||||
return self.and_()
|
return self.and_()
|
||||||
|
|
||||||
def and_(self) -> Expr:
|
def and_(self) -> Expr:
|
||||||
@@ -331,7 +365,55 @@ class MidasParser(Parser):
|
|||||||
right: Expr = self.unary()
|
right: Expr = self.unary()
|
||||||
location: Location = Location.span(operator.get_location(), right.location)
|
location: Location = Location.span(operator.get_location(), right.location)
|
||||||
return UnaryExpr(location=location, operator=operator, right=right)
|
return UnaryExpr(location=location, operator=operator, right=right)
|
||||||
return self.reference()
|
return self.call()
|
||||||
|
|
||||||
|
def call(self) -> Expr:
|
||||||
|
expr: Expr = self.reference()
|
||||||
|
while self.match(TokenType.LEFT_PAREN):
|
||||||
|
expr = self.finish_call(expr)
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def finish_call(self, callee: Expr) -> Expr:
|
||||||
|
pos_args: list[Expr] = []
|
||||||
|
kw_args: dict[str, Expr] = {}
|
||||||
|
keywords: bool = False
|
||||||
|
while not self.check(TokenType.RIGHT_PAREN):
|
||||||
|
if self.check_identifier() and self.check_next(TokenType.EQUAL):
|
||||||
|
keywords = True
|
||||||
|
keyword: Token = self.advance()
|
||||||
|
self.advance()
|
||||||
|
value: Expr = self.expression()
|
||||||
|
name: str = keyword.lexeme
|
||||||
|
if name in kw_args:
|
||||||
|
self.error(
|
||||||
|
self.peek(),
|
||||||
|
f"Multiple values passed for '{name}', only the last occurrence will be used",
|
||||||
|
)
|
||||||
|
kw_args[name] = value
|
||||||
|
else:
|
||||||
|
value = self.expression()
|
||||||
|
if self.check(TokenType.EQUAL):
|
||||||
|
if keywords:
|
||||||
|
raise self.error(self.peek(), "Invalid keyword argument name")
|
||||||
|
else:
|
||||||
|
raise self.error(
|
||||||
|
self.peek(),
|
||||||
|
"Cannot pass positional arguments after a keyword argument",
|
||||||
|
)
|
||||||
|
pos_args.append(value)
|
||||||
|
|
||||||
|
if not self.match(TokenType.COMMA):
|
||||||
|
break
|
||||||
|
|
||||||
|
r_paren: Token = self.consume(
|
||||||
|
TokenType.RIGHT_PAREN, "Expected ')' after arguments."
|
||||||
|
)
|
||||||
|
return CallExpr(
|
||||||
|
location=Location.span(callee.location, r_paren.get_location()),
|
||||||
|
callee=callee,
|
||||||
|
arguments=pos_args,
|
||||||
|
keywords=kw_args,
|
||||||
|
)
|
||||||
|
|
||||||
def reference(self) -> Expr:
|
def reference(self) -> Expr:
|
||||||
"""Parse an attribute access expression or a simpler expression
|
"""Parse an attribute access expression or a simpler expression
|
||||||
@@ -365,6 +447,9 @@ class MidasParser(Parser):
|
|||||||
if self.match(TokenType.NUMBER):
|
if self.match(TokenType.NUMBER):
|
||||||
return LiteralExpr(location=token.get_location(), value=token.value)
|
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||||
|
|
||||||
|
if self.match(TokenType.STRING):
|
||||||
|
return LiteralExpr(location=token.get_location(), value=token.value)
|
||||||
|
|
||||||
if self.match_identifier():
|
if self.match_identifier():
|
||||||
return VariableExpr(location=token.get_location(), name=token)
|
return VariableExpr(location=token.get_location(), name=token)
|
||||||
|
|
||||||
@@ -453,23 +538,35 @@ class MidasParser(Parser):
|
|||||||
PredicateStmt: the parsed predicate declaration statement
|
PredicateStmt: the parsed predicate declaration statement
|
||||||
"""
|
"""
|
||||||
keyword: Token = self.previous()
|
keyword: Token = self.previous()
|
||||||
|
|
||||||
name: Token = self.consume_identifier("Expected predicate name")
|
name: Token = self.consume_identifier("Expected predicate name")
|
||||||
self.consume(TokenType.LEFT_PAREN, "Expected '(' before predicate subject")
|
|
||||||
subject: Token = self.consume_identifier("Expected subject name")
|
params: list[ParamSpec] = []
|
||||||
self.consume(TokenType.COLON, "Expected ':' after subject name")
|
while self.check(TokenType.LEFT_PAREN):
|
||||||
type: Type = self.type_expr()
|
params.append(self.function_args())
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after predicate subject")
|
|
||||||
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
self.consume(TokenType.EQUAL, "Expected '=' after predicate subject")
|
||||||
condition: Expr = self.constraint()
|
body: Expr = self.constraint()
|
||||||
return PredicateStmt(
|
return PredicateStmt(
|
||||||
location=keyword.location_to(self.previous()),
|
location=keyword.location_to(self.previous()),
|
||||||
name=name,
|
name=name,
|
||||||
subject=subject,
|
params=params,
|
||||||
type=type,
|
body=body,
|
||||||
condition=condition,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def function(self) -> FunctionType:
|
def function(self) -> FunctionType:
|
||||||
|
params: ParamSpec = self.function_args()
|
||||||
|
|
||||||
|
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
||||||
|
result: Type = self.type_expr()
|
||||||
|
|
||||||
|
return FunctionType(
|
||||||
|
location=params.l_paren.location_to(self.previous()),
|
||||||
|
params=params,
|
||||||
|
returns=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def function_args(self) -> ParamSpec:
|
||||||
l_paren: Token = self.consume(
|
l_paren: Token = self.consume(
|
||||||
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
TokenType.LEFT_PAREN, "Expected '(' before function parameters"
|
||||||
)
|
)
|
||||||
@@ -526,14 +623,4 @@ class MidasParser(Parser):
|
|||||||
self.error(token, "Unnamed mixed argument")
|
self.error(token, "Unnamed mixed argument")
|
||||||
|
|
||||||
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
self.consume(TokenType.RIGHT_PAREN, "Expected ')' after function parameters")
|
||||||
|
return ParamSpec(l_paren=l_paren, pos=pos_args, mixed=args, kw=kw_args)
|
||||||
self.consume(TokenType.ARROW, "Expected '->' before result type")
|
|
||||||
result: Type = self.type_expr()
|
|
||||||
|
|
||||||
return FunctionType(
|
|
||||||
location=l_paren.location_to(self.previous()),
|
|
||||||
pos_args=pos_args,
|
|
||||||
args=args,
|
|
||||||
kw_args=kw_args,
|
|
||||||
returns=result,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from midas.ast.python import (
|
|||||||
Stmt,
|
Stmt,
|
||||||
SubscriptExpr,
|
SubscriptExpr,
|
||||||
TernaryExpr,
|
TernaryExpr,
|
||||||
|
TupleExpr,
|
||||||
TypeAssign,
|
TypeAssign,
|
||||||
UnaryExpr,
|
UnaryExpr,
|
||||||
VariableExpr,
|
VariableExpr,
|
||||||
@@ -49,6 +50,7 @@ class UnsupportedSyntaxError(Exception):
|
|||||||
|
|
||||||
class PythonParser:
|
class PythonParser:
|
||||||
CAST_FUNCTION = "cast"
|
CAST_FUNCTION = "cast"
|
||||||
|
UNSAFE_CAST_FUNCTION = "unsafe_cast"
|
||||||
|
|
||||||
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
def parse_module(self, node: ast.Module) -> list[Stmt]:
|
||||||
statements: list[Stmt] = []
|
statements: list[Stmt] = []
|
||||||
@@ -299,26 +301,28 @@ class PythonParser:
|
|||||||
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
|
case ast.Subscript(value=ast.Name(id="Frame"), slice=schema):
|
||||||
return self._parse_frame_type(schema)
|
return self._parse_frame_type(schema)
|
||||||
|
|
||||||
case ast.Subscript(value=ast.Name(id=name), slice=param):
|
case ast.Subscript(value=ast.Name(id=name), slice=arg):
|
||||||
|
args: tuple[MidasType, ...] = (
|
||||||
|
tuple(self._parse_type(a) for a in arg.elts)
|
||||||
|
if isinstance(arg, ast.Tuple)
|
||||||
|
else (self._parse_type(arg),)
|
||||||
|
)
|
||||||
return BaseType(
|
return BaseType(
|
||||||
location=loc,
|
location=loc,
|
||||||
base=name,
|
base=name,
|
||||||
param=self._parse_type(param),
|
args=args,
|
||||||
)
|
)
|
||||||
|
|
||||||
case ast.Name(id=name):
|
case ast.Name(id=name):
|
||||||
return BaseType(
|
return BaseType(
|
||||||
location=loc,
|
location=loc,
|
||||||
base=name,
|
base=name,
|
||||||
param=None,
|
args=(),
|
||||||
)
|
)
|
||||||
|
|
||||||
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
|
case ast.BinOp(left=left_expr, op=ast.Add(), right=right_expr):
|
||||||
left = self._parse_type(left_expr)
|
left = self._parse_type(left_expr)
|
||||||
match left:
|
match left:
|
||||||
case None:
|
|
||||||
raise InvalidSyntaxError()
|
|
||||||
|
|
||||||
# If chained constraints, separate base type and rebuild constraint
|
# If chained constraints, separate base type and rebuild constraint
|
||||||
case ConstraintType(type=left_type, constraint=left_constraint):
|
case ConstraintType(type=left_type, constraint=left_constraint):
|
||||||
constraint = ast.BinOp(
|
constraint = ast.BinOp(
|
||||||
@@ -344,7 +348,7 @@ class PythonParser:
|
|||||||
return BaseType(
|
return BaseType(
|
||||||
location=loc,
|
location=loc,
|
||||||
base="None",
|
base="None",
|
||||||
param=None,
|
args=(),
|
||||||
)
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
@@ -423,6 +427,9 @@ class PythonParser:
|
|||||||
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
|
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
|
||||||
return self.parse_cast(node)
|
return self.parse_cast(node)
|
||||||
|
|
||||||
|
case ast.Call(func=ast.Name(id=self.UNSAFE_CAST_FUNCTION)):
|
||||||
|
return self.parse_cast(node)
|
||||||
|
|
||||||
case ast.Call():
|
case ast.Call():
|
||||||
return self.parse_call(node)
|
return self.parse_call(node)
|
||||||
|
|
||||||
@@ -473,6 +480,12 @@ class PythonParser:
|
|||||||
step=self.parse_expr(step) if step is not None else None,
|
step=self.parse_expr(step) if step is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case ast.Tuple(elts=items):
|
||||||
|
return TupleExpr(
|
||||||
|
location=location,
|
||||||
|
items=tuple(self.parse_expr(item) for item in items),
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
print(f"Unsupported expression: {ast.unparse(node)}")
|
print(f"Unsupported expression: {ast.unparse(node)}")
|
||||||
return RawExpr(location=location, expr=node)
|
return RawExpr(location=location, expr=node)
|
||||||
@@ -527,16 +540,19 @@ class PythonParser:
|
|||||||
return expr
|
return expr
|
||||||
|
|
||||||
def parse_cast(self, node: ast.Call) -> CastExpr:
|
def parse_cast(self, node: ast.Call) -> CastExpr:
|
||||||
|
assert isinstance(node.func, ast.Name)
|
||||||
|
func: str = node.func.id
|
||||||
match node:
|
match node:
|
||||||
case ast.Call(args=[type, expr], keywords=[]):
|
case ast.Call(args=[type, expr], keywords=[]):
|
||||||
return CastExpr(
|
return CastExpr(
|
||||||
location=Location.from_ast(node),
|
location=Location.from_ast(node),
|
||||||
type=self._parse_type(type),
|
type=self._parse_type(type),
|
||||||
expr=self.parse_expr(expr),
|
expr=self.parse_expr(expr),
|
||||||
|
unsafe=func == self.UNSAFE_CAST_FUNCTION,
|
||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
raise InvalidSyntaxError(
|
raise InvalidSyntaxError(
|
||||||
f"Invalid call to {self.CAST_FUNCTION}, expected type and expression"
|
f"Invalid call to {func}, expected type and expression"
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_call(self, node: ast.Call) -> CallExpr:
|
def parse_call(self, node: ast.Call) -> CallExpr:
|
||||||
|
|||||||
52
midas/typing.py
Normal file
52
midas/typing.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from typing import Generic, TypeVar
|
||||||
|
from typing import cast as typing_cast
|
||||||
|
|
||||||
|
cast = typing_cast
|
||||||
|
"""### Midas documentation
|
||||||
|
Cast a value to a type.
|
||||||
|
|
||||||
|
- **Compile-time**: tells the type checker that the return value has the designated type.
|
||||||
|
- **Run-time**: generates assertions to ensure the value can be interpreted as the given type.
|
||||||
|
|
||||||
|
---
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
|
||||||
|
_**Internal Python documentation**_
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
unsafe_cast = typing_cast
|
||||||
|
"""### Midas documentation
|
||||||
|
Cast a value to a type.
|
||||||
|
|
||||||
|
- **Compile-time**: tells the type checker that the return value has the designated type.
|
||||||
|
- **Run-time**: -
|
||||||
|
|
||||||
|
This operation is unsound, use at your own risk!
|
||||||
|
|
||||||
|
---
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
|
||||||
|
_**Internal Python documentation**_
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Frame(Generic[T]):
|
||||||
|
"""A `Frame` is the abstract type implemented by `DataFrame`
|
||||||
|
|
||||||
|
A frame contains any number of named columns (see :class:`Column`)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Column(Generic[T]):
|
||||||
|
"""A `Column` is the abstract type implemented by `Series`
|
||||||
|
|
||||||
|
A column contains a any number of values of the same type
|
||||||
|
"""
|
||||||
@@ -62,3 +62,4 @@ class UniversalJSONDumper:
|
|||||||
class TypedAST:
|
class TypedAST:
|
||||||
stmts: list[p.Stmt]
|
stmts: list[p.Stmt]
|
||||||
judgements: list[tuple[p.Expr, Type]]
|
judgements: list[tuple[p.Expr, Type]]
|
||||||
|
evaluated_casts: list[p.CastExpr]
|
||||||
|
|||||||
@@ -8,7 +8,11 @@ authors = [
|
|||||||
{ name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" },
|
{ name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" },
|
||||||
]
|
]
|
||||||
classifiers = ["Programming Language :: Python :: 3"]
|
classifiers = ["Programming Language :: Python :: 3"]
|
||||||
dependencies = ["click>=8.4.1"]
|
dependencies = [
|
||||||
|
"black>=26.5.1",
|
||||||
|
"click>=8.4.1",
|
||||||
|
"watchdog>=6.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Homepage = "https://git.kbk28.ch/HEL/midas"
|
Homepage = "https://git.kbk28.ch/HEL/midas"
|
||||||
|
|||||||
@@ -4,7 +4,35 @@
|
|||||||
"type": "Warning",
|
"type": "Warning",
|
||||||
"location": {
|
"location": {
|
||||||
"start": [
|
"start": [
|
||||||
6,
|
8,
|
||||||
|
12
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
8,
|
||||||
|
43
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "ConstraintType not yet supported"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Warning",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
10,
|
||||||
|
10
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
10,
|
||||||
|
18
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Unknown type 'datetime'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Warning",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
13,
|
||||||
4
|
4
|
||||||
],
|
],
|
||||||
"end": [
|
"end": [
|
||||||
@@ -12,7 +40,7 @@
|
|||||||
5
|
5
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"message": "FrameType not yet supported"
|
"message": "Unknown type '_'"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"judgments": []
|
"judgments": []
|
||||||
|
|||||||
@@ -328,6 +328,19 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:9",
|
||||||
|
"to": "L6:10"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L6:5",
|
"from": "L6:5",
|
||||||
@@ -373,19 +386,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L6:9",
|
|
||||||
"to": "L6:10"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L6:5",
|
"from": "L6:5",
|
||||||
@@ -407,6 +407,32 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L7:9",
|
||||||
|
"to": "L7:10"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L7:12",
|
||||||
|
"to": "L7:15"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L7:5",
|
"from": "L7:5",
|
||||||
@@ -452,32 +478,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L7:9",
|
|
||||||
"to": "L7:10"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L7:12",
|
|
||||||
"to": "L7:15"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L7:5",
|
"from": "L7:5",
|
||||||
@@ -503,6 +503,32 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L8:9",
|
||||||
|
"to": "L8:10"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L8:14",
|
||||||
|
"to": "L8:17"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L8:5",
|
"from": "L8:5",
|
||||||
@@ -548,32 +574,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L8:9",
|
|
||||||
"to": "L8:10"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L8:14",
|
|
||||||
"to": "L8:17"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L8:5",
|
"from": "L8:5",
|
||||||
@@ -600,6 +600,45 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:9",
|
||||||
|
"to": "L9:10"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:12",
|
||||||
|
"to": "L9:15"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:17",
|
||||||
|
"to": "L9:23"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "test"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L9:5",
|
"from": "L9:5",
|
||||||
@@ -645,45 +684,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L9:9",
|
|
||||||
"to": "L9:10"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L9:12",
|
|
||||||
"to": "L9:15"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L9:17",
|
|
||||||
"to": "L9:23"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": "test"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "str"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L9:5",
|
"from": "L9:5",
|
||||||
@@ -713,6 +713,45 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L10:9",
|
||||||
|
"to": "L10:10"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L10:12",
|
||||||
|
"to": "L10:15"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L10:19",
|
||||||
|
"to": "L10:22"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 3.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L10:5",
|
"from": "L10:5",
|
||||||
@@ -758,45 +797,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L10:9",
|
|
||||||
"to": "L10:10"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L10:12",
|
|
||||||
"to": "L10:15"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L10:19",
|
|
||||||
"to": "L10:22"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 3.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L10:5",
|
"from": "L10:5",
|
||||||
@@ -827,6 +827,19 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:11",
|
||||||
|
"to": "L11:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L11:5",
|
"from": "L11:5",
|
||||||
@@ -872,19 +885,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L11:11",
|
|
||||||
"to": "L11:12"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L11:5",
|
"from": "L11:5",
|
||||||
@@ -906,6 +906,19 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:11",
|
||||||
|
"to": "L12:17"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "test"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L12:5",
|
"from": "L12:5",
|
||||||
@@ -951,19 +964,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L12:11",
|
|
||||||
"to": "L12:17"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": "test"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "str"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L12:5",
|
"from": "L12:5",
|
||||||
@@ -985,6 +985,45 @@
|
|||||||
},
|
},
|
||||||
"type": {}
|
"type": {}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L14:10",
|
||||||
|
"to": "L14:11"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L14:13",
|
||||||
|
"to": "L14:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L14:20",
|
||||||
|
"to": "L14:26"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "test"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L14:6",
|
"from": "L14:6",
|
||||||
@@ -1030,45 +1069,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L14:10",
|
|
||||||
"to": "L14:11"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L14:13",
|
|
||||||
"to": "L14:16"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L14:20",
|
|
||||||
"to": "L14:26"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": "test"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "str"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L14:6",
|
"from": "L14:6",
|
||||||
@@ -1101,6 +1101,45 @@
|
|||||||
"name": "bool"
|
"name": "bool"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L15:10",
|
||||||
|
"to": "L15:11"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L15:15",
|
||||||
|
"to": "L15:18"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L15:22",
|
||||||
|
"to": "L15:28"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "test"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L15:6",
|
"from": "L15:6",
|
||||||
@@ -1146,45 +1185,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L15:10",
|
|
||||||
"to": "L15:11"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L15:15",
|
|
||||||
"to": "L15:18"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L15:22",
|
|
||||||
"to": "L15:28"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": "test"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "str"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L15:6",
|
"from": "L15:6",
|
||||||
@@ -1217,6 +1217,45 @@
|
|||||||
"name": "bool"
|
"name": "bool"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L16:10",
|
||||||
|
"to": "L16:11"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L16:15",
|
||||||
|
"to": "L16:21"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "test"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L16:25",
|
||||||
|
"to": "L16:28"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2.0
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L16:6",
|
"from": "L16:6",
|
||||||
@@ -1262,45 +1301,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L16:10",
|
|
||||||
"to": "L16:11"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 1
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L16:15",
|
|
||||||
"to": "L16:21"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": "test"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "str"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L16:25",
|
|
||||||
"to": "L16:28"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 2.0
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L16:6",
|
"from": "L16:6",
|
||||||
@@ -1333,6 +1333,45 @@
|
|||||||
"name": "bool"
|
"name": "bool"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L18:10",
|
||||||
|
"to": "L18:13"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": "a"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "str"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L18:15",
|
||||||
|
"to": "L18:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 3
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L18:20",
|
||||||
|
"to": "L18:25"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": false
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L18:6",
|
"from": "L18:6",
|
||||||
@@ -1378,45 +1417,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L18:10",
|
|
||||||
"to": "L18:13"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": "a"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "str"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L18:15",
|
|
||||||
"to": "L18:16"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": 3
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L18:20",
|
|
||||||
"to": "L18:25"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "LiteralExpr",
|
|
||||||
"value": false
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "bool"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L18:6",
|
"from": "L18:6",
|
||||||
|
|||||||
@@ -1,6 +1,19 @@
|
|||||||
{
|
{
|
||||||
"diagnostics": [],
|
"diagnostics": [],
|
||||||
"judgments": [
|
"judgments": [
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L4:30",
|
||||||
|
"to": "L4:36"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 123.45
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L4:18",
|
"from": "L4:18",
|
||||||
@@ -11,12 +24,13 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Meter",
|
"base": "Meter",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"expr": {
|
"expr": {
|
||||||
"_type": "LiteralExpr",
|
"_type": "LiteralExpr",
|
||||||
"value": 123.45
|
"value": 123.45
|
||||||
}
|
},
|
||||||
|
"unsafe": false
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"name": "Meter",
|
"name": "Meter",
|
||||||
@@ -25,6 +39,19 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L5:28",
|
||||||
|
"to": "L5:31"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 6.7
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L5:15",
|
"from": "L5:15",
|
||||||
@@ -35,12 +62,13 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Second",
|
"base": "Second",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"expr": {
|
"expr": {
|
||||||
"_type": "LiteralExpr",
|
"_type": "LiteralExpr",
|
||||||
"value": 6.7
|
"value": 6.7
|
||||||
}
|
},
|
||||||
|
"unsafe": false
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"name": "Second",
|
"name": "Second",
|
||||||
|
|||||||
@@ -100,6 +100,32 @@
|
|||||||
"name": "float"
|
"name": "float"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:13",
|
||||||
|
"to": "L11:15"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v1"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L11:17",
|
||||||
|
"to": "L11:19"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "v2"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L11:5",
|
"from": "L11:5",
|
||||||
@@ -135,32 +161,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L11:13",
|
|
||||||
"to": "L11:15"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "VariableExpr",
|
|
||||||
"name": "v1"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "int"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"location": {
|
|
||||||
"from": "L11:17",
|
|
||||||
"to": "L11:19"
|
|
||||||
},
|
|
||||||
"expr": {
|
|
||||||
"_type": "VariableExpr",
|
|
||||||
"name": "v2"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"name": "float"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"location": {
|
"location": {
|
||||||
"from": "L11:5",
|
"from": "L11:5",
|
||||||
|
|||||||
59
tests/cases/checker/07_variance.midas
Normal file
59
tests/cases/checker/07_variance.midas
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
// T is invariant (unused)
|
||||||
|
type Unused[T] = object
|
||||||
|
|
||||||
|
// T is covariant
|
||||||
|
type Covariant[T] = object
|
||||||
|
|
||||||
|
// T is contravariant
|
||||||
|
type Contravariant[T] = object
|
||||||
|
|
||||||
|
// T is invariant
|
||||||
|
type Invariant[T] = object
|
||||||
|
|
||||||
|
extend Covariant[T] {
|
||||||
|
def foo: fn() -> T
|
||||||
|
}
|
||||||
|
|
||||||
|
extend Contravariant[T] {
|
||||||
|
def foo: fn(T, /) -> None
|
||||||
|
}
|
||||||
|
|
||||||
|
extend Invariant[T] {
|
||||||
|
def foo: fn(T, /) -> T
|
||||||
|
}
|
||||||
|
|
||||||
|
// T is covariant
|
||||||
|
type Coco[T] = object
|
||||||
|
extend Coco[T] {
|
||||||
|
def foo: fn() -> Covariant[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
// T is contravariant
|
||||||
|
type Cocontra[T] = object
|
||||||
|
extend Cocontra[T] {
|
||||||
|
def foo: fn() -> Contravariant[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
// T is contravariant
|
||||||
|
type Contraco[T] = object
|
||||||
|
extend Contraco[T] {
|
||||||
|
def foo: fn(Covariant[T], /) -> None
|
||||||
|
}
|
||||||
|
|
||||||
|
// T is covariant
|
||||||
|
type Contracontra[T] = object
|
||||||
|
extend Contracontra[T] {
|
||||||
|
def foo: fn(Contravariant[T], /) -> None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
type T1[T] = object
|
||||||
|
type T2[T] = object
|
||||||
|
|
||||||
|
extend T1[T] {
|
||||||
|
def foo: fn() -> T2[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
extend T2[T] {
|
||||||
|
def foo: fn() -> T1[T]
|
||||||
|
}
|
||||||
52
tests/cases/checker/07_variance.py
Normal file
52
tests/cases/checker/07_variance.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from _ import (
|
||||||
|
T1,
|
||||||
|
T2,
|
||||||
|
Coco,
|
||||||
|
Cocontra,
|
||||||
|
Contraco,
|
||||||
|
Contracontra,
|
||||||
|
Contravariant,
|
||||||
|
Covariant,
|
||||||
|
Invariant,
|
||||||
|
Unused,
|
||||||
|
)
|
||||||
|
|
||||||
|
unused: Unused
|
||||||
|
covariant: Covariant
|
||||||
|
contravariant: Contravariant
|
||||||
|
invariant: Invariant
|
||||||
|
coco: Coco
|
||||||
|
cocontra: Cocontra
|
||||||
|
contraco: Contraco
|
||||||
|
contracontra: Contracontra
|
||||||
|
t1: T1
|
||||||
|
t2: T2
|
||||||
|
|
||||||
|
# Dummy print to prudce judgements for the expressions
|
||||||
|
print(
|
||||||
|
unused,
|
||||||
|
covariant,
|
||||||
|
contravariant,
|
||||||
|
invariant,
|
||||||
|
coco,
|
||||||
|
cocontra,
|
||||||
|
contraco,
|
||||||
|
contracontra,
|
||||||
|
t1,
|
||||||
|
t2,
|
||||||
|
)
|
||||||
|
|
||||||
|
cov1: Covariant[float]
|
||||||
|
cov2: Covariant[int]
|
||||||
|
cov1 = cov2 # Ok because int <: float => Covariant[int] <: Covariant[float]
|
||||||
|
cov2 = cov1 # Invalid
|
||||||
|
|
||||||
|
contra1: Contravariant[float]
|
||||||
|
contra2: Contravariant[int]
|
||||||
|
contra1 = contra2 # Invalid
|
||||||
|
contra2 = contra1 # Ok because int <: float => Covariant[float] <: Covariant[int]
|
||||||
|
|
||||||
|
inv1: Invariant[float]
|
||||||
|
inv2: Invariant[int]
|
||||||
|
inv1 = inv2 # Invalid
|
||||||
|
inv2 = inv1 # Invalid
|
||||||
512
tests/cases/checker/07_variance.py.ref.json
Normal file
512
tests/cases/checker/07_variance.py.ref.json
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
{
|
||||||
|
"diagnostics": [
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
28,
|
||||||
|
4
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
28,
|
||||||
|
13
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Too many positional arguments"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
42,
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
42,
|
||||||
|
11
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Cannot assign Covariant[float] to variable 'cov2' of type Covariant[int]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
46,
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
46,
|
||||||
|
17
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Cannot assign Contravariant[int] to variable 'contra1' of type Contravariant[float]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
51,
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
51,
|
||||||
|
11
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Cannot assign Invariant[int] to variable 'inv1' of type Invariant[float]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
52,
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
52,
|
||||||
|
11
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Cannot assign Invariant[float] to variable 'inv2' of type Invariant[int]"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"judgments": [
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L27:4",
|
||||||
|
"to": "L27:10"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "unused"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Unused",
|
||||||
|
"params": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L28:4",
|
||||||
|
"to": "L28:13"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "covariant"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Covariant",
|
||||||
|
"params": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "COVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L29:4",
|
||||||
|
"to": "L29:17"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "contravariant"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Contravariant",
|
||||||
|
"params": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "CONTRAVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L30:4",
|
||||||
|
"to": "L30:13"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "invariant"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Invariant",
|
||||||
|
"params": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L31:4",
|
||||||
|
"to": "L31:8"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "coco"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Coco",
|
||||||
|
"params": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "COVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L32:4",
|
||||||
|
"to": "L32:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "cocontra"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Cocontra",
|
||||||
|
"params": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "CONTRAVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L33:4",
|
||||||
|
"to": "L33:12"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "contraco"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Contraco",
|
||||||
|
"params": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "CONTRAVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L34:4",
|
||||||
|
"to": "L34:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "contracontra"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Contracontra",
|
||||||
|
"params": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "COVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L35:4",
|
||||||
|
"to": "L35:6"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "t1"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "T1",
|
||||||
|
"params": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L36:4",
|
||||||
|
"to": "L36:6"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "t2"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "T2",
|
||||||
|
"params": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L26:0",
|
||||||
|
"to": "L26:5"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "print"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"pos_args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "object",
|
||||||
|
"type": {},
|
||||||
|
"required": false
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"args": [],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L26:0",
|
||||||
|
"to": "L37:1"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CallExpr",
|
||||||
|
"callee": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "print"
|
||||||
|
},
|
||||||
|
"arguments": [
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "unused"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "covariant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "contravariant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "invariant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "coco"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "cocontra"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "contraco"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "contracontra"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "t1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "t2"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"keywords": {}
|
||||||
|
},
|
||||||
|
"type": {}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L41:7",
|
||||||
|
"to": "L41:11"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "cov2"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Covariant",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L42:7",
|
||||||
|
"to": "L42:11"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "cov1"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Covariant",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L46:10",
|
||||||
|
"to": "L46:17"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "contra2"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Contravariant",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L47:10",
|
||||||
|
"to": "L47:17"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "contra1"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Contravariant",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L51:7",
|
||||||
|
"to": "L51:11"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "inv2"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Invariant",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L52:7",
|
||||||
|
"to": "L52:11"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "inv1"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "Invariant",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
14
tests/cases/checker/08_unification.py
Normal file
14
tests/cases/checker/08_unification.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
def double(value: float) -> float:
|
||||||
|
return value * 2
|
||||||
|
|
||||||
|
|
||||||
|
def is_odd(value: int) -> bool:
|
||||||
|
return bool(value % 2)
|
||||||
|
|
||||||
|
|
||||||
|
floats: list[float] = [0.2, 0.5, 0.1, 1.2]
|
||||||
|
ints: list[int] = [1, 2, 6, -3]
|
||||||
|
|
||||||
|
doubled_floats = map(double, floats)
|
||||||
|
doubled_ints = map(double, ints)
|
||||||
|
odd_ints = map(is_odd, ints)
|
||||||
874
tests/cases/checker/08_unification.py.ref.json
Normal file
874
tests/cases/checker/08_unification.py.ref.json
Normal file
@@ -0,0 +1,874 @@
|
|||||||
|
{
|
||||||
|
"diagnostics": [
|
||||||
|
{
|
||||||
|
"type": "Error",
|
||||||
|
"location": {
|
||||||
|
"start": [
|
||||||
|
13,
|
||||||
|
15
|
||||||
|
],
|
||||||
|
"end": [
|
||||||
|
13,
|
||||||
|
32
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"message": "Could not unify map[T, U]=(transform: (v: T, /) -> U, iterable: list[T], /) -> list[U] with pos=[(value: float) -> float, list[int]] and kw={}"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"judgments": [
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L2:11",
|
||||||
|
"to": "L2:16"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "value"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L2:19",
|
||||||
|
"to": "L2:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L2:11",
|
||||||
|
"to": "L2:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "value"
|
||||||
|
},
|
||||||
|
"operator": "*",
|
||||||
|
"right": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:16",
|
||||||
|
"to": "L6:21"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "value"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:24",
|
||||||
|
"to": "L6:25"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:16",
|
||||||
|
"to": "L6:25"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "value"
|
||||||
|
},
|
||||||
|
"operator": "%",
|
||||||
|
"right": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:11",
|
||||||
|
"to": "L6:15"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "bool"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"pos_args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "object",
|
||||||
|
"type": {},
|
||||||
|
"required": false
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"args": [],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L6:11",
|
||||||
|
"to": "L6:26"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CallExpr",
|
||||||
|
"callee": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "bool"
|
||||||
|
},
|
||||||
|
"arguments": [
|
||||||
|
{
|
||||||
|
"_type": "BinaryExpr",
|
||||||
|
"left": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "value"
|
||||||
|
},
|
||||||
|
"operator": "%",
|
||||||
|
"right": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"keywords": {}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:23",
|
||||||
|
"to": "L9:26"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 0.2
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:28",
|
||||||
|
"to": "L9:31"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 0.5
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:33",
|
||||||
|
"to": "L9:36"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 0.1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:38",
|
||||||
|
"to": "L9:41"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1.2
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L9:22",
|
||||||
|
"to": "L9:42"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "ListExpr",
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 0.2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 0.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1.2
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L10:19",
|
||||||
|
"to": "L10:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L10:22",
|
||||||
|
"to": "L10:23"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L10:25",
|
||||||
|
"to": "L10:26"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 6
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L10:29",
|
||||||
|
"to": "L10:30"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 3
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L10:28",
|
||||||
|
"to": "L10:30"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "UnaryExpr",
|
||||||
|
"operator": "-",
|
||||||
|
"right": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 3
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L10:18",
|
||||||
|
"to": "L10:31"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "ListExpr",
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 6
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "UnaryExpr",
|
||||||
|
"operator": "-",
|
||||||
|
"right": {
|
||||||
|
"_type": "LiteralExpr",
|
||||||
|
"value": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:21",
|
||||||
|
"to": "L12:27"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "double"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"pos_args": [],
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "value",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:29",
|
||||||
|
"to": "L12:35"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "floats"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:17",
|
||||||
|
"to": "L12:20"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "map"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "map",
|
||||||
|
"params": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "U",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"pos_args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "transform",
|
||||||
|
"type": {
|
||||||
|
"pos_args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "v",
|
||||||
|
"type": {
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"args": [],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "U",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pos": 1,
|
||||||
|
"name": "iterable",
|
||||||
|
"type": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"args": [],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "U",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L12:17",
|
||||||
|
"to": "L12:36"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CallExpr",
|
||||||
|
"callee": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "map"
|
||||||
|
},
|
||||||
|
"arguments": [
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "double"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "floats"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"keywords": {}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L13:19",
|
||||||
|
"to": "L13:25"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "double"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"pos_args": [],
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "value",
|
||||||
|
"type": {
|
||||||
|
"name": "float"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "float"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L13:27",
|
||||||
|
"to": "L13:31"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "ints"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L13:15",
|
||||||
|
"to": "L13:18"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "map"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "map",
|
||||||
|
"params": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "U",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"pos_args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "transform",
|
||||||
|
"type": {
|
||||||
|
"pos_args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "v",
|
||||||
|
"type": {
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"args": [],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "U",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pos": 1,
|
||||||
|
"name": "iterable",
|
||||||
|
"type": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"args": [],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "U",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L13:15",
|
||||||
|
"to": "L13:32"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CallExpr",
|
||||||
|
"callee": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "map"
|
||||||
|
},
|
||||||
|
"arguments": [
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "double"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "ints"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"keywords": {}
|
||||||
|
},
|
||||||
|
"type": {}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L14:15",
|
||||||
|
"to": "L14:21"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "is_odd"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"pos_args": [],
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "value",
|
||||||
|
"type": {
|
||||||
|
"name": "int"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L14:23",
|
||||||
|
"to": "L14:27"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "ints"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "int"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L14:11",
|
||||||
|
"to": "L14:14"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "map"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "map",
|
||||||
|
"params": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "U",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"pos_args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "transform",
|
||||||
|
"type": {
|
||||||
|
"pos_args": [
|
||||||
|
{
|
||||||
|
"pos": 0,
|
||||||
|
"name": "v",
|
||||||
|
"type": {
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"args": [],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "U",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"pos": 1,
|
||||||
|
"name": "iterable",
|
||||||
|
"type": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "T",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"args": [],
|
||||||
|
"kw_args": [],
|
||||||
|
"returns": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "U",
|
||||||
|
"bound": null,
|
||||||
|
"variance": "INVARIANT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"location": {
|
||||||
|
"from": "L14:11",
|
||||||
|
"to": "L14:28"
|
||||||
|
},
|
||||||
|
"expr": {
|
||||||
|
"_type": "CallExpr",
|
||||||
|
"callee": {
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "map"
|
||||||
|
},
|
||||||
|
"arguments": [
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "is_odd"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_type": "VariableExpr",
|
||||||
|
"name": "ints"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"keywords": {}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"name": "list",
|
||||||
|
"args": [
|
||||||
|
{
|
||||||
|
"name": "bool"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
|
"name": "list"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -7,68 +7,14 @@ Module(
|
|||||||
alias(name='Meter'),
|
alias(name='Meter'),
|
||||||
alias(name='Second')],
|
alias(name='Second')],
|
||||||
level=0),
|
level=0),
|
||||||
Assign(
|
|
||||||
targets=[
|
|
||||||
Name(id='__midas_alias_0__')],
|
|
||||||
value=Constant(value=123.45)),
|
|
||||||
Assert(
|
|
||||||
test=Call(
|
|
||||||
func=Name(id='isinstance'),
|
|
||||||
args=[
|
|
||||||
Name(id='__midas_alias_0__'),
|
|
||||||
Name(id='float')],
|
|
||||||
keywords=[]),
|
|
||||||
msg=JoinedStr(
|
|
||||||
values=[
|
|
||||||
Constant(value='01_simple_types.py:L3:19: CastError: Cannot cast '),
|
|
||||||
FormattedValue(
|
|
||||||
value=Attribute(
|
|
||||||
value=Call(
|
|
||||||
func=Name(id='type'),
|
|
||||||
args=[
|
|
||||||
Name(id='__midas_alias_0__')],
|
|
||||||
keywords=[]),
|
|
||||||
attr='__name__'),
|
|
||||||
conversion=-1),
|
|
||||||
Constant(value=' to float')])),
|
|
||||||
Assign(
|
Assign(
|
||||||
targets=[
|
targets=[
|
||||||
Name(id='distance')],
|
Name(id='distance')],
|
||||||
value=Name(id='__midas_alias_0__')),
|
value=Constant(value=123.45)),
|
||||||
Delete(
|
|
||||||
targets=[
|
|
||||||
Name(id='__midas_alias_0__')]),
|
|
||||||
Assign(
|
|
||||||
targets=[
|
|
||||||
Name(id='__midas_alias_1__')],
|
|
||||||
value=Constant(value=6.7)),
|
|
||||||
Assert(
|
|
||||||
test=Call(
|
|
||||||
func=Name(id='isinstance'),
|
|
||||||
args=[
|
|
||||||
Name(id='__midas_alias_1__'),
|
|
||||||
Name(id='float')],
|
|
||||||
keywords=[]),
|
|
||||||
msg=JoinedStr(
|
|
||||||
values=[
|
|
||||||
Constant(value='01_simple_types.py:L4:16: CastError: Cannot cast '),
|
|
||||||
FormattedValue(
|
|
||||||
value=Attribute(
|
|
||||||
value=Call(
|
|
||||||
func=Name(id='type'),
|
|
||||||
args=[
|
|
||||||
Name(id='__midas_alias_1__')],
|
|
||||||
keywords=[]),
|
|
||||||
attr='__name__'),
|
|
||||||
conversion=-1),
|
|
||||||
Constant(value=' to float')])),
|
|
||||||
Assign(
|
Assign(
|
||||||
targets=[
|
targets=[
|
||||||
Name(id='time')],
|
Name(id='time')],
|
||||||
value=Name(id='__midas_alias_1__')),
|
value=Constant(value=6.7)),
|
||||||
Delete(
|
|
||||||
targets=[
|
|
||||||
Name(id='__midas_alias_1__')]),
|
|
||||||
Assign(
|
Assign(
|
||||||
targets=[
|
targets=[
|
||||||
Name(id='speed')],
|
Name(id='speed')],
|
||||||
|
|||||||
14
tests/cases/generator/02_constraints.midas
Normal file
14
tests/cases/generator/02_constraints.midas
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
// Inline
|
||||||
|
type T1 = float where _ > 0
|
||||||
|
|
||||||
|
// Named
|
||||||
|
predicate is_positive(v: float) = v > 0
|
||||||
|
type T2 = float where is_positive(_)
|
||||||
|
|
||||||
|
// Curried
|
||||||
|
predicate in_range(mn: float, mx: float)(v: float) = v >= mn & v < mx
|
||||||
|
type T3 = float where in_range(100, 200)(_)
|
||||||
|
|
||||||
|
// Alias
|
||||||
|
predicate minor = in_range(0, 18)
|
||||||
|
type T4 = float where minor(_)
|
||||||
8
tests/cases/generator/02_constraints.py
Normal file
8
tests/cases/generator/02_constraints.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from midas import T1, T2, T3, T4, cast
|
||||||
|
|
||||||
|
t: float = 12.5
|
||||||
|
|
||||||
|
t1: T1 = cast(T1, t)
|
||||||
|
t2: T2 = cast(T2, t)
|
||||||
|
t3: T3 = cast(T3, t)
|
||||||
|
t4: T4 = cast(T4, t)
|
||||||
333
tests/cases/generator/02_constraints.py.ref.txt
Normal file
333
tests/cases/generator/02_constraints.py.ref.txt
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
Module(
|
||||||
|
body=[
|
||||||
|
FunctionDef(
|
||||||
|
name='__midas_p0__',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='_',
|
||||||
|
annotation=Constant(value='Any'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
Return(
|
||||||
|
value=Compare(
|
||||||
|
left=Name(id='_'),
|
||||||
|
ops=[
|
||||||
|
Gt()],
|
||||||
|
comparators=[
|
||||||
|
Constant(value=0.0)]))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='bool')),
|
||||||
|
FunctionDef(
|
||||||
|
name='__midas_is_positive__',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='v',
|
||||||
|
annotation=Constant(value='float'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
Return(
|
||||||
|
value=Compare(
|
||||||
|
left=Name(id='v'),
|
||||||
|
ops=[
|
||||||
|
Gt()],
|
||||||
|
comparators=[
|
||||||
|
Constant(value=0.0)]))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='bool')),
|
||||||
|
FunctionDef(
|
||||||
|
name='__midas_p1__',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='_',
|
||||||
|
annotation=Constant(value='Any'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
Return(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='__midas_is_positive__'),
|
||||||
|
args=[
|
||||||
|
Name(id='_')],
|
||||||
|
keywords=[]))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='bool')),
|
||||||
|
FunctionDef(
|
||||||
|
name='__midas_in_range__',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='mn',
|
||||||
|
annotation=Constant(value='float')),
|
||||||
|
arg(
|
||||||
|
arg='mx',
|
||||||
|
annotation=Constant(value='float'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
FunctionDef(
|
||||||
|
name='inner0',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='v',
|
||||||
|
annotation=Constant(value='float'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
Return(
|
||||||
|
value=BoolOp(
|
||||||
|
op=And(),
|
||||||
|
values=[
|
||||||
|
Compare(
|
||||||
|
left=Name(id='v'),
|
||||||
|
ops=[
|
||||||
|
GtE()],
|
||||||
|
comparators=[
|
||||||
|
Name(id='mn')]),
|
||||||
|
Compare(
|
||||||
|
left=Name(id='v'),
|
||||||
|
ops=[
|
||||||
|
Lt()],
|
||||||
|
comparators=[
|
||||||
|
Name(id='mx')])]))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='bool')),
|
||||||
|
Return(
|
||||||
|
value=Name(id='inner0'))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='Callable[[float], bool]')),
|
||||||
|
FunctionDef(
|
||||||
|
name='__midas_p2__',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='_',
|
||||||
|
annotation=Constant(value='Any'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
Return(
|
||||||
|
value=Call(
|
||||||
|
func=Call(
|
||||||
|
func=Name(id='__midas_in_range__'),
|
||||||
|
args=[
|
||||||
|
Constant(value=100.0),
|
||||||
|
Constant(value=200.0)],
|
||||||
|
keywords=[]),
|
||||||
|
args=[
|
||||||
|
Name(id='_')],
|
||||||
|
keywords=[]))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='bool')),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_minor__')],
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='__midas_in_range__'),
|
||||||
|
args=[
|
||||||
|
Constant(value=0.0),
|
||||||
|
Constant(value=18.0)],
|
||||||
|
keywords=[])),
|
||||||
|
FunctionDef(
|
||||||
|
name='__midas_p3__',
|
||||||
|
args=arguments(
|
||||||
|
posonlyargs=[],
|
||||||
|
args=[
|
||||||
|
arg(
|
||||||
|
arg='_',
|
||||||
|
annotation=Constant(value='Any'))],
|
||||||
|
kwonlyargs=[],
|
||||||
|
kw_defaults=[],
|
||||||
|
defaults=[]),
|
||||||
|
body=[
|
||||||
|
Return(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='__midas_minor__'),
|
||||||
|
args=[
|
||||||
|
Name(id='_')],
|
||||||
|
keywords=[]))],
|
||||||
|
decorator_list=[],
|
||||||
|
returns=Constant(value='bool')),
|
||||||
|
ImportFrom(
|
||||||
|
module='midas',
|
||||||
|
names=[
|
||||||
|
alias(name='T1'),
|
||||||
|
alias(name='T2'),
|
||||||
|
alias(name='T3'),
|
||||||
|
alias(name='T4'),
|
||||||
|
alias(name='cast')],
|
||||||
|
level=0),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='t')],
|
||||||
|
value=Constant(value=12.5)),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a0__')],
|
||||||
|
value=Name(id='t')),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='isinstance'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a0__'),
|
||||||
|
Name(id='float')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=JoinedStr(
|
||||||
|
values=[
|
||||||
|
Constant(value='02_constraints.py:L5:10: CastError: Cannot cast '),
|
||||||
|
FormattedValue(
|
||||||
|
value=Attribute(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='type'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a0__')],
|
||||||
|
keywords=[]),
|
||||||
|
attr='__name__'),
|
||||||
|
conversion=-1),
|
||||||
|
Constant(value=' to float')])),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='__midas_p0__'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a0__')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=Constant(value="02_constraints.py:L5:10: ConstraintError: Value does not fit constraint '_ > 0.0'")),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='t1')],
|
||||||
|
value=Name(id='__midas_a0__')),
|
||||||
|
Delete(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a0__')]),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a1__')],
|
||||||
|
value=Name(id='t')),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='isinstance'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a1__'),
|
||||||
|
Name(id='float')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=JoinedStr(
|
||||||
|
values=[
|
||||||
|
Constant(value='02_constraints.py:L6:10: CastError: Cannot cast '),
|
||||||
|
FormattedValue(
|
||||||
|
value=Attribute(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='type'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a1__')],
|
||||||
|
keywords=[]),
|
||||||
|
attr='__name__'),
|
||||||
|
conversion=-1),
|
||||||
|
Constant(value=' to float')])),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='__midas_p1__'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a1__')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=Constant(value="02_constraints.py:L6:10: ConstraintError: Value does not fit constraint 'is_positive(_)'")),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='t2')],
|
||||||
|
value=Name(id='__midas_a1__')),
|
||||||
|
Delete(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a1__')]),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a2__')],
|
||||||
|
value=Name(id='t')),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='isinstance'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a2__'),
|
||||||
|
Name(id='float')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=JoinedStr(
|
||||||
|
values=[
|
||||||
|
Constant(value='02_constraints.py:L7:10: CastError: Cannot cast '),
|
||||||
|
FormattedValue(
|
||||||
|
value=Attribute(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='type'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a2__')],
|
||||||
|
keywords=[]),
|
||||||
|
attr='__name__'),
|
||||||
|
conversion=-1),
|
||||||
|
Constant(value=' to float')])),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='__midas_p2__'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a2__')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=Constant(value="02_constraints.py:L7:10: ConstraintError: Value does not fit constraint 'in_range(100.0, 200.0)(_)'")),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='t3')],
|
||||||
|
value=Name(id='__midas_a2__')),
|
||||||
|
Delete(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a2__')]),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a3__')],
|
||||||
|
value=Name(id='t')),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='isinstance'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a3__'),
|
||||||
|
Name(id='float')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=JoinedStr(
|
||||||
|
values=[
|
||||||
|
Constant(value='02_constraints.py:L8:10: CastError: Cannot cast '),
|
||||||
|
FormattedValue(
|
||||||
|
value=Attribute(
|
||||||
|
value=Call(
|
||||||
|
func=Name(id='type'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a3__')],
|
||||||
|
keywords=[]),
|
||||||
|
attr='__name__'),
|
||||||
|
conversion=-1),
|
||||||
|
Constant(value=' to float')])),
|
||||||
|
Assert(
|
||||||
|
test=Call(
|
||||||
|
func=Name(id='__midas_p3__'),
|
||||||
|
args=[
|
||||||
|
Name(id='__midas_a3__')],
|
||||||
|
keywords=[]),
|
||||||
|
msg=Constant(value="02_constraints.py:L8:10: ConstraintError: Value does not fit constraint 'minor(_)'")),
|
||||||
|
Assign(
|
||||||
|
targets=[
|
||||||
|
Name(id='t4')],
|
||||||
|
value=Name(id='__midas_a3__')),
|
||||||
|
Delete(
|
||||||
|
targets=[
|
||||||
|
Name(id='__midas_a3__')])],
|
||||||
|
type_ignores=[])
|
||||||
@@ -2582,18 +2582,21 @@
|
|||||||
"name": "__sub__",
|
"name": "__sub__",
|
||||||
"type": {
|
"type": {
|
||||||
"_type": "FunctionType",
|
"_type": "FunctionType",
|
||||||
"pos_args": [
|
"params": {
|
||||||
{
|
"_type": "ParamSpec",
|
||||||
"name": null,
|
"pos": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"name": null,
|
||||||
"name": "GeoLocation"
|
"type": {
|
||||||
},
|
"_type": "NamedType",
|
||||||
"required": true
|
"name": "GeoLocation"
|
||||||
}
|
},
|
||||||
],
|
"required": true
|
||||||
"args": [],
|
}
|
||||||
"kw_args": [],
|
],
|
||||||
|
"mixed": [],
|
||||||
|
"kw": []
|
||||||
|
},
|
||||||
"returns": {
|
"returns": {
|
||||||
"_type": "GenericType",
|
"_type": "GenericType",
|
||||||
"type": {
|
"type": {
|
||||||
@@ -2673,18 +2676,21 @@
|
|||||||
"name": "__sub__",
|
"name": "__sub__",
|
||||||
"type": {
|
"type": {
|
||||||
"_type": "FunctionType",
|
"_type": "FunctionType",
|
||||||
"pos_args": [
|
"params": {
|
||||||
{
|
"_type": "ParamSpec",
|
||||||
"name": null,
|
"pos": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"name": null,
|
||||||
"name": "Latitude"
|
"type": {
|
||||||
},
|
"_type": "NamedType",
|
||||||
"required": true
|
"name": "Latitude"
|
||||||
}
|
},
|
||||||
],
|
"required": true
|
||||||
"args": [],
|
}
|
||||||
"kw_args": [],
|
],
|
||||||
|
"mixed": [],
|
||||||
|
"kw": []
|
||||||
|
},
|
||||||
"returns": {
|
"returns": {
|
||||||
"_type": "GenericType",
|
"_type": "GenericType",
|
||||||
"type": {
|
"type": {
|
||||||
@@ -2713,18 +2719,21 @@
|
|||||||
"name": "__sub__",
|
"name": "__sub__",
|
||||||
"type": {
|
"type": {
|
||||||
"_type": "FunctionType",
|
"_type": "FunctionType",
|
||||||
"pos_args": [
|
"params": {
|
||||||
{
|
"_type": "ParamSpec",
|
||||||
"name": null,
|
"pos": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"name": null,
|
||||||
"name": "Longitude"
|
"type": {
|
||||||
},
|
"_type": "NamedType",
|
||||||
"required": true
|
"name": "Longitude"
|
||||||
}
|
},
|
||||||
],
|
"required": true
|
||||||
"args": [],
|
}
|
||||||
"kw_args": [],
|
],
|
||||||
|
"mixed": [],
|
||||||
|
"kw": []
|
||||||
|
},
|
||||||
"returns": {
|
"returns": {
|
||||||
"_type": "GenericType",
|
"_type": "GenericType",
|
||||||
"type": {
|
"type": {
|
||||||
@@ -2745,12 +2754,24 @@
|
|||||||
{
|
{
|
||||||
"_type": "PredicateStmt",
|
"_type": "PredicateStmt",
|
||||||
"name": "Positive",
|
"name": "Positive",
|
||||||
"subject": "v",
|
"params": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"_type": "ParamSpec",
|
||||||
"name": "float"
|
"pos": [],
|
||||||
},
|
"mixed": [
|
||||||
"condition": {
|
{
|
||||||
|
"name": "v",
|
||||||
|
"type": {
|
||||||
|
"_type": "NamedType",
|
||||||
|
"name": "float"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
"_type": "BinaryExpr",
|
"_type": "BinaryExpr",
|
||||||
"left": {
|
"left": {
|
||||||
"_type": "VariableExpr",
|
"_type": "VariableExpr",
|
||||||
@@ -2766,12 +2787,24 @@
|
|||||||
{
|
{
|
||||||
"_type": "PredicateStmt",
|
"_type": "PredicateStmt",
|
||||||
"name": "StrictlyPositive",
|
"name": "StrictlyPositive",
|
||||||
"subject": "v",
|
"params": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"_type": "ParamSpec",
|
||||||
"name": "float"
|
"pos": [],
|
||||||
},
|
"mixed": [
|
||||||
"condition": {
|
{
|
||||||
|
"name": "v",
|
||||||
|
"type": {
|
||||||
|
"_type": "NamedType",
|
||||||
|
"name": "float"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
"_type": "BinaryExpr",
|
"_type": "BinaryExpr",
|
||||||
"left": {
|
"left": {
|
||||||
"_type": "VariableExpr",
|
"_type": "VariableExpr",
|
||||||
@@ -2787,12 +2820,24 @@
|
|||||||
{
|
{
|
||||||
"_type": "PredicateStmt",
|
"_type": "PredicateStmt",
|
||||||
"name": "Equatorial",
|
"name": "Equatorial",
|
||||||
"subject": "loc",
|
"params": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"_type": "ParamSpec",
|
||||||
"name": "GeoLocation"
|
"pos": [],
|
||||||
},
|
"mixed": [
|
||||||
"condition": {
|
{
|
||||||
|
"name": "loc",
|
||||||
|
"type": {
|
||||||
|
"_type": "NamedType",
|
||||||
|
"name": "GeoLocation"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
"_type": "GroupingExpr",
|
"_type": "GroupingExpr",
|
||||||
"expr": {
|
"expr": {
|
||||||
"_type": "BinaryExpr",
|
"_type": "BinaryExpr",
|
||||||
@@ -2827,12 +2872,24 @@
|
|||||||
{
|
{
|
||||||
"_type": "PredicateStmt",
|
"_type": "PredicateStmt",
|
||||||
"name": "Arctic",
|
"name": "Arctic",
|
||||||
"subject": "loc",
|
"params": [
|
||||||
"type": {
|
{
|
||||||
"_type": "NamedType",
|
"_type": "ParamSpec",
|
||||||
"name": "GeoLocation"
|
"pos": [],
|
||||||
},
|
"mixed": [
|
||||||
"condition": {
|
{
|
||||||
|
"name": "loc",
|
||||||
|
"type": {
|
||||||
|
"_type": "NamedType",
|
||||||
|
"name": "GeoLocation"
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"kw": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"body": {
|
||||||
"_type": "GroupingExpr",
|
"_type": "GroupingExpr",
|
||||||
"expr": {
|
"expr": {
|
||||||
"_type": "BinaryExpr",
|
"_type": "BinaryExpr",
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "bool",
|
"base": "bool",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -25,7 +25,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "int",
|
"base": "int",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -36,7 +36,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "float",
|
"base": "float",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"constraint": "(_ > 0) + (_ < 250)"
|
"constraint": "(_ > 0) + (_ < 250)"
|
||||||
}
|
}
|
||||||
@@ -47,7 +47,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "str",
|
"base": "str",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -56,7 +56,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "datetime",
|
"base": "datetime",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -65,7 +65,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "float",
|
"base": "float",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -79,7 +79,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "_",
|
"base": "_",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "GeoLocation",
|
"base": "GeoLocation",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -28,11 +28,13 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Column",
|
"base": "Column",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "BaseType",
|
{
|
||||||
"base": "GeoLocation",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "GeoLocation",
|
||||||
}
|
"args": []
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -65,11 +67,13 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Column",
|
"base": "Column",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "BaseType",
|
{
|
||||||
"base": "GeoLocation",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "GeoLocation",
|
||||||
}
|
"args": []
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -117,7 +121,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Latitude",
|
"base": "Latitude",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -146,7 +150,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Latitude",
|
"base": "Latitude",
|
||||||
"param": null
|
"args": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -175,11 +179,13 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Difference",
|
"base": "Difference",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "BaseType",
|
{
|
||||||
"base": "Latitude",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "Latitude",
|
||||||
}
|
"args": []
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -217,7 +223,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "int",
|
"base": "int",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"constraint": "_ >= 0"
|
"constraint": "_ >= 0"
|
||||||
}
|
}
|
||||||
@@ -230,7 +236,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "float",
|
"base": "float",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"constraint": "_ >= 0"
|
"constraint": "_ >= 0"
|
||||||
}
|
}
|
||||||
@@ -252,7 +258,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "int",
|
"base": "int",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"constraint": "Positive"
|
"constraint": "Positive"
|
||||||
}
|
}
|
||||||
@@ -265,7 +271,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "float",
|
"base": "float",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"constraint": "Positive"
|
"constraint": "Positive"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,15 +14,17 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Column",
|
"base": "Column",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "ConstraintType",
|
{
|
||||||
"type": {
|
"_type": "ConstraintType",
|
||||||
"_type": "BaseType",
|
"type": {
|
||||||
"base": "float",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "float",
|
||||||
},
|
"args": []
|
||||||
"constraint": "0 <= _ <= 1"
|
},
|
||||||
}
|
"constraint": "0 <= _ <= 1"
|
||||||
|
}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"default": null
|
"default": null
|
||||||
},
|
},
|
||||||
@@ -31,15 +33,17 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Column",
|
"base": "Column",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "ConstraintType",
|
{
|
||||||
"type": {
|
"_type": "ConstraintType",
|
||||||
"_type": "BaseType",
|
"type": {
|
||||||
"base": "float",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "float",
|
||||||
},
|
"args": []
|
||||||
"constraint": "0 <= _ <= 1"
|
},
|
||||||
}
|
"constraint": "0 <= _ <= 1"
|
||||||
|
}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"default": null
|
"default": null
|
||||||
}
|
}
|
||||||
@@ -50,15 +54,17 @@
|
|||||||
"returns": {
|
"returns": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Column",
|
"base": "Column",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "ConstraintType",
|
{
|
||||||
"type": {
|
"_type": "ConstraintType",
|
||||||
"_type": "BaseType",
|
"type": {
|
||||||
"base": "float",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "float",
|
||||||
},
|
"args": []
|
||||||
"constraint": "0 <= _ <= 2"
|
},
|
||||||
}
|
"constraint": "0 <= _ <= 2"
|
||||||
|
}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"body": [
|
"body": [
|
||||||
{
|
{
|
||||||
@@ -67,15 +73,17 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "Column",
|
"base": "Column",
|
||||||
"param": {
|
"args": [
|
||||||
"_type": "ConstraintType",
|
{
|
||||||
"type": {
|
"_type": "ConstraintType",
|
||||||
"_type": "BaseType",
|
"type": {
|
||||||
"base": "float",
|
"_type": "BaseType",
|
||||||
"param": null
|
"base": "float",
|
||||||
},
|
"args": []
|
||||||
"constraint": "0 <= _ <= 2"
|
},
|
||||||
}
|
"constraint": "0 <= _ <= 2"
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -117,7 +125,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "int",
|
"base": "int",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"default": null
|
"default": null
|
||||||
}
|
}
|
||||||
@@ -128,7 +136,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "float",
|
"base": "float",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"default": null
|
"default": null
|
||||||
}
|
}
|
||||||
@@ -140,7 +148,7 @@
|
|||||||
"type": {
|
"type": {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": "str",
|
"base": "str",
|
||||||
"param": null
|
"args": []
|
||||||
},
|
},
|
||||||
"default": null
|
"default": null
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class GeneratorTester(Tester):
|
|||||||
typed_ast: TypedAST = checker.type_check(path)
|
typed_ast: TypedAST = checker.type_check(path)
|
||||||
|
|
||||||
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
|
if not any(d.type == DiagnosticType.ERROR for d in checker.diagnostics):
|
||||||
generator = Generator(workdir=path.parent)
|
generator = Generator(workdir=path.parent, types=checker.types)
|
||||||
result.compiled_ast = generator.generate_ast(typed_ast, path)
|
result.compiled_ast = generator.generate_ast(typed_ast, path)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -2,11 +2,13 @@ from typing import Optional, Sequence
|
|||||||
|
|
||||||
from midas.ast.midas import (
|
from midas.ast.midas import (
|
||||||
BinaryExpr,
|
BinaryExpr,
|
||||||
|
CallExpr,
|
||||||
ComplexType,
|
ComplexType,
|
||||||
ConstraintType,
|
ConstraintType,
|
||||||
Expr,
|
Expr,
|
||||||
ExtendStmt,
|
ExtendStmt,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
|
FrameType,
|
||||||
FunctionType,
|
FunctionType,
|
||||||
GenericType,
|
GenericType,
|
||||||
GetExpr,
|
GetExpr,
|
||||||
@@ -15,6 +17,7 @@ from midas.ast.midas import (
|
|||||||
LogicalExpr,
|
LogicalExpr,
|
||||||
MemberStmt,
|
MemberStmt,
|
||||||
NamedType,
|
NamedType,
|
||||||
|
ParamSpec,
|
||||||
PredicateStmt,
|
PredicateStmt,
|
||||||
Stmt,
|
Stmt,
|
||||||
Type,
|
Type,
|
||||||
@@ -78,9 +81,8 @@ class MidasAstJsonSerializer(
|
|||||||
return {
|
return {
|
||||||
"_type": "PredicateStmt",
|
"_type": "PredicateStmt",
|
||||||
"name": stmt.name.lexeme,
|
"name": stmt.name.lexeme,
|
||||||
"subject": stmt.subject.lexeme,
|
"params": [self._serialize_param_spec(spec) for spec in stmt.params],
|
||||||
"type": stmt.type.accept(self),
|
"body": stmt.body.accept(self),
|
||||||
"condition": stmt.condition.accept(self),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
def visit_logical_expr(self, expr: LogicalExpr) -> dict:
|
||||||
@@ -106,6 +108,14 @@ class MidasAstJsonSerializer(
|
|||||||
"right": expr.right.accept(self),
|
"right": expr.right.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_call_expr(self, expr: CallExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "CallExpr",
|
||||||
|
"callee": expr.callee.accept(self),
|
||||||
|
"arguments": self._serialize_list(expr.arguments),
|
||||||
|
"keywords": {name: arg.accept(self) for name, arg in expr.keywords.items()},
|
||||||
|
}
|
||||||
|
|
||||||
def visit_get_expr(self, expr: GetExpr) -> dict:
|
def visit_get_expr(self, expr: GetExpr) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "GetExpr",
|
"_type": "GetExpr",
|
||||||
@@ -163,15 +173,21 @@ class MidasAstJsonSerializer(
|
|||||||
def visit_function_type(self, type: FunctionType) -> dict:
|
def visit_function_type(self, type: FunctionType) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "FunctionType",
|
"_type": "FunctionType",
|
||||||
"pos_args": [self._serialize_func_arg(arg) for arg in type.pos_args],
|
"params": self._serialize_param_spec(type.params),
|
||||||
"args": [self._serialize_func_arg(arg) for arg in type.args],
|
|
||||||
"kw_args": [self._serialize_func_arg(arg) for arg in type.kw_args],
|
|
||||||
"returns": type.returns.accept(self),
|
"returns": type.returns.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _serialize_param_spec(self, spec: ParamSpec) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "ParamSpec",
|
||||||
|
"pos": [self._serialize_func_arg(arg) for arg in spec.pos],
|
||||||
|
"mixed": [self._serialize_func_arg(arg) for arg in spec.mixed],
|
||||||
|
"kw": [self._serialize_func_arg(arg) for arg in spec.kw],
|
||||||
|
}
|
||||||
|
|
||||||
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
def _serialize_func_arg(self, arg: FunctionType.Argument) -> dict:
|
||||||
return {
|
return {
|
||||||
"name": arg.name,
|
"name": arg.name.lexeme if arg.name is not None else None,
|
||||||
"type": arg.type.accept(self),
|
"type": arg.type.accept(self),
|
||||||
"required": arg.required,
|
"required": arg.required,
|
||||||
}
|
}
|
||||||
@@ -182,3 +198,15 @@ class MidasAstJsonSerializer(
|
|||||||
"base": type.base.accept(self),
|
"base": type.base.accept(self),
|
||||||
"extension": type.extension.accept(self),
|
"extension": type.extension.accept(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_frame_type(self, type: FrameType) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "FrameType",
|
||||||
|
"columns": [self._serialize_column(col) for col in type.columns],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _serialize_column(self, column: FrameType.Column):
|
||||||
|
return {
|
||||||
|
"name": column.name.lexeme,
|
||||||
|
"type": column.type.accept(self),
|
||||||
|
}
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from midas.ast.python import (
|
|||||||
Stmt,
|
Stmt,
|
||||||
SubscriptExpr,
|
SubscriptExpr,
|
||||||
TernaryExpr,
|
TernaryExpr,
|
||||||
|
TupleExpr,
|
||||||
TypeAssign,
|
TypeAssign,
|
||||||
UnaryExpr,
|
UnaryExpr,
|
||||||
VariableExpr,
|
VariableExpr,
|
||||||
@@ -98,7 +99,7 @@ class PythonAstJsonSerializer(
|
|||||||
return {
|
return {
|
||||||
"_type": "BaseType",
|
"_type": "BaseType",
|
||||||
"base": node.base,
|
"base": node.base,
|
||||||
"param": self._serialize_optional(node.param),
|
"args": self._serialize_list(node.args),
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_constraint_type(self, node: ConstraintType) -> dict:
|
def visit_constraint_type(self, node: ConstraintType) -> dict:
|
||||||
@@ -263,6 +264,7 @@ class PythonAstJsonSerializer(
|
|||||||
"_type": "CastExpr",
|
"_type": "CastExpr",
|
||||||
"type": expr.type.accept(self),
|
"type": expr.type.accept(self),
|
||||||
"expr": expr.expr.accept(self),
|
"expr": expr.expr.accept(self),
|
||||||
|
"unsafe": expr.unsafe,
|
||||||
}
|
}
|
||||||
|
|
||||||
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
|
def visit_ternary_expr(self, expr: TernaryExpr) -> dict:
|
||||||
@@ -301,6 +303,12 @@ class PythonAstJsonSerializer(
|
|||||||
"step": self._serialize_optional(expr.step),
|
"step": self._serialize_optional(expr.step),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def visit_tuple_expr(self, expr: TupleExpr) -> dict:
|
||||||
|
return {
|
||||||
|
"_type": "TupleExpr",
|
||||||
|
"items": [item.accept(self) for item in expr.items],
|
||||||
|
}
|
||||||
|
|
||||||
def visit_raw_expr(self, expr: RawExpr) -> dict:
|
def visit_raw_expr(self, expr: RawExpr) -> dict:
|
||||||
return {
|
return {
|
||||||
"_type": "RawExpr",
|
"_type": "RawExpr",
|
||||||
|
|||||||
Reference in New Issue
Block a user