Compare commits

64 Commits

Author SHA1 Message Date
9f59366289 feat(gen): generate asserts for dataframes and columns 2026-06-26 14:56:15 +02:00
fd0b410d74 fix(checker): change heterogeneous errors to warnings 2026-06-26 11:55:31 +02:00
5b0c5c01ad feat(checker): add mean method on frames 2026-06-26 11:21:38 +02:00
43e40396a1 fix(checker): type check None literal 2026-06-26 11:21:17 +02:00
0d265ef24c feat(checker): lookup dunders on dataframes 2026-06-26 10:35:50 +02:00
88c56c9d15 tests: update with reordered argument typing 2026-06-26 10:28:12 +02:00
d1c217a335 refactor: use metaclass to collect frame methods 2026-06-25 22:31:59 +02:00
5b3e87afcb refactor: add MethodResolver class 2026-06-25 22:14:25 +02:00
894d5a7196 feat: add dummy classes for typing frames and columns 2026-06-25 21:35:47 +02:00
eb809c6341 fix(checker): improve heterogeneous error message 2026-06-25 21:35:19 +02:00
bd68d1003f feat(checker): lookup dataframe methods 2026-06-25 21:34:59 +02:00
72c9236650 feat(checker): defined add method of dataframes 2026-06-25 21:34:00 +02:00
90051c7981 feat(checker): add structural subtyping rule for dataframes 2026-06-25 21:09:14 +02:00
dd1e2e693c feat(cli): print context for multiline diagnostics 2026-06-25 16:32:15 +02:00
78e10e0895 feat(checker): process frame type definitions 2026-06-24 14:36:53 +02:00
c81e4a9560 feat(cli): add frame type to highlighter 2026-06-24 14:36:53 +02:00
6d0cf1a055 feat(parser): add frame type to midas syntax 2026-06-24 14:36:52 +02:00
cc5e7af143 feat(gen): add support for tuples and dataframes 2026-06-24 14:36:51 +02:00
3bdbc80079 feat(checker): handle setting dataframe column 2026-06-24 14:36:51 +02:00
c1b5284f72 feat(checker): type check subscript on dataframes 2026-06-24 14:36:28 +02:00
5e9ccd4e13 feat(types): add TupleType 2026-06-24 14:36:04 +02:00
cf083fc0c3 fix(types): add str methods to dataframe types 2026-06-24 14:35:31 +02:00
a80da5db2c feat(types): add DataFrameType and ColumnType 2026-06-24 14:35:30 +02:00
f7c43837b5 Merge pull request 'CLI tweaks' (#22) from fix/cli-tweaks into main
Reviewed-on: #22
2026-06-24 12:18:07 +00:00
32ed62a6f1 fix(cli): show summary of diagnostic counts 2026-06-24 14:11:39 +02:00
66f39acec0 fix(cli): show all diagnostics in types command
combine type checker diagnostics with judgements info diagnostics
2026-06-24 14:11:15 +02:00
6c04e2fee4 feat(cli): add compile option to ignore errors 2026-06-24 14:10:30 +02:00
2bb2e0a684 Merge pull request 'Unsafe cast' (#21) from feat/unsafe-cast into main
Reviewed-on: #21
2026-06-24 12:00:03 +00:00
5630320d21 chore: use unsafe_cast in demo script 2026-06-24 13:57:38 +02:00
9f05ba3224 feat: handle unsafe casts 2026-06-24 13:51:14 +02:00
5fbe965919 feat(checker): add typing submodule with cast functions 2026-06-24 13:40:23 +02:00
252a5abdfd Merge pull request 'Static evalution of casts on literals' (#20) from feat/literal-static-constraints into main
Reviewed-on: #20
2026-06-24 09:32:54 +00:00
55fba6a088 tests: update test without evaluated casts 2026-06-24 11:28:44 +02:00
70ce263ea2 feat(gen): skip assertions for evaluated casts
avoid generating a runtime assertion for a cast which has already been checked statically
2026-06-24 11:28:43 +02:00
e1d5eac8b8 feat(checker): evaluate constraints statically on literals 2026-06-24 11:10:09 +02:00
82666a4918 feat(checker): add evaluator
add an evaluator class to evaluate expressions using literal values
2026-06-24 11:08:15 +02:00
45f84a2f23 feat(checker): add debug diagnostics 2026-06-24 11:07:42 +02:00
dedfcb4dbb feat(checker): store builtin python functions in preamble 2026-06-24 11:05:36 +02:00
d9ea6365ea tests: update with cast expression judgement 2026-06-23 16:49:38 +02:00
9c7a93412c Merge pull request 'Fixes and small demo' (#19) from feat/demonstration into main
Reviewed-on: #19
2026-06-23 08:15:56 +00:00
d6b8fbfb60 chore: improve demo example 2026-06-23 10:03:24 +02:00
b290c59ac4 fix(gen): add bases for ConstraintType and TypeVar 2026-06-23 00:25:43 +02:00
093f2bc477 fix(checker): lookup member on typevar bound 2026-06-23 00:24:37 +02:00
7c771c4070 feat(checker): add input function to preamble 2026-06-23 00:22:38 +02:00
a50a207385 fix(gen): don't generate stubs for builtin types 2026-06-22 15:40:31 +02:00
7e5ea5e414 chore: add example to demonstrate some features 2026-06-22 15:29:39 +02:00
0ba0266bae fix(checker): check general subtype case for AppliedType
this adds the case where we check whether AppliedType <: Type, and delegates to the body

this may not be a legitimate rule, or may need to be refined
2026-06-22 15:27:06 +02:00
216c80f08c fix(checker): produce judgement for expression in cast 2026-06-22 15:24:51 +02:00
f75d7722a1 fix(checker): look up members on constraint type 2026-06-22 15:24:18 +02:00
2f29c47274 fix(gen): assert type var bound 2026-06-22 15:23:53 +02:00
80af2b9048 fix(checker): handle is_subtype of TypeVar 2026-06-22 14:44:51 +02:00
577454ee7e fix(checker): make UnknownType a top type for subtyping 2026-06-22 14:15:18 +02:00
878693383e feat(cli): add watch option to stubs command 2026-06-22 14:14:05 +02:00
0b91de75a8 feat(checker): handle type vars in python functions 2026-06-22 14:13:25 +02:00
739871c101 Merge pull request 'Generic call unification' (#18) from feat/unification into main
Reviewed-on: #18
2026-06-21 11:41:48 +00:00
4395e9339b fix(checker): abort unification on conflict 2026-06-21 13:36:07 +02:00
29e601128d tests: add unification test 2026-06-21 13:19:17 +02:00
b591f5508f fix(checker): make map definition generic 2026-06-21 13:17:35 +02:00
41d0c84bbe feat(checker): add unifier
add unifier class to infer type parameters from local call context
2026-06-21 13:12:27 +02:00
cccf2f8f9f Merge pull request 'Stubs generator' (#17) from feat/stubs-gen into main
Reviewed-on: #17
2026-06-20 15:44:34 +00:00
3f48c2138f chore: add stubs command to README 2026-06-20 17:44:15 +02:00
e4ab27673d fix(gen): handle TypeVar variance in stubs generator 2026-06-20 17:34:40 +02:00
b02ecc6326 fix(gen): handle ConstraintType in stubs generator 2026-06-20 17:34:22 +02:00
9e83079910 fix(cli): add missing methods to highlighter 2026-06-20 17:23:18 +02:00
44 changed files with 3200 additions and 525 deletions

View File

@@ -18,6 +18,7 @@ This framework is being developed as part of a Bachelor's Thesis by Louis Herede
- [Highlighting](#highlighting)
- [Dumping the AST](#dumping-the-ast)
- [Dumping the Registry](#dumping-the-registry)
- [Generating Stubs](#generating-stubs)
- [Showing Type Judgements](#showing-type-judgements)
- [Validating Definitions](#validating-definitions)
- [Tests](#tests)
@@ -116,6 +117,14 @@ midas dump-registry -t types.midas
This command processes the given Midas definitions and dumps the contents of the types registry.
### Generating Stubs
```shell
midas stubs types.midas -o stubs.pyi
```
This command generate Python stubs from a Midas definition file
### Showing Type Judgements
```shell

View File

@@ -0,0 +1,15 @@
predicate in_range(min: float, max: float)(v: float) = min <= v & v <= max
predicate is_ratio = in_range(0, 1)
type Currency = float
type Price[T <: Currency] = T where _ >= 0
extend Price[T <: Currency] {
def __add__: fn(Price[T], /) -> Price[T]
}
type EUR = Currency
type USD = Currency
type CHF = Currency
type Discount = float where is_ratio(_)

View File

@@ -0,0 +1,35 @@
from typing import TypeVar
from demo_stubs import CHF, EUR, USD, Currency, Discount, Price
from midas.typing import cast, unsafe_cast
T = TypeVar("T", bound=Currency)
def apply_discount(amount: Price[T], discount: Discount) -> Price[T]:
return cast(Price[T], (1.0 - discount) * amount)
a1 = cast(Price[EUR], 3.2)
a2 = cast(Price[USD], 10.4)
r1 = cast(Discount, 0.2)
print(apply_discount(a1, r1))
print(apply_discount(a2, r1))
a3 = a1 + a1
a4 = a1 + a2 # cannot add euros and dollars
a3 = a2 # cannot change variable type
dyn_price = float(input("Price (CHF): "))
dyn_discount = float(input("Discount (0.0-1.0): "))
discounted = apply_discount(
cast(Price[CHF], dyn_price),
cast(Discount, dyn_discount),
)
print(f"Discounted: CHF {discounted}")
large_data = [i * 10 for i in range(100)]
prices = unsafe_cast(list[Price[EUR]], large_data)

View File

@@ -0,0 +1,14 @@
from __future__ import annotations
from typing import Generic, TypeVar
class Currency(float): ...
_T0 = TypeVar("_T0", bound=Currency, covariant=True)
class Price(Currency, Generic[_T0]):
def __add__(self, _0: Price[_T0], /) -> Price[_T0]: ...
class EUR(Currency): ...
class USD(Currency): ...
class CHF(Currency): ...
class Discount(float): ...

View File

@@ -152,4 +152,14 @@ class FunctionType:
required: bool
class FrameType:
columns: list[Column]
@dataclass(frozen=True, kw_only=True)
class Column:
location: Optional[Location] = None
name: Token
type: Type
###<

View File

@@ -145,6 +145,7 @@ class LogicalExpr:
class CastExpr:
type: MidasType
expr: Expr
unsafe: bool
class TernaryExpr:

View File

@@ -253,6 +253,9 @@ class Type(ABC):
@abstractmethod
def visit_function_type(self, type: FunctionType) -> T: ...
@abstractmethod
def visit_frame_type(self, type: FrameType) -> T: ...
@dataclass(frozen=True)
class NamedType(Type):
@@ -311,3 +314,17 @@ class FunctionType(Type):
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_function_type(self)
@dataclass(frozen=True)
class FrameType(Type):
columns: list[Column]
@dataclass(frozen=True, kw_only=True)
class Column:
location: Optional[Location] = None
name: Token
type: Type
def accept(self, visitor: Type.Visitor[T]) -> T:
return visitor.visit_frame_type(self)

View File

@@ -350,6 +350,25 @@ class MidasAstPrinter(
arg.type.accept(self)
self._write_line(f"required: {arg.required}", last=True)
def visit_frame_type(self, type: m.FrameType) -> None:
self._write_line("FrameType")
with self._child_level(single=True):
self._write_line("columns")
with self._child_level():
for i, column in enumerate(type.columns):
self._idx = i
if i == len(type.columns) - 1:
self._mark_last()
self._print_frame_column(column)
def _print_frame_column(self, column: m.FrameType.Column) -> None:
self._write_line("Column")
with self._child_level():
self._write_line(f'name: "{column.name.lexeme}"')
self._write_line("type")
with self._child_level(single=True):
column.type.accept(self)
class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]):
def __init__(self, indent: int = 4):
@@ -502,6 +521,23 @@ class MidasPrinter(m.Expr.Visitor[str], m.Stmt.Visitor[str], m.Type.Visitor[str]
res += "?"
return res
def visit_frame_type(self, type: m.FrameType) -> str:
res: str = self.indented("Frame[")
if len(type.columns) != 0:
res += "\n"
self.level += 1
columns: list[str] = []
for column in type.columns:
columns.append(self.indented(self._print_frame_column(column)))
res += ",\n".join(columns)
self.level -= 1
res += "\n"
res += "]"
return res
def _print_frame_column(self, column: m.FrameType.Column) -> str:
return f"{column.name.lexeme}: {column.type.accept(self)}"
class PythonAstPrinter(
AstPrinter,
@@ -757,9 +793,10 @@ class PythonAstPrinter(
self._write_line("type")
with self._child_level(single=True):
expr.type.accept(self)
self._write_line("expr", last=True)
self._write_line("expr")
with self._child_level(single=True):
expr.expr.accept(self)
self._write_line(f"unsafe: {expr.unsafe}", last=True)
def visit_ternary_expr(self, expr: p.TernaryExpr) -> None:
self._write_line("TernaryExpr")

View File

@@ -350,6 +350,7 @@ class LogicalExpr(Expr):
class CastExpr(Expr):
type: MidasType
expr: Expr
unsafe: bool
def accept(self, visitor: Expr.Visitor[T]) -> T:
return visitor.visit_cast_expr(self)

View File

@@ -15,7 +15,7 @@ if TYPE_CHECKING:
BUILTIN_SUBTYPES: dict[str, set[str]] = {
"object": {"float", "list", "dict"},
"object": {"float", "list", "dict", "str"},
"float": {"int"},
"int": {"bool"},
}

View File

@@ -9,6 +9,7 @@ class DiagnosticType(StrEnum):
ERROR = "Error"
WARNING = "Warning"
INFO = "Info"
DEBUG = "Debug"
@dataclass(frozen=True)

172
midas/checker/evaluator.py Normal file
View File

@@ -0,0 +1,172 @@
from dataclasses import dataclass
from typing import Any, Callable, Optional
import midas.ast.midas as m
from midas.checker.preamble import Preamble
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import Function, Predicate
from midas.lexer.token import TokenType
@dataclass(frozen=True, kw_only=True)
class PartialPredicate(Predicate):
scope: dict[str, Any]
class Evaluator(m.Expr.Visitor[Any]):
def __init__(self, types: TypesRegistry, reporter: Optional[FileReporter] = None):
self.types: TypesRegistry = types
self.reporter: Optional[FileReporter] = reporter
self.preamble: Preamble = Preamble(self.types)
self.scopes: list[dict[str, Any]] = [{}]
def evaluate(self, expr: m.Expr) -> Any:
value: Any = expr.accept(self)
if self.reporter is not None:
self.reporter.debug(expr.location, f"Value: {value}")
return value
def get_value(self, name: str) -> Any:
scope: dict[str, Any] = self.scopes[-1]
return scope[name]
def set_value(self, name: str, value: Any, force_declare: bool = False):
if not force_declare:
for scope in reversed(self.scopes):
if name in scope:
scope[name] = value
return
self.scopes[-1][name] = value
def visit_logical_expr(self, expr: m.LogicalExpr) -> Any:
def left():
return self.evaluate(expr.left)
def right():
return self.evaluate(expr.right)
match expr.operator.type:
case TokenType.AND:
return left() and right()
case _:
raise NotImplementedError
def visit_binary_expr(self, expr: m.BinaryExpr) -> Any:
left: Any = self.evaluate(expr.left)
right: Any = self.evaluate(expr.right)
match expr.operator.type:
case TokenType.MINUS:
return left - right
case TokenType.STAR:
return left * right
case TokenType.SLASH:
return left / right
case TokenType.GREATER:
return left > right
case TokenType.GREATER_EQUAL:
return left >= right
case TokenType.LESS:
return left < right
case TokenType.LESS_EQUAL:
return left <= right
case TokenType.EQUAL_EQUAL:
return left == right
case TokenType.BANG_EQUAL:
return left != right
case _:
raise NotImplementedError
def visit_unary_expr(self, expr: m.UnaryExpr) -> Any:
right: Any = self.evaluate(expr.right)
match expr.operator.type:
case TokenType.MINUS:
return -right
case _:
raise NotImplementedError
def visit_call_expr(self, expr: m.CallExpr) -> Any:
callee: Any = self.evaluate(expr.callee)
args: list[Any] = [self.evaluate(arg) for arg in expr.arguments]
kwargs: dict[str, Any] = {
name: self.evaluate(arg) for name, arg in expr.keywords.items()
}
match callee:
case Predicate():
return self._evaluate_predicate(callee, args, kwargs)
case _ if callable(callee):
return callee(*args, **kwargs)
case _:
return NotImplementedError
def visit_get_expr(self, expr: m.GetExpr) -> Any:
obj: Any = self.evaluate(expr.expr)
return getattr(obj, expr.name.lexeme)
def visit_variable_expr(self, expr: m.VariableExpr) -> Any:
name: str = expr.name.lexeme
for scope in reversed(self.scopes):
if name in scope:
return scope[name]
predicate: Optional[Predicate] = self.types.lookup_predicate(name)
if predicate is not None:
if predicate.alias:
return self.evaluate(predicate.body)
return predicate
glob: Optional[Callable] = self.preamble.get_py_func(name)
if glob is not None:
return glob
raise NameError(f"Unknown variable '{name}'")
def visit_grouping_expr(self, expr: m.GroupingExpr) -> Any:
return self.evaluate(expr.expr)
def visit_literal_expr(self, expr: m.LiteralExpr) -> Any:
return expr.value
def visit_wildcard_expr(self, expr: m.WildcardExpr) -> Any:
return self.get_value("_")
def _evaluate_predicate(
self, predicate: Predicate, args: list[Any], kwargs: dict[str, Any]
) -> Any:
res: Any = None
if isinstance(predicate, PartialPredicate):
self.scopes.append(predicate.scope)
else:
self.scopes.append({})
match predicate.type:
case Function(returns=Function() as inner):
self._map_args(predicate.type, args, kwargs)
res = PartialPredicate(
type=inner,
body=predicate.body,
alias=False,
scope=self.scopes[-1],
)
case Function():
self._map_args(predicate.type, args, kwargs)
res = self.evaluate(predicate.body)
case _:
raise NotImplementedError
self.scopes.pop()
return res
def _map_args(self, function: Function, args: list[Any], kwargs: dict[str, Any]):
positional: list[Function.Argument] = function.pos_args + function.args
keywords: dict[str, Function.Argument] = {
arg.name: arg for arg in function.args + function.kw_args
}
for i, arg in enumerate(args):
param: Function.Argument = positional[i]
self.set_value(param.name, arg)
for name, arg in kwargs.items():
param: Function.Argument = keywords[name]
self.set_value(param.name, arg)

View File

@@ -0,0 +1,198 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional
from midas.ast.location import Location
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import (
ColumnType,
DataFrameType,
Function,
OverloadedFunction,
TopType,
Type,
UnknownType,
unfold_type,
)
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
@staticmethod
def frame_method(*names: str):
def wrapper(func):
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
frame: DataFrameType
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
class _MethodRegistryMeta(type):
_methods: dict[str, Callable] = {}
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
):
new_class = super().__new__(cls, name, bases, namespace)
new_class._methods = {}
for attr in namespace.values():
if callable(attr) and hasattr(attr, "__method_names__"):
for name in attr.__method_names__: # type: ignore
new_class._methods[name] = attr
return new_class
class MethodRegistry(metaclass=_MethodRegistryMeta):
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
@property
def reporter(self) -> FileReporter:
return self.typer.reporter
@property
def types(self) -> TypesRegistry:
return self.typer.types
def call(
self,
method: str,
call: Call,
) -> Type:
func: Optional[Callable] = self._methods.get(method)
if func is None:
self.reporter.error(call.location, f"Unknown method {method}")
return UnknownType()
return func(self, call)
@frame_method("add", "__add__")
def add(
self,
call: Call,
) -> Type:
# TODO: support add with scalar, sequence, Series, dict
# TODO: check operation exists on inner column types
new_columns: list[DataFrameType.Column] = []
by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None
if len(call.positional) != 0:
other: Type = call.positional[0][1]
unfolded_other: Type = unfold_type(other)
if isinstance(unfolded_other, DataFrameType):
frame2 = unfolded_other
by_name = {
col.name: col for col in frame2.columns if col.name is not None
}
in_frame1: set[str] = set()
for column in call.frame.columns:
if column.name is not None:
in_frame1.add(column.name)
col_type1: Type = column.type
col_type: Type = ColumnType(type=UnknownType())
if column.name in by_name:
column2 = by_name[column.name]
col_type2: Type = column2.type
if self.types.are_equivalent(col_type2, col_type1):
col_type = col_type1
new_column = DataFrameType.Column(
index=column.index,
name=column.name,
type=col_type,
)
new_columns.append(new_column)
if frame2 is not None:
for column in frame2.columns:
if column.name in in_frame1:
continue
new_columns.append(
DataFrameType.Column(
index=len(new_columns),
name=column.name,
type=ColumnType(type=UnknownType()),
)
)
signature = Function(
args=[
Function.Argument(
pos=0,
name="other",
type=DataFrameType(columns=[]),
required=True,
),
],
returns=DataFrameType(columns=new_columns),
)
return (
self.typer._get_call_result(
location=call.location,
callee=signature,
positional=call.positional,
keywords=call.keywords,
)
or UnknownType()
)
@frame_method()
def mean(self, call: Call) -> Type:
with_axis = Function(
kw_args=[
Function.Argument(
pos=0,
name="axis",
type=self.types.get_type("int"),
required=False,
)
],
returns=ColumnType(type=TopType()),
)
without_axis = Function(
kw_args=[
Function.Argument(
pos=0,
name="axis",
type=self.types.get_type("None"),
required=True,
)
],
returns=TopType(),
)
overload = OverloadedFunction(
overloads=[
with_axis,
without_axis,
]
)
return (
self.typer._get_call_result(
location=call.location,
callee=overload,
positional=call.positional,
keywords=call.keywords,
)
or UnknownType()
)

154
midas/checker/frames.py Normal file
View File

@@ -0,0 +1,154 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, TypeGuard, cast
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frame_methods import Call, MethodRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
return all(isinstance(expr, p.LiteralExpr) for expr in exprs)
class FrameManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: MethodRegistry = MethodRegistry(self.typer)
def assign(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
value_type: Type,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
return self.assign_column(reporter, location, frame, name, value_type)
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(idx, str) for idx in indices
):
raise NotImplementedError
case _:
reporter.error(location, f"Invalid index type {index} on {frame}")
return UnknownType()
def assign_column(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
name: str,
type: Type,
) -> Type:
if not isinstance(type, ColumnType):
reporter.error(
location,
f"Cannot assign {type} to dataframe column. Must be a ColumnType",
)
return frame
return self._set_column(frame, name, type)
def get(
self,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
index: p.Expr,
) -> Type:
match index:
case p.LiteralExpr(value=str() as name):
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
return column
case p.ListExpr(items=indices) if is_list_of_literals(indices) and all(
isinstance(index.value, str) for index in indices
):
names: list[str] = [cast(str, index.value) for index in indices]
columns: list[ColumnType] = []
for name in names:
column: Optional[ColumnType] = FrameManager._get_column(frame, name)
if column is None:
reporter.error(location, f"Unknown column '{name}' on {frame}")
return UnknownType()
columns.append(column)
return TupleType(items=tuple(columns))
case _:
reporter.error(location, f"Invalid index type {index} on {frame}")
return UnknownType()
@classmethod
def _set_column(
cls, frame: DataFrameType, name: str, column: ColumnType
) -> DataFrameType:
new_columns: list[DataFrameType.Column] = []
index: int = len(frame.columns)
replace: bool = False
for i, col in enumerate(frame.columns):
if col.name == name:
index = i
replace = True
# TODO: check column type here to prevent changing it
new_columns.append(col)
new_col: DataFrameType.Column = DataFrameType.Column(
index=index,
name=name,
type=column,
)
if replace:
new_columns[index] = new_col
else:
new_columns.append(new_col)
return DataFrameType(columns=new_columns)
@classmethod
def _set_columns(
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
) -> DataFrameType:
for name, col in zip(names, columns):
frame = cls._set_column(frame, name, col)
return frame
@classmethod
def _get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
for col in frame.columns:
if col.name == name:
return col.type
return None
@classmethod
def _get_columns(
cls, frame: DataFrameType, names: list[str]
) -> list[Optional[ColumnType]]:
return [cls._get_column(frame, name) for name in names]
def call(
self,
method: str,
location: Location,
frame: DataFrameType,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
frame=frame,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)

View File

@@ -14,8 +14,10 @@ from midas.checker.reporter import FileReporter, Reporter
from midas.checker.types import (
AliasType,
AppliedType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
ExtensionType,
Function,
GenericType,
@@ -401,6 +403,18 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
kw=[process_arg(arg, i + n_pos + n_mixed) for i, arg in enumerate(spec.kw)],
)
def visit_frame_type(self, type: m.FrameType) -> Type:
def process_column(i: int, col: m.FrameType.Column) -> DataFrameType.Column:
return DataFrameType.Column(
index=i,
name=col.name.lexeme,
type=ColumnType(type=col.type.accept(self)),
)
return DataFrameType(
columns=[process_column(i, col) for i, col in enumerate(type.columns)]
)
def _resolve_type_params(self, params: list[m.TypeParam]):
vars: list[TypeVar] = []
for param in params:

View File

@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Callable, Optional
from midas.checker.environment import Environment
from midas.checker.registry import TypesRegistry
@@ -16,16 +17,18 @@ class Preamble(Environment):
def __init__(self, types: TypesRegistry) -> None:
super().__init__()
self._types: TypesRegistry = types
self._python_funcs: dict[str, Callable] = {}
self._def_type_constructor("object")
self._def_type_constructor("float")
self._def_type_constructor("int")
self._def_type_constructor("bool")
self._def_type_constructor("str")
self._def_type_constructor("object", object)
self._def_type_constructor("float", float)
self._def_type_constructor("int", int)
self._def_type_constructor("bool", bool)
self._def_type_constructor("str", str)
self._def_function(
name="list",
pos=[Param("object", TopType())],
returns=self._list_of(TopType()),
py_function=list,
)
# TODO: use sink
@@ -33,6 +36,7 @@ class Preamble(Environment):
name="print",
pos=[Param("object", TopType())],
returns=UnitType(),
py_function=print,
)
map_in = TypeVar(name="T", bound=None)
@@ -52,17 +56,25 @@ class Preamble(Environment):
),
],
returns=self._list_of(map_out), # TODO: replace with Iterable[U]
type_vars=[map_in, map_out],
py_function=map,
)
self._def_function(
name="input",
pos=[Param("prompt", TopType(), required=False)],
returns=self._types.get_type("str"),
)
def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type])
def _def_type_constructor(self, name: str):
def _def_type_constructor(self, name: str, py_function: Optional[Callable] = None):
# TODO: more specific arg types
self._def_function(
name=name,
pos=[Param("object", TopType(), required=False)],
returns=self._types.get_type(name),
py_function=py_function,
)
def _make_function(
@@ -109,6 +121,7 @@ class Preamble(Environment):
kw: list[Param] = [],
returns: Type = UnitType(),
type_vars: list[TypeVar] = [],
py_function: Optional[Callable] = None,
):
function: Type = self._make_function(
name=name,
@@ -119,3 +132,8 @@ class Preamble(Environment):
type_vars=type_vars,
)
self.define(name, function)
if py_function is not None:
self._python_funcs[name] = py_function
def get_py_func(self, name: str) -> Optional[Callable]:
return self._python_funcs.get(name)

View File

@@ -1,11 +1,14 @@
import ast
import logging
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.environment import Environment
from midas.checker.evaluator import Evaluator
from midas.checker.frames import FrameManager
from midas.checker.operators import (
PY_COMPARATOR_METHODS,
PY_OPERATOR_METHODS,
@@ -18,13 +21,22 @@ from midas.checker.resolver import Resolver
from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnType,
ConstraintType,
DataFrameType,
Function,
GenericType,
OverloadedFunction,
TupleType,
Type,
TypeVar,
UnitType,
UnknownType,
Variance,
unfold_type,
)
from midas.checker.unifier import Unifier
from midas.parser.python import PythonParser
from midas.utils import TypedAST
@@ -35,6 +47,10 @@ class ReturnException(Exception):
pass
class UndefinedMethodException(Exception):
pass
@dataclass(frozen=True, kw_only=True)
class MappedArgument:
expr: p.Expr
@@ -63,10 +79,12 @@ class PythonTyper(
self.logger: logging.Logger = logging.getLogger("PythonTyper")
self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types
self.frame_mgr: FrameManager = FrameManager(self)
self.global_env: Environment = Preamble(self.types)
self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = []
self.evaluated_casts: list[p.CastExpr] = []
def process(self, source: str, path: Optional[str]) -> TypedAST:
self.reporter = self.reporter.for_file(path)
@@ -80,10 +98,15 @@ class PythonTyper(
self.env = self.global_env
self.locals = resolver.locals
self.judgements = []
self.evaluated_casts = []
self.check(stmts)
return TypedAST(stmts=stmts, judgements=self.judgements)
return TypedAST(
stmts=stmts,
judgements=self.judgements,
evaluated_casts=self.evaluated_casts,
)
def judge(self, expr: p.Expr, type: Type):
"""Record a typing judgement
@@ -176,6 +199,36 @@ class PythonTyper(
return self.env.get_at(distance, name)
return self.global_env.get(name)
def call_method(
self,
location: Location,
obj: Type,
method_name: str,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Optional[Type]:
unfolded: Type = unfold_type(obj)
match unfolded:
case DataFrameType():
return self.frame_mgr.call(
method=method_name,
location=location,
frame=unfolded,
positional=positional,
keywords=keywords,
)
method: Optional[Type] = self.types.lookup_member(obj, method_name)
if method is None:
raise UndefinedMethodException
return self._get_call_result(
location,
method,
positional,
keywords,
)
def is_subtype(self, type1: Type, type2: Type) -> bool:
return self.types.is_subtype(type1, type2)
@@ -227,7 +280,8 @@ class PythonTyper(
)
pos += 1
for arg in pos_args + args + kw_args:
all_args: list[Function.Argument] = pos_args + args + kw_args
for arg in all_args:
env.define(arg.name, arg.type)
returns_hint: Optional[Type] = None
@@ -268,12 +322,25 @@ class PythonTyper(
returns = inferred_return
# TODO: handle *args and **kwargs sinks
function: Function = Function(
function: Type = Function(
pos_args=pos_args,
args=args,
kw_args=kw_args,
returns=returns,
)
generic_params: list[TypeVar] = []
all_types: list[Type] = [arg.type for arg in all_args] + [returns]
for type in all_types:
if isinstance(type, TypeVar):
if type not in generic_params:
generic_params.append(type)
if len(generic_params) != 0:
function = GenericType(
name=stmt.name,
params=generic_params,
body=function,
)
self.env.define(stmt.name, function)
def visit_type_assign(self, stmt: p.TypeAssign) -> None:
@@ -291,9 +358,15 @@ class PythonTyper(
case p.VariableExpr():
self._assign_var(location, target, value_type)
# Allow any kind of object because we disallow creating new attributes
case p.GetExpr(object=object, name=name):
self._assign_attr(location, object, name, value_type)
# Only support variable expressions because modifying
# the underlying value would require reference types
case p.SubscriptExpr(object=p.VariableExpr() as var, index=index):
self._assign_sub(location, var, index, value_type)
case _:
if not isinstance(target, p.VariableExpr):
self.logger.warning(f"Unsupported assignment to {target}")
@@ -332,6 +405,27 @@ class PythonTyper(
f"Cannot assign {value_type} to member '{object_type}.{name}' of type {member}",
)
def _assign_sub(
self,
location: Location,
var: p.VariableExpr,
index: p.Expr,
value_type: Type,
):
var_type: Type = self.type_of(var)
# TODO: what happens if type is an alias of a dataframe type
match var_type:
case DataFrameType() as frame:
new_type: Type = self.frame_mgr.assign(
self.reporter, location, frame, index, value_type
)
self.env.assign(var.name, new_type)
case _:
self.reporter.error(
location,
f"Cannot assign {value_type} to index {index} of {var_type}",
)
def visit_return_stmt(self, stmt: p.ReturnStmt) -> None:
type: Type = self.type_of(stmt.value) if stmt.value is not None else UnitType()
self.env.return_types.append(type)
@@ -408,20 +502,16 @@ class PythonTyper(
left: Type = self.type_of(left_expr)
right: Type = self.type_of(right_expr)
operation: Optional[Type] = self.types.lookup_member(left, method)
if operation is None:
result: Optional[Type]
try:
result = self.call_method(location, left, method, [(right_expr, right)], {})
except UndefinedMethodException:
self.reporter.error(
location,
f"Undefined operation {method} between {left} and {right}",
)
return UnknownType()
result: Optional[Type] = self._get_call_result(
location,
operation,
[(right_expr, right)],
{},
)
return result or UnknownType()
def visit_unary_expr(self, expr: p.UnaryExpr) -> Type:
@@ -434,30 +524,45 @@ class PythonTyper(
return UnknownType()
operand: Type = self.type_of(expr.right)
operation: Optional[Type] = self.types.lookup_member(operand, method)
if operation is None:
result: Optional[Type]
try:
result = self.call_method(expr.location, operand, method, [], {})
except UndefinedMethodException:
self.reporter.error(
expr.location,
f"Undefined operation {method} for {operand}",
)
return UnknownType()
result: Optional[Type] = self._get_call_result(
expr.location,
operation,
[],
{},
)
return result or UnknownType()
def visit_call_expr(self, expr: p.CallExpr) -> Type:
callee: Type = self.type_of(expr.callee)
match expr.callee:
case p.VariableExpr(name="TypeVar"):
return self.define_typevar(expr) or UnknownType()
positional: list[TypedExpr] = [
(arg, self.type_of(arg)) for arg in expr.arguments
]
keywords: dict[str, TypedExpr] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
}
match expr.callee:
case p.GetExpr(object=obj, name=method):
obj_type: Type = self.type_of(obj)
unfolded: Type = unfold_type(obj_type)
if isinstance(unfolded, DataFrameType):
return self.frame_mgr.call(
method,
expr.location,
unfolded,
positional,
keywords,
)
callee: Type = self.type_of(expr.callee)
return (
self._get_call_result(
location=expr.location,
@@ -489,6 +594,8 @@ class PythonTyper(
return self.types.get_type("float")
case str():
return self.types.get_type("str")
case None:
return self.types.get_type("None")
case _:
self.reporter.warning(expr.location, f"Unknown literal {expr}")
return UnknownType()
@@ -516,7 +623,16 @@ class PythonTyper(
return UnknownType()
def visit_cast_expr(self, expr: p.CastExpr) -> Type:
return self.resolve_type_expr(expr.type)
subject_type: Type = self.type_of(expr.expr)
target_type: Type = self.resolve_type_expr(expr.type)
is_lit, lit_value = self._get_literal(expr.expr)
if is_lit:
evaluated: bool = self._evaluate_cast_statically(
expr, subject_type, target_type, lit_value
)
if evaluated:
self.evaluated_casts.append(expr)
return target_type
def visit_ternary_expr(self, expr: p.TernaryExpr) -> Type:
test_type: Type = self.type_of(expr.test)
@@ -551,9 +667,9 @@ class PythonTyper(
if len(item_types) == 1:
item_type: Type = item_types[0]
return self.types.apply_generic(list_type, [item_type])
self.reporter.error(
self.reporter.warning(
expr.location,
f"Heterogeneous list items: {item_types}",
f"Heterogeneous list items: [{', '.join(map(str, item_types))}]",
)
return self.types.apply_generic(list_type, [UnknownType()])
@@ -583,22 +699,29 @@ class PythonTyper(
if len(key_types) == 1:
key_type = key_types[0]
else:
self.reporter.error(
self.reporter.warning(
expr.location,
f"Heterogeneous dict keys: {key_types}",
f"Heterogeneous dict keys: [{', '.join(map(str, key_types))}]",
)
if len(value_types) == 1:
value_type = value_types[0]
else:
self.reporter.error(
self.reporter.warning(
expr.location,
f"Heterogeneous dict values: {value_types}",
f"Heterogeneous dict values: [{', '.join(map(str, value_types))}]",
)
return self.types.apply_generic(dict_type, [key_type, value_type])
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> Type:
object: Type = self.type_of(expr.object)
unfolded: Type = unfold_type(object)
match unfolded:
case TupleType():
return self._visit_tuple_subscript(unfolded, expr)
case DataFrameType():
return self._visit_frame_subscript(unfolded, expr)
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
if operation is None:
self.reporter.error(
@@ -636,13 +759,26 @@ class PythonTyper(
self.reporter.warning(node.location, "ConstraintType not yet supported")
return UnknownType()
def visit_frame_column(self, node: p.FrameColumn) -> Type:
self.reporter.warning(node.location, "FrameColumn not yet supported")
return UnknownType()
def visit_frame_column(self, node: p.FrameColumn) -> ColumnType:
return ColumnType(
type=(
self.resolve_type_expr(node.type)
if node.type is not None
else UnknownType()
)
)
def visit_frame_type(self, node: p.FrameType) -> Type:
self.reporter.warning(node.location, "FrameType not yet supported")
return UnknownType()
return DataFrameType(
columns=[
DataFrameType.Column(
index=i,
name=column.name,
type=self.visit_frame_column(column),
)
for i, column in enumerate(node.columns)
]
)
def _get_call_result(
self,
@@ -704,6 +840,28 @@ class PythonTyper(
location, base, positional, keywords, report_errors
)
case GenericType():
unifier: Unifier = Unifier(self.types)
pos: list[Type] = [a[1] for a in positional]
kw: dict[str, Type] = {k: v[1] for k, v in keywords.items()}
unified: Optional[Type] = unifier.unify_call(callee, pos, kw)
if unified is None:
if report_errors:
pos_str: str = ", ".join(str(t) for t in pos)
kw_str: str = ", ".join(f"{k}: {v}" for k, v in kw.items())
self.reporter.error(
location,
f"Could not unify {callee}={callee.body} with pos=[{pos_str}] and kw={{{kw_str}}}",
)
return None
return self._get_call_result(
location,
unified,
positional,
keywords,
report_errors,
)
case _:
if report_errors:
self.reporter.error(
@@ -1009,3 +1167,173 @@ class PythonTyper(
report_errors=False,
)
return result
def define_typevar(self, call: p.CallExpr) -> Optional[TypeVar]:
def is_kw_true(name: str) -> bool:
match call.keywords.get(name):
case p.LiteralExpr(value=True):
return True
case _:
return False
match call:
case p.CallExpr(
arguments=[p.LiteralExpr(value=str() as name)],
):
bound: Optional[Type] = None
variance: Variance = Variance.INVARIANT
if "bound" in call.keywords:
bound_type: p.MidasType = self._parse_type_from_expr(
call.keywords["bound"]
)
bound = self.resolve_type_expr(bound_type)
if is_kw_true("covariant"):
variance = Variance.COVARIANT
if is_kw_true("contravariant"):
if variance == Variance.COVARIANT:
self.reporter.warning(
call.keywords["contravariant"].location,
"TypeVar cannot be covariant and contravariant at the same time. Marked as invariant",
)
variance = Variance.INVARIANT
else:
variance = Variance.CONTRAVARIANT
var: TypeVar = TypeVar(name=name, bound=bound, variance=variance)
self.types.define_type(name, var)
return var
case _:
self.reporter.warning(
call.location, "Invalid usage of 'TypeVar', skipping"
)
return None
def _parse_type_from_expr(self, expr: p.Expr) -> p.MidasType:
location: Location = expr.location
parser = PythonParser()
match expr:
case p.LiteralExpr(value=str() as value):
node: ast.Expression = ast.parse(value, mode="eval")
return parser._parse_type(node.body)
case p.VariableExpr(name=name):
return p.BaseType(location=location, base=name, param=None)
case _:
raise NotImplementedError
def _get_literal(self, expr: p.Expr) -> tuple[bool, Any]:
match expr:
case p.LiteralExpr(value=value):
return True, value
case p.ListExpr(items=items):
values: list[Any] = []
for item in items:
is_lit, value = self._get_literal(item)
if not is_lit:
return False, None
values.append(value)
return True, values
case p.DictExpr(keys=keys, values=values):
pairs: list[tuple[Any, Any]] = []
for key, value in zip(keys, values):
key_val = None
if key is not None:
is_lit, key_val = self._get_literal(key)
if not is_lit:
return False, None
is_lit, value_val = self._get_literal(value)
if not is_lit:
return False, None
if key is None:
# TODO: check that value is always a dict
assert isinstance(value_val, dict)
pairs.extend(value_val.items())
else:
pairs.append((key_val, value_val))
return True, dict(pairs)
case _:
return False, None
def _evaluate_cast_statically(
self, expr: p.CastExpr, subject_type: Type, target_type: Type, lit_value: Any
) -> bool:
match target_type:
case AliasType(type=base):
return self._evaluate_cast_statically(
expr, subject_type, base, lit_value
)
case AppliedType(body=body):
return self._evaluate_cast_statically(
expr, subject_type, body, lit_value
)
case ConstraintType(type=base, constraint=constraint):
evaluated: bool = True
if not self._evaluate_cast_statically(
expr, subject_type, base, lit_value
):
evaluated = False
evaluator = Evaluator(self.types)
evaluator.set_value("_", lit_value)
res = evaluator.evaluate(constraint)
if not res:
printer = MidasPrinter()
constraint_str: str = printer.print(constraint)
self.reporter.error(
expr.location,
f"Value {lit_value!r} does not fit constraint '{constraint_str}'",
)
evaluated = False
return evaluated
case BaseType():
# TODO: do we want to allow cast(float, int)? would require runtime conversion
if not self.types.is_subtype(
subject_type, target_type
) or not self.types.is_subtype(target_type, subject_type):
self.reporter.error(
expr.location,
f"Value {lit_value!r} of type {subject_type} cannot be cast as {target_type}",
)
return False
return True
case DataFrameType() | ColumnType():
self.reporter.error(
expr.location, f"Cannot cast {lit_value!r} to {target_type}"
)
return False
case _:
self.reporter.info(
expr.location, f"Cannot evaluate cast to {target_type} statically"
)
return False
def _visit_tuple_subscript(self, tup: TupleType, expr: p.SubscriptExpr) -> Type:
match expr.index:
case p.LiteralExpr(value=int() as index):
if index < 0 or index >= len(tup.items):
self.reporter.error(
expr.location, f"Index {index} out of range for tuple {tup}"
)
return UnknownType()
return tup.items[index]
case _:
self.reporter.error(
expr.location, f"Invalid index type {expr.index} on {tup}"
)
return UnknownType()
def _visit_frame_subscript(
self, frame: DataFrameType, expr: p.SubscriptExpr
) -> Type:
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)

View File

@@ -8,8 +8,10 @@ from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
ExtensionType,
Function,
GenericType,
@@ -130,6 +132,19 @@ class TypesRegistry:
case (_, TopType()):
return True
case (_, UnknownType()):
return True
case (TypeVar(bound=bound), _):
if bound is None:
return False
return self.is_subtype(bound, type2)
case (_, TypeVar(bound=bound)):
if bound is None:
return True
return self.is_subtype(type1, bound)
case (AliasType(type=base1), _):
return self.is_subtype(base1, type2)
@@ -144,14 +159,27 @@ class TypesRegistry:
return False
return True
case (DataFrameType(columns=columns1), DataFrameType(columns=columns2)):
# TODO: check order?
by_name1: dict[str, DataFrameType.Column] = {
col.name: col for col in columns1 if col.name is not None
}
for col2 in columns2:
if col2.name not in by_name1:
return False
if not self.is_subtype(by_name1[col2.name].type, col2.type):
return False
return True
case (ColumnType(type=inner1), ColumnType(type=inner2)):
# TODO: invariant, replace ColumnType with simple GenericType
if not self.are_equivalent(inner1, inner2):
return False
return True
case (Function(), Function()):
return self.is_func_subtype(type1, type2)
case (TypeVar(bound=bound), _):
if bound is None:
return False
return self.is_subtype(bound, type2)
case (ConstraintType(type=base1), _):
return self.is_subtype(base1, type2)
@@ -173,8 +201,15 @@ class TypesRegistry:
return False
return True
# TODO: verify legitimacy
case (AppliedType(body=body), _):
return self.is_subtype(body, type2)
return False
def are_equivalent(self, type1: Type, type2: Type) -> bool:
return self.is_subtype(type1, type2) and self.is_subtype(type2, type1)
# TODO: verify the logic in here
def is_func_subtype(self, func1: Function, func2: Function) -> bool:
"""Check whether a function is a subtype of another
@@ -389,6 +424,12 @@ class TypesRegistry:
)
return self.lookup_member(base, member_name)
case ConstraintType(type=base):
return self.lookup_member(base, member_name)
case TypeVar(bound=bound) if bound is not None:
return self.lookup_member(bound, member_name)
case UnknownType():
return UnknownType()

View File

@@ -61,3 +61,10 @@ class FileReporter:
location=location,
message=message,
)
def debug(self, location: Location, message: str):
self.report(
type=DiagnosticType.DEBUG,
location=location,
message=message,
)

View File

@@ -128,6 +128,10 @@ class Resolver(p.Stmt.Visitor[None], p.Expr.Visitor[None]):
case p.GetExpr():
target.accept(self)
case p.SubscriptExpr():
target.accept(self)
case _:
raise Exception(f"Unsupported assignment to {target}")

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Optional, assert_never
from typing import Optional, assert_never, cast
import midas.ast.midas as m
from midas.ast.printer import MidasPrinter
@@ -156,6 +156,37 @@ class ConstraintType:
return f"{self.type} where {printer.print(self.constraint)}"
@dataclass(frozen=True, kw_only=True)
class TupleType:
items: tuple[Type, ...]
def __str__(self) -> str:
return f"({', '.join(map(str, self.items))})"
@dataclass(frozen=True, kw_only=True)
class ColumnType:
type: Type
def __str__(self) -> str:
return f"Column[{self.type}]"
@dataclass(frozen=True, kw_only=True)
class DataFrameType:
columns: list[Column]
def __str__(self) -> str:
schema: list[str] = [f"{col.name}: {col.type}" for col in self.columns]
return f"Frame[{', '.join(schema)}]"
@dataclass(frozen=True, kw_only=True)
class Column:
index: int
name: Optional[str]
type: ColumnType
def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
def sub_argument(arg: Function.Argument):
return Function.Argument(
@@ -165,6 +196,13 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
required=arg.required,
)
def sub_column(col: DataFrameType.Column):
return DataFrameType.Column(
index=col.index,
name=col.name,
type=cast(ColumnType, substitute_typevars(col.type, substitutions)),
)
match type:
case TopType():
return type
@@ -250,10 +288,26 @@ def substitute_typevars(type: Type, substitutions: dict[str, Type]) -> Type:
body=substitute_typevars(body, substitutions),
)
case TupleType(items=items):
return TupleType(
items=tuple(substitute_typevars(item, substitutions) for item in items),
)
case ColumnType(type=items_type):
return ColumnType(
type=substitute_typevars(items_type, substitutions),
)
case DataFrameType(columns=columns):
return DataFrameType(
columns=list(map(sub_column, columns)),
)
case UnknownType() | UnitType():
return type
case TopType() | GenericType():
raise NotImplementedError(f"Unsupported type {type}")
# Ensure exhaustiveness
@@ -317,6 +371,15 @@ def to_annotation(type: Type) -> str:
case ConstraintType():
return str(type)
case TupleType(items=items):
return f"Tuple[{', '.join(map(to_annotation, items))}]"
case ColumnType():
return "pd.Series"
case DataFrameType():
return "pd.DataFrame"
case _:
assert_never(type)
@@ -342,4 +405,7 @@ Type = (
| GenericType
| AppliedType
| ConstraintType
| TupleType
| ColumnType
| DataFrameType
)

169
midas/checker/unifier.py Normal file
View File

@@ -0,0 +1,169 @@
import logging
from typing import Optional
from midas.checker.registry import TypesRegistry
from midas.checker.types import (
AppliedType,
Function,
GenericType,
TopType,
Type,
TypeVar,
)
class UnificationError(Exception): ...
class Unifier:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
self.logger: logging.Logger = logging.getLogger("Unifier")
def unify_call(
self,
type: GenericType,
positional: list[Type],
keywords: dict[str, Type],
) -> Optional[Type]:
concrete_func: Function = Function(
pos_args=[
Function.Argument(
pos=i,
name=str(i),
type=arg,
required=True,
)
for i, arg in enumerate(positional)
],
args=[],
kw_args=[
Function.Argument(
pos=len(positional) + i,
name=name,
type=arg,
required=True,
)
for i, (name, arg) in enumerate(keywords.items())
],
returns=TopType(), # TODO: use expected type
)
return self.unify_generic(type, concrete_func, match_return=False)
def unify_generic(
self,
template: GenericType,
concrete: Type,
match_return: bool = True,
) -> Optional[Type]:
substitutions: dict[str, Type]
try:
substitutions = self.match(template.body, concrete, match_return)
except UnificationError:
return None
args: list[Type] = []
for param in template.params:
if param.name not in substitutions:
return None
args.append(substitutions[param.name])
applied: Type = self.types.apply_generic(template, args)
return applied
def match(
self,
template: Type,
concrete: Type,
match_return: bool = True,
) -> dict[str, Type]:
# TODO: if concrete is Generic, record bound TypeVar. Then when merging
# substitutions, check that the constraint is respected
match (template, concrete):
case (TypeVar(name=name), _):
return {name: concrete}
case (
AppliedType(name=template_name, args=template_args),
AppliedType(name=concrete_name, args=concrete_args),
) if template_name == concrete_name and len(template_args) == len(
concrete_args
):
substitutions: dict[str, Type] = {}
for template_arg, concrete_arg in zip(template_args, concrete_args):
new_substistutions: dict[str, Type] = self.match(
template_arg, concrete_arg
)
substitutions = self.merge(substitutions, new_substistutions)
return substitutions
case (Function(), Function()):
mapped: list[tuple[Function.Argument, Function.Argument]] = (
self.map_params(template, concrete)
)
substitutions: dict[str, Type] = {}
for template_arg, concrete_arg in mapped:
arg_subs: dict[str, Type] = self.match(
template_arg.type, concrete_arg.type
)
substitutions = self.merge(substitutions, arg_subs)
if match_return:
return_subs: dict[str, Type] = self.match(
template.returns, concrete.returns
)
substitutions = self.merge(substitutions, return_subs)
return substitutions
case _:
self.logger.debug(f"Can't match {concrete!r} with {template!r}")
return {}
def merge(self, subs1: dict[str, Type], subs2: dict[str, Type]) -> dict[str, Type]:
merged: dict[str, Type] = subs1.copy()
for k, v in subs2.items():
if k in merged and merged[k] != v:
self.logger.debug(
f"Substitution already defined for {k} with type {merged[k]}, got {v}"
)
raise UnificationError
merged[k] = v
return merged
def map_params(
self, func1: Function, func2: Function
) -> list[tuple[Function.Argument, Function.Argument]]:
pos1: list[Function.Argument] = func1.pos_args
mixed1: list[Function.Argument] = func1.args
kw1: list[Function.Argument] = func1.kw_args
pos2: list[Function.Argument] = func2.pos_args
mixed2: list[Function.Argument] = func2.args
kw2: list[Function.Argument] = func2.kw_args
mapped: list[tuple[Function.Argument, Function.Argument]] = []
by_pos2: dict[int, Function.Argument] = {arg.pos: arg for arg in pos2 + mixed2}
by_name2: dict[str, Function.Argument] = {arg.name: arg for arg in mixed2 + kw2}
for arg1 in pos1:
if (arg2 := by_pos2.get(arg1.pos)) is not None:
mapped.append((arg1, arg2))
for arg1 in mixed1:
# Match both positionally and by name, conflicts are caught
# when merging substitutions
if (arg2 := by_pos2.get(arg1.pos)) is not None:
mapped.append((arg1, arg2))
if (arg2 := by_name2.get(arg1.name)) is not None:
mapped.append((arg1, arg2))
for arg1 in kw1:
if (arg2 := by_name2.get(arg1.name)) is not None:
mapped.append((arg1, arg2))
return mapped

View File

@@ -19,9 +19,11 @@ from midas.utils import TypedAST
@click.command(help="Compile source")
@click.argument("file", type=click.File("r"))
@click.option("-t", "--types", type=click.File("r"), multiple=True)
@click.option("--ignore-errors", is_flag=True)
def compile(
file: TextIO,
types: tuple[TextIO],
ignore_errors: bool,
):
source: str = file.read()
source_path: Path = Path(file.name).resolve()
@@ -35,7 +37,9 @@ def compile(
printer = DiagnosticPrinter()
printer.print_all(diagnostics)
if any(map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)):
if not ignore_errors and any(
map(lambda d: d.type == DiagnosticType.ERROR, diagnostics)
):
sys.exit(1)
generator = Generator(workdir=source_path.parent, types=checker.types)

View File

@@ -1,27 +1,64 @@
import ast
import time
from pathlib import Path
from typing import TextIO
import black
import click
from watchdog.events import DirModifiedEvent, FileModifiedEvent, FileSystemEventHandler
from watchdog.observers import Observer
from midas.checker.checker import TypeChecker
from midas.generator.stubs import StubsGenerator
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
def stubs(
file: TextIO,
output: TextIO,
):
source_path: Path = Path(file.name).resolve()
def generate_stubs(in_path: Path, out_path: Path):
checker = TypeChecker()
checker.import_midas(source_path)
checker.import_midas(in_path)
generator = StubsGenerator(checker.types)
module: ast.Module = generator.generate_stubs()
module = ast.fix_missing_locations(module)
output.write(ast.unparse(module))
output: str = ast.unparse(module)
output = black.format_str(output, mode=black.Mode(is_pyi=True))
out_path.write_text(output)
class Handler(FileSystemEventHandler):
def __init__(self, in_path: Path, out_path: Path) -> None:
super().__init__()
self.in_path: Path = in_path
self.out_path: Path = out_path
def on_modified(self, event: DirModifiedEvent | FileModifiedEvent) -> None:
generate_stubs(self.in_path, self.out_path)
@click.command(help="Generate stubs from Midas definitions")
@click.argument("file", type=click.File("r"))
@click.option("-o", "--output", type=click.File("w"), default="-")
@click.option("-w", "--watch", is_flag=True)
def stubs(
file: TextIO,
output: TextIO,
watch: bool,
):
source_path: Path = Path(file.name).resolve()
out_path: Path = Path(output.name).resolve()
generate_stubs(source_path, out_path)
if watch:
print(f"Watching {source_path}...")
print("Press CTRL+C to stop")
handler = Handler(source_path, out_path)
observer = Observer()
observer.schedule(handler, str(source_path))
observer.start()
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
observer.stop()
observer.join()

View File

@@ -41,6 +41,7 @@ def types(
message=f"Type: {type}",
)
)
diagnostics.extend(checker.diagnostics)
printer = DiagnosticPrinter()
printer.print_all(diagnostics)

View File

@@ -228,6 +228,13 @@ class PythonHighlighter(
for item in expr.items:
item.accept(self)
def visit_dict_expr(self, expr: p.DictExpr) -> None:
for key in expr.keys:
if key is not None:
key.accept(self)
for value in expr.values:
value.accept(self)
def visit_subscript_expr(self, expr: p.SubscriptExpr) -> None:
expr.object.accept(self)
expr.index.accept(self)
@@ -240,6 +247,10 @@ class PythonHighlighter(
if expr.step is not None:
expr.step.accept(self)
def visit_raw_expr(self, expr: p.RawExpr) -> None: ...
def visit_raw_stmt(self, stmt: p.RawStmt) -> None: ...
class MidasHighlighter(
Highlighter, m.Stmt.Visitor[None], m.Expr.Visitor[None], m.Type.Visitor[None]
@@ -266,8 +277,9 @@ class MidasHighlighter(
def visit_predicate_stmt(self, stmt: m.PredicateStmt) -> None:
self.wrap(stmt, "predicate")
self.wrap(LocatableToken(stmt.name), "predicate-name")
stmt.type.accept(self)
stmt.condition.accept(self)
for spec in stmt.params:
self._visit_param_spec(spec)
stmt.body.accept(self)
def visit_logical_expr(self, expr: m.LogicalExpr) -> None:
self.wrap(expr, "logical-expr")
@@ -283,6 +295,14 @@ class MidasHighlighter(
self.wrap(expr, "unary-expr")
expr.right.accept(self)
def visit_call_expr(self, expr: m.CallExpr) -> None:
self.wrap(expr, "call-expr")
expr.callee.accept(self)
for arg in expr.arguments:
arg.accept(self)
for arg in expr.keywords.values():
arg.accept(self)
def visit_get_expr(self, expr: m.GetExpr) -> None:
self.wrap(expr, "get-expr")
expr.expr.accept(self)
@@ -318,8 +338,7 @@ class MidasHighlighter(
def visit_function_type(self, type: m.FunctionType) -> None:
self.wrap(type, "function")
for arg in type.pos_args + type.args + type.kw_args:
arg.type.accept(self)
self._visit_param_spec(type.params)
type.returns.accept(self)
def visit_extension_type(self, type: m.ExtensionType) -> None:
@@ -327,6 +346,18 @@ class MidasHighlighter(
type.base.accept(self)
type.extension.accept(self)
def _visit_param_spec(self, spec: m.ParamSpec) -> None:
for param in spec.pos + spec.mixed + spec.kw:
param.type.accept(self)
def visit_frame_type(self, type: m.FrameType) -> None:
self.wrap(type, "frame")
for column in type.columns:
self._visit_frame_column(column)
def _visit_frame_column(self, column: m.FrameType.Column) -> None:
self.wrap(column, "column")
class DiagnosticsHighlighter(Highlighter):
EXTRA_CSS_PATH: Optional[Path] = Path(__file__).parent / "hl_diagnostic.css"

View File

@@ -1,3 +1,4 @@
from collections import defaultdict
from pathlib import Path
from typing import Optional
@@ -7,6 +8,13 @@ from midas.cli.ansi import Ansi
class DiagnosticPrinter:
COLORS: dict[DiagnosticType, int] = {
DiagnosticType.ERROR: Ansi.RED,
DiagnosticType.WARNING: Ansi.YELLOW,
DiagnosticType.INFO: Ansi.CYAN,
DiagnosticType.DEBUG: Ansi.MAGENTA,
}
def __init__(self) -> None:
self.files: dict[Optional[str], list[str]] = {}
@@ -22,10 +30,25 @@ class DiagnosticPrinter:
return self.files[filename]
def print_all(self, diagnostics: list[Diagnostic], indent: int = 4):
by_type: dict[DiagnosticType, int] = defaultdict(int)
for diagnostic in diagnostics:
filename: Optional[str] = diagnostic.file_path
lines = self.get_lines(filename)
self.print(lines, diagnostic, indent=indent)
by_type[diagnostic.type] += 1
if len(diagnostics) == 0:
return
counts: list[str] = []
for type in DiagnosticType:
if type not in by_type:
continue
count: int = by_type[type]
color: int = self.COLORS.get(type, Ansi.WHITE)
counts.append(f"{Ansi.FG(color)}{type.value}s{Ansi.RESET}: {count}")
print(" ".join(counts))
def print(self, lines: list[str], diagnostic: Diagnostic, indent: int = 4):
"""Pretty-print a diagnostic, showing some context if possible
@@ -45,7 +68,7 @@ class DiagnosticPrinter:
loc: Location = diagnostic.location
if loc.lineno != loc.end_lineno:
print(diagnostic)
self.print_multiline(lines, diagnostic, indent)
return
start_offset: int = loc.col_offset
@@ -55,11 +78,7 @@ class DiagnosticPrinter:
before: str = line[:start_offset]
after: str = line[end_offset:]
color: int = {
DiagnosticType.ERROR: Ansi.RED,
DiagnosticType.WARNING: Ansi.YELLOW,
DiagnosticType.INFO: Ansi.CYAN,
}.get(diagnostic.type, Ansi.WHITE)
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
subject: str = Ansi.FG(color) + line[start_offset:end_offset] + Ansi.RESET
cursor: str = (
@@ -76,3 +95,27 @@ class DiagnosticPrinter:
print(indent_str + before + subject + after)
print(indent_str + cursor)
print()
def print_multiline(
self, all_lines: list[str], diagnostic: Diagnostic, indent: int = 4
):
loc: Location = diagnostic.location
lines: list[str] = all_lines[loc.lineno - 1 : loc.end_lineno]
start_offset: int = loc.col_offset
end_offset: int = loc.end_col_offset or (start_offset + 1)
indent_str: str = " " * indent
color: int = self.COLORS.get(diagnostic.type, Ansi.WHITE)
res: str = indent_str + lines[0][:start_offset]
res += Ansi.FG(color) + lines[0][start_offset:]
for line in lines[1:-1]:
res += "\n" + indent_str + line
res += "\n" + indent_str + lines[-1][:end_offset]
res += Ansi.RESET + lines[-1][end_offset:]
print(diagnostic.location_str + ":")
print(res)
print()
print(Ansi.FG(color) + diagnostic.message + Ansi.RESET)
print()

View File

@@ -1,4 +1,5 @@
import ast
import logging
import shutil
from dataclasses import dataclass, field
from pathlib import Path
@@ -13,13 +14,16 @@ from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
TupleType,
Type,
TypeVar,
UnitType,
@@ -36,14 +40,19 @@ class Scope:
class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
IS_DATAFRAME_FUNC = "__midas_is_dataframe__"
IS_COLUMN_FUNC = "__midas_is_column__"
def __init__(self, workdir: Path, types: TypesRegistry) -> None:
self.workdir: Path = workdir.resolve()
self.build_dir: Path = self.workdir / "build" / "midas"
self.rel_src_path: Path = Path()
self.logger: logging.Logger = logging.getLogger("Generator")
self._typed_ast: TypedAST = TypedAST(
stmts=[],
judgements=[],
evaluated_casts=[],
)
self._alias_count: int = 0
self._predicate_count: int = 0
@@ -52,12 +61,24 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._constraint_generator: ConstraintGenerator = ConstraintGenerator(types)
self._constraints: list[tuple[m.Expr, ast.expr]] = []
self.define_is_dataframe: bool = False
self.define_is_column: bool = False
def generate_ast(self, typed_ast: TypedAST, src_path: Path) -> ast.AST:
self.rel_src_path = src_path.resolve().relative_to(self.workdir)
self._typed_ast = typed_ast
body: list[ast.stmt] = self._visit_body(typed_ast.stmts)
predicates: list[ast.stmt] = self._constraint_generator.get_definitions()
module = ast.Module(body=predicates + body, type_ignores=[])
body = predicates + body
if self.define_is_dataframe:
body = [self._is_dataframe_definition()] + body
if self.define_is_column:
body = [self._is_column_definition()] + body
module = ast.Module(body=body, type_ignores=[])
module = ast.fix_missing_locations(module)
return module
@@ -131,6 +152,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def visit_cast_expr(self, expr: p.CastExpr) -> ast.expr:
expr2: ast.expr = expr.expr.accept(self)
if expr in self._typed_ast.evaluated_casts or expr.unsafe:
return expr2
alias: ast.expr = self._make_alias(expr2)
type: Type = self._get_expr_type(expr)
@@ -322,8 +347,68 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
self._make_cast_asserts(src_location, expr, base)
self._make_constraint_assert(src_location, expr, constraint)
case TypeVar():
raise RuntimeError("Unexpected TypeVar")
case TypeVar(bound=bound):
# TODO: check with type from arguments / use call-site context
if bound is not None:
self._make_cast_asserts(src_location, expr, bound)
case TupleType(items=items):
self._add_assert(
ast.Call(
func=ast.Name(id="isinstance"),
args=[expr, ast.Name(id="tuple")],
keywords=[],
),
self._make_cast_assert_message(src_location, expr, type),
)
assert isinstance(expr, ast.Tuple)
for item, item_type in zip(expr.elts, items):
self._make_cast_asserts(src_location, item, item_type)
case DataFrameType(columns=columns):
self.define_is_dataframe = True
self._add_assert(
ast.Call(
func=ast.Name(id=self.IS_DATAFRAME_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a dataframe"
),
)
for column in columns:
self._add_assert(
ast.Compare(
left=ast.Constant(value=column.name),
ops=[ast.In()],
comparators=[expr],
),
self._make_cast_assert_message(
src_location, expr, type, f": Missing column {column.name}"
),
)
self._make_cast_asserts(
src_location,
ast.Subscript(
value=expr, slice=ast.Constant(value=column.name)
),
column.type,
)
case ColumnType(type=inner):
self.define_is_column = True
self._add_assert(
ast.Call(
func=ast.Name(id=self.IS_COLUMN_FUNC),
args=[expr],
keywords=[],
),
self._make_cast_assert_message(
src_location, expr, type, ": Not a column"
),
)
# TODO: check value type
case (
TopType()
@@ -333,14 +418,18 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
| ExtensionType()
| GenericType()
):
raise NotImplementedError(f"Can't make assertion for type {type}")
self.logger.warning(f"Can't make assertion for type {type}")
# Ensure exhaustiveness
case _:
assert_never(type)
def _make_cast_assert_message(
self, location: Location, expr: ast.expr, type: Type
self,
location: Location,
expr: ast.expr,
type: Type,
extra: Optional[str] = None,
) -> ast.expr:
loc_str: str = f"{self.rel_src_path}:L{location.lineno}:{location.col_offset+1}"
# f"file.py:L1:1: CastError: Cannot cast {type(expr).__name__} to Type"
@@ -358,7 +447,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
),
conversion=-1,
),
ast.Constant(f" to {type}"),
ast.Constant(f" to {type}{extra or ''}"),
]
)
@@ -394,3 +483,75 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
constraint: ast.expr = self._constraint_generator.generate(expr)
self._constraints.append((expr, constraint))
return constraint
def _is_dataframe_definition(self) -> ast.stmt:
"""
def IS_DATAFRAME_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.DataFrame)
"""
return ast.FunctionDef(
name=self.IS_DATAFRAME_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)
def _is_column_definition(self) -> ast.stmt:
"""
def IS_COLUMN_FUNC(obj) -> bool:
import pandas as pd
return isinstance(obj, pd.Series)
"""
return ast.FunctionDef(
name=self.IS_COLUMN_FUNC,
args=ast.arguments(
posonlyargs=[ast.arg(arg="obj")],
args=[],
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=[
ast.Import(names=[ast.alias(name="pandas", asname="pd")]),
ast.Return(
value=ast.Call(
func=ast.Name(id="isinstance"),
args=[
ast.Name(id="obj"),
ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
],
keywords=[],
)
),
],
decorator_list=[],
returns=ast.Name(id="bool"),
)

View File

@@ -1,5 +1,5 @@
import ast
from typing import Optional
from typing import Optional, assert_never
import midas.ast.midas as m
from midas.checker.registry import Member, TypesRegistry
@@ -7,16 +7,21 @@ from midas.checker.types import (
AliasType,
AppliedType,
BaseType,
ColumnType,
ComplexType,
ConstraintType,
DataFrameType,
ExtensionType,
Function,
GenericType,
OverloadedFunction,
TopType,
TupleType,
Type,
TypeVar,
UnitType,
UnknownType,
Variance,
substitute_typevars,
)
@@ -28,6 +33,7 @@ class StubsGenerator:
self.types: TypesRegistry = types
self.stubs: list[ast.stmt] = []
self.typing_imports: set[str] = set()
self.import_pandas: bool = False
self.protocol_idx: int = 0
self.stub_idx: int = 0
self.type_var_idx: int = 0
@@ -36,10 +42,23 @@ class StubsGenerator:
def generate_stubs(self) -> ast.Module:
self.stubs = []
self.typing_imports = set()
self.import_pandas = False
for name, type in self.types._types.items():
# Skip builtin types, not just based on name so the user can override
# TODO: check if added members on builtin type
match type:
case BaseType(name=name_) if name == name_:
continue
case GenericType(
name=name1,
body=BaseType(name=name2),
) if (
name == name1 == name2
):
continue
self.generate_stub(name, type)
imports = [
imports: list[ast.stmt] = [
ast.ImportFrom(
module="__future__",
names=[ast.alias(name="annotations")],
@@ -56,6 +75,17 @@ class StubsGenerator:
level=0,
)
)
if self.import_pandas:
imports.append(
ast.Import(
names=[
ast.alias(
name="pandas",
asname="pd",
)
],
)
)
return ast.Module(body=imports + self.stubs, type_ignores=[])
def generate_stub(self, name: str, type: Type):
@@ -84,6 +114,7 @@ class StubsGenerator:
match type:
case AliasType(type=base):
return [self.dump_type(base)], {}
case GenericType(params=params, body=body):
self.add_typing_import("Generic")
type_vars: ast.expr
@@ -111,6 +142,13 @@ class StubsGenerator:
],
body_subsitutions | substitutions,
)
case ConstraintType(type=base):
return self.get_bases(base)
case TypeVar(bound=bound) if bound is not None:
return [self.dump_type(bound)], {}
case _:
return [], {}
@@ -148,15 +186,20 @@ class StubsGenerator:
case TopType() | UnknownType():
self.add_typing_import("Any")
return ast.Name(id="Any")
case BaseType(name=name):
return ast.Name(id=name)
case AliasType(name=name):
return ast.Name(id=name)
case UnitType():
return ast.Constant(value=None)
case Function():
name: str = self.define_protocol(type)
return ast.Name(id=name)
case OverloadedFunction(overloads=overloads):
if len(overloads) == 1:
return self.dump_type(overloads[0])
@@ -176,6 +219,7 @@ class StubsGenerator:
case TypeVar():
return ast.Name(id=type.name)
case GenericType(name=name):
params: ast.expr
if len(type.params) == 1:
@@ -188,6 +232,7 @@ class StubsGenerator:
value=ast.Name(id=type.name),
slice=params,
)
case AppliedType():
args: ast.expr
if len(type.args) == 1:
@@ -199,6 +244,37 @@ class StubsGenerator:
slice=args,
)
case ConstraintType():
return self.dump_type(type.type)
case TupleType(items=items):
return ast.Subscript(
value=ast.Name(id="tuple"),
slice=ast.Tuple(
elts=[self.dump_type(item) for item in items],
),
)
case ColumnType(type=inner):
self.import_pandas = True
return ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="pd"),
attr="Series",
),
slice=self.dump_type(inner),
)
case DataFrameType():
self.import_pandas = True
return ast.Attribute(
value=ast.Name(id="pd"),
attr="DataFrame",
)
case _:
assert_never(type)
def dump_method(
self, name: str, method: Type, overloaded: bool = False
) -> list[ast.stmt]:
@@ -313,6 +389,29 @@ class StubsGenerator:
def define_type_var(self, var: TypeVar) -> TypeVar:
name: str = self.new_type_var_name()
self.add_typing_import("TypeVar")
kwargs: list[ast.keyword] = []
if var.bound is not None:
kwargs.append(
ast.keyword(
arg="bound",
value=self.dump_type(var.bound),
)
)
if var.variance == Variance.COVARIANT:
kwargs.append(
ast.keyword(
arg="covariant",
value=ast.Constant(value=True),
)
)
elif var.variance == Variance.CONTRAVARIANT:
kwargs.append(
ast.keyword(
arg="contravariant",
value=ast.Constant(value=True),
)
)
self.add_stub(
ast.Assign(
targets=[ast.Name(id=name)],
@@ -321,16 +420,7 @@ class StubsGenerator:
args=[
ast.Constant(value=name),
],
keywords=(
[]
if var.bound is None
else [
ast.keyword(
arg="bound",
value=self.dump_type(var.bound),
)
]
),
keywords=kwargs,
),
)
)

View File

@@ -9,6 +9,7 @@ from midas.ast.midas import (
Expr,
ExtendStmt,
ExtensionType,
FrameType,
FunctionType,
GenericType,
GetExpr,
@@ -204,8 +205,10 @@ class MidasParser(Parser):
return self.generic_type()
def generic_type(self) -> Type:
type: Type = self.named_type()
type: NamedType = self.named_type()
if self.check(TokenType.LEFT_BRACKET):
if type.name.lexeme == "Frame":
return self.frame_type()
args: list[Type] = self.type_args()
return GenericType(
location=Location.span(type.location, self.previous().get_location()),
@@ -224,7 +227,7 @@ class MidasParser(Parser):
self.consume(TokenType.RIGHT_BRACKET, "Missing ']' after generic arguments")
return args
def named_type(self) -> Type:
def named_type(self) -> NamedType:
name: Token = self.consume_identifier("Expected type name")
return NamedType(
location=name.get_location(),
@@ -259,6 +262,32 @@ class MidasParser(Parser):
members=members,
)
def frame_type(self) -> FrameType:
keyword: Token = self.previous()
self.consume(TokenType.LEFT_BRACKET, "Expected '[' to start frame schema")
columns: list[FrameType.Column] = []
while not self.check(TokenType.RIGHT_BRACKET) and not self.is_at_end():
name: Token = self.advance()
self.consume(TokenType.COLON, "Expected ':' between column name and type")
type: Type = self.type_expr()
columns.append(
FrameType.Column(
location=name.location_to(self.previous()),
name=name,
type=type,
)
)
if not self.match(TokenType.COMMA):
break
self.consume(TokenType.RIGHT_BRACKET, "Unclosed frame schema")
return FrameType(
location=keyword.location_to(self.previous()),
columns=columns,
)
def constraint(self) -> Expr:
"""Parse a constraint

View File

@@ -49,6 +49,7 @@ class UnsupportedSyntaxError(Exception):
class PythonParser:
CAST_FUNCTION = "cast"
UNSAFE_CAST_FUNCTION = "unsafe_cast"
def parse_module(self, node: ast.Module) -> list[Stmt]:
statements: list[Stmt] = []
@@ -423,6 +424,9 @@ class PythonParser:
case ast.Call(func=ast.Name(id=self.CAST_FUNCTION)):
return self.parse_cast(node)
case ast.Call(func=ast.Name(id=self.UNSAFE_CAST_FUNCTION)):
return self.parse_cast(node)
case ast.Call():
return self.parse_call(node)
@@ -527,16 +531,19 @@ class PythonParser:
return expr
def parse_cast(self, node: ast.Call) -> CastExpr:
assert isinstance(node.func, ast.Name)
func: str = node.func.id
match node:
case ast.Call(args=[type, expr], keywords=[]):
return CastExpr(
location=Location.from_ast(node),
type=self._parse_type(type),
expr=self.parse_expr(expr),
unsafe=func == self.UNSAFE_CAST_FUNCTION,
)
case _:
raise InvalidSyntaxError(
f"Invalid call to {self.CAST_FUNCTION}, expected type and expression"
f"Invalid call to {func}, expected type and expression"
)
def parse_call(self, node: ast.Call) -> CallExpr:

52
midas/typing.py Normal file
View File

@@ -0,0 +1,52 @@
from typing import Generic, TypeVar
from typing import cast as typing_cast
cast = typing_cast
"""### Midas documentation
Cast a value to a type.
- **Compile-time**: tells the type checker that the return value has the designated type.
- **Run-time**: generates assertions to ensure the value can be interpreted as the given type.
---
<br>
<br>
<br>
_**Internal Python documentation**_
"""
unsafe_cast = typing_cast
"""### Midas documentation
Cast a value to a type.
- **Compile-time**: tells the type checker that the return value has the designated type.
- **Run-time**: -
This operation is unsound, use at your own risk!
---
<br>
<br>
<br>
_**Internal Python documentation**_
"""
T = TypeVar("T")
class Frame(Generic[T]):
"""A `Frame` is the abstract type implemented by `DataFrame`
A frame contains any number of named columns (see :class:`Column`)
"""
class Column(Generic[T]):
"""A `Column` is the abstract type implemented by `Series`
A column contains a any number of values of the same type
"""

View File

@@ -62,3 +62,4 @@ class UniversalJSONDumper:
class TypedAST:
stmts: list[p.Stmt]
judgements: list[tuple[p.Expr, Type]]
evaluated_casts: list[p.CastExpr]

View File

@@ -8,7 +8,11 @@ authors = [
{ name = "Louis Heredero", email = "louis.heredero@students.hevs.ch" },
]
classifiers = ["Programming Language :: Python :: 3"]
dependencies = ["click>=8.4.1"]
dependencies = [
"black>=26.5.1",
"click>=8.4.1",
"watchdog>=6.0.0",
]
[project.urls]
Homepage = "https://git.kbk28.ch/HEL/midas"

View File

@@ -4,7 +4,35 @@
"type": "Warning",
"location": {
"start": [
6,
8,
12
],
"end": [
8,
43
]
},
"message": "ConstraintType not yet supported"
},
{
"type": "Warning",
"location": {
"start": [
10,
10
],
"end": [
10,
18
]
},
"message": "Unknown type 'datetime'"
},
{
"type": "Warning",
"location": {
"start": [
13,
4
],
"end": [
@@ -12,7 +40,7 @@
5
]
},
"message": "FrameType not yet supported"
"message": "Unknown type '_'"
}
],
"judgments": []

View File

@@ -328,6 +328,19 @@
},
"type": {}
},
{
"location": {
"from": "L6:9",
"to": "L6:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:5",
@@ -373,19 +386,6 @@
}
}
},
{
"location": {
"from": "L6:9",
"to": "L6:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:5",
@@ -407,6 +407,32 @@
},
"type": {}
},
{
"location": {
"from": "L7:9",
"to": "L7:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L7:12",
"to": "L7:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L7:5",
@@ -452,32 +478,6 @@
}
}
},
{
"location": {
"from": "L7:9",
"to": "L7:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L7:12",
"to": "L7:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L7:5",
@@ -503,6 +503,32 @@
},
"type": {}
},
{
"location": {
"from": "L8:9",
"to": "L8:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L8:14",
"to": "L8:17"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L8:5",
@@ -548,32 +574,6 @@
}
}
},
{
"location": {
"from": "L8:9",
"to": "L8:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L8:14",
"to": "L8:17"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L8:5",
@@ -600,6 +600,45 @@
},
"type": {}
},
{
"location": {
"from": "L9:9",
"to": "L9:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L9:12",
"to": "L9:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:17",
"to": "L9:23"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L9:5",
@@ -645,45 +684,6 @@
}
}
},
{
"location": {
"from": "L9:9",
"to": "L9:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L9:12",
"to": "L9:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:17",
"to": "L9:23"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L9:5",
@@ -713,6 +713,45 @@
},
"type": {}
},
{
"location": {
"from": "L10:9",
"to": "L10:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:12",
"to": "L10:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L10:19",
"to": "L10:22"
},
"expr": {
"_type": "LiteralExpr",
"value": 3.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L10:5",
@@ -758,45 +797,6 @@
}
}
},
{
"location": {
"from": "L10:9",
"to": "L10:10"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:12",
"to": "L10:15"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L10:19",
"to": "L10:22"
},
"expr": {
"_type": "LiteralExpr",
"value": 3.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L10:5",
@@ -827,6 +827,19 @@
},
"type": {}
},
{
"location": {
"from": "L11:11",
"to": "L11:12"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:5",
@@ -872,19 +885,6 @@
}
}
},
{
"location": {
"from": "L11:11",
"to": "L11:12"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:5",
@@ -906,6 +906,19 @@
},
"type": {}
},
{
"location": {
"from": "L12:11",
"to": "L12:17"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L12:5",
@@ -951,19 +964,6 @@
}
}
},
{
"location": {
"from": "L12:11",
"to": "L12:17"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L12:5",
@@ -985,6 +985,45 @@
},
"type": {}
},
{
"location": {
"from": "L14:10",
"to": "L14:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L14:13",
"to": "L14:16"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L14:20",
"to": "L14:26"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L14:6",
@@ -1030,45 +1069,6 @@
}
}
},
{
"location": {
"from": "L14:10",
"to": "L14:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L14:13",
"to": "L14:16"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L14:20",
"to": "L14:26"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L14:6",
@@ -1101,6 +1101,45 @@
"name": "bool"
}
},
{
"location": {
"from": "L15:10",
"to": "L15:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L15:15",
"to": "L15:18"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L15:22",
"to": "L15:28"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L15:6",
@@ -1146,45 +1185,6 @@
}
}
},
{
"location": {
"from": "L15:10",
"to": "L15:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L15:15",
"to": "L15:18"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L15:22",
"to": "L15:28"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L15:6",
@@ -1217,6 +1217,45 @@
"name": "bool"
}
},
{
"location": {
"from": "L16:10",
"to": "L16:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L16:15",
"to": "L16:21"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L16:25",
"to": "L16:28"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L16:6",
@@ -1262,45 +1301,6 @@
}
}
},
{
"location": {
"from": "L16:10",
"to": "L16:11"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L16:15",
"to": "L16:21"
},
"expr": {
"_type": "LiteralExpr",
"value": "test"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L16:25",
"to": "L16:28"
},
"expr": {
"_type": "LiteralExpr",
"value": 2.0
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L16:6",
@@ -1333,6 +1333,45 @@
"name": "bool"
}
},
{
"location": {
"from": "L18:10",
"to": "L18:13"
},
"expr": {
"_type": "LiteralExpr",
"value": "a"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L18:15",
"to": "L18:16"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L18:20",
"to": "L18:25"
},
"expr": {
"_type": "LiteralExpr",
"value": false
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L18:6",
@@ -1378,45 +1417,6 @@
}
}
},
{
"location": {
"from": "L18:10",
"to": "L18:13"
},
"expr": {
"_type": "LiteralExpr",
"value": "a"
},
"type": {
"name": "str"
}
},
{
"location": {
"from": "L18:15",
"to": "L18:16"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L18:20",
"to": "L18:25"
},
"expr": {
"_type": "LiteralExpr",
"value": false
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L18:6",

View File

@@ -1,6 +1,19 @@
{
"diagnostics": [],
"judgments": [
{
"location": {
"from": "L4:30",
"to": "L4:36"
},
"expr": {
"_type": "LiteralExpr",
"value": 123.45
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L4:18",
@@ -16,7 +29,8 @@
"expr": {
"_type": "LiteralExpr",
"value": 123.45
}
},
"unsafe": false
},
"type": {
"name": "Meter",
@@ -25,6 +39,19 @@
}
}
},
{
"location": {
"from": "L5:28",
"to": "L5:31"
},
"expr": {
"_type": "LiteralExpr",
"value": 6.7
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L5:15",
@@ -40,7 +67,8 @@
"expr": {
"_type": "LiteralExpr",
"value": 6.7
}
},
"unsafe": false
},
"type": {
"name": "Second",

View File

@@ -100,6 +100,32 @@
"name": "float"
}
},
{
"location": {
"from": "L11:13",
"to": "L11:15"
},
"expr": {
"_type": "VariableExpr",
"name": "v1"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:17",
"to": "L11:19"
},
"expr": {
"_type": "VariableExpr",
"name": "v2"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L11:5",
@@ -135,32 +161,6 @@
}
}
},
{
"location": {
"from": "L11:13",
"to": "L11:15"
},
"expr": {
"_type": "VariableExpr",
"name": "v1"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L11:17",
"to": "L11:19"
},
"expr": {
"_type": "VariableExpr",
"name": "v2"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L11:5",

View File

@@ -72,29 +72,6 @@
}
],
"judgments": [
{
"location": {
"from": "L26:0",
"to": "L26:5"
},
"expr": {
"_type": "VariableExpr",
"name": "print"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {}
}
},
{
"location": {
"from": "L27:4",
@@ -325,6 +302,29 @@
}
}
},
{
"location": {
"from": "L26:0",
"to": "L26:5"
},
"expr": {
"_type": "VariableExpr",
"name": "print"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {}
}
},
{
"location": {
"from": "L26:0",

View File

@@ -0,0 +1,14 @@
def double(value: float) -> float:
return value * 2
def is_odd(value: int) -> bool:
return bool(value % 2)
floats: list[float] = [0.2, 0.5, 0.1, 1.2]
ints: list[int] = [1, 2, 6, -3]
doubled_floats = map(double, floats)
doubled_ints = map(double, ints)
odd_ints = map(is_odd, ints)

View File

@@ -0,0 +1,874 @@
{
"diagnostics": [
{
"type": "Error",
"location": {
"start": [
13,
15
],
"end": [
13,
32
]
},
"message": "Could not unify map[T, U]=(transform: (v: T, /) -> U, iterable: list[T], /) -> list[U] with pos=[(value: float) -> float, list[int]] and kw={}"
}
],
"judgments": [
{
"location": {
"from": "L2:11",
"to": "L2:16"
},
"expr": {
"_type": "VariableExpr",
"name": "value"
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L2:19",
"to": "L2:20"
},
"expr": {
"_type": "LiteralExpr",
"value": 2
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L2:11",
"to": "L2:20"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "value"
},
"operator": "*",
"right": {
"_type": "LiteralExpr",
"value": 2
}
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L6:16",
"to": "L6:21"
},
"expr": {
"_type": "VariableExpr",
"name": "value"
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:24",
"to": "L6:25"
},
"expr": {
"_type": "LiteralExpr",
"value": 2
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:16",
"to": "L6:25"
},
"expr": {
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "value"
},
"operator": "%",
"right": {
"_type": "LiteralExpr",
"value": 2
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L6:11",
"to": "L6:15"
},
"expr": {
"_type": "VariableExpr",
"name": "bool"
},
"type": {
"pos_args": [
{
"pos": 0,
"name": "object",
"type": {},
"required": false
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{
"location": {
"from": "L6:11",
"to": "L6:26"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "bool"
},
"arguments": [
{
"_type": "BinaryExpr",
"left": {
"_type": "VariableExpr",
"name": "value"
},
"operator": "%",
"right": {
"_type": "LiteralExpr",
"value": 2
}
}
],
"keywords": {}
},
"type": {
"name": "bool"
}
},
{
"location": {
"from": "L9:23",
"to": "L9:26"
},
"expr": {
"_type": "LiteralExpr",
"value": 0.2
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:28",
"to": "L9:31"
},
"expr": {
"_type": "LiteralExpr",
"value": 0.5
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:33",
"to": "L9:36"
},
"expr": {
"_type": "LiteralExpr",
"value": 0.1
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:38",
"to": "L9:41"
},
"expr": {
"_type": "LiteralExpr",
"value": 1.2
},
"type": {
"name": "float"
}
},
{
"location": {
"from": "L9:22",
"to": "L9:42"
},
"expr": {
"_type": "ListExpr",
"items": [
{
"_type": "LiteralExpr",
"value": 0.2
},
{
"_type": "LiteralExpr",
"value": 0.5
},
{
"_type": "LiteralExpr",
"value": 0.1
},
{
"_type": "LiteralExpr",
"value": 1.2
}
]
},
"type": {
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L10:19",
"to": "L10:20"
},
"expr": {
"_type": "LiteralExpr",
"value": 1
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:22",
"to": "L10:23"
},
"expr": {
"_type": "LiteralExpr",
"value": 2
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:25",
"to": "L10:26"
},
"expr": {
"_type": "LiteralExpr",
"value": 6
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:29",
"to": "L10:30"
},
"expr": {
"_type": "LiteralExpr",
"value": 3
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:28",
"to": "L10:30"
},
"expr": {
"_type": "UnaryExpr",
"operator": "-",
"right": {
"_type": "LiteralExpr",
"value": 3
}
},
"type": {
"name": "int"
}
},
{
"location": {
"from": "L10:18",
"to": "L10:31"
},
"expr": {
"_type": "ListExpr",
"items": [
{
"_type": "LiteralExpr",
"value": 1
},
{
"_type": "LiteralExpr",
"value": 2
},
{
"_type": "LiteralExpr",
"value": 6
},
{
"_type": "UnaryExpr",
"operator": "-",
"right": {
"_type": "LiteralExpr",
"value": 3
}
}
]
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L12:21",
"to": "L12:27"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L12:29",
"to": "L12:35"
},
"expr": {
"_type": "VariableExpr",
"name": "floats"
},
"type": {
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L12:17",
"to": "L12:20"
},
"expr": {
"_type": "VariableExpr",
"name": "map"
},
"type": {
"name": "map",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"pos_args": [
{
"pos": 0,
"name": "transform",
"type": {
"pos_args": [
{
"pos": 0,
"name": "v",
"type": {
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
},
"required": true
},
{
"pos": 1,
"name": "iterable",
"type": {
"name": "list",
"args": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "list",
"args": [
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
}
}
}
},
{
"location": {
"from": "L12:17",
"to": "L12:36"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "map"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "double"
},
{
"_type": "VariableExpr",
"name": "floats"
}
],
"keywords": {}
},
"type": {
"name": "list",
"args": [
{
"name": "float"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L13:19",
"to": "L13:25"
},
"expr": {
"_type": "VariableExpr",
"name": "double"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "float"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "float"
}
}
},
{
"location": {
"from": "L13:27",
"to": "L13:31"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L13:15",
"to": "L13:18"
},
"expr": {
"_type": "VariableExpr",
"name": "map"
},
"type": {
"name": "map",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"pos_args": [
{
"pos": 0,
"name": "transform",
"type": {
"pos_args": [
{
"pos": 0,
"name": "v",
"type": {
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
},
"required": true
},
{
"pos": 1,
"name": "iterable",
"type": {
"name": "list",
"args": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "list",
"args": [
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
}
}
}
},
{
"location": {
"from": "L13:15",
"to": "L13:32"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "map"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "double"
},
{
"_type": "VariableExpr",
"name": "ints"
}
],
"keywords": {}
},
"type": {}
},
{
"location": {
"from": "L14:15",
"to": "L14:21"
},
"expr": {
"_type": "VariableExpr",
"name": "is_odd"
},
"type": {
"pos_args": [],
"args": [
{
"pos": 0,
"name": "value",
"type": {
"name": "int"
},
"required": true
}
],
"kw_args": [],
"returns": {
"name": "bool"
}
}
},
{
"location": {
"from": "L14:23",
"to": "L14:27"
},
"expr": {
"_type": "VariableExpr",
"name": "ints"
},
"type": {
"name": "list",
"args": [
{
"name": "int"
}
],
"body": {
"name": "list"
}
}
},
{
"location": {
"from": "L14:11",
"to": "L14:14"
},
"expr": {
"_type": "VariableExpr",
"name": "map"
},
"type": {
"name": "map",
"params": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"pos_args": [
{
"pos": 0,
"name": "transform",
"type": {
"pos_args": [
{
"pos": 0,
"name": "v",
"type": {
"name": "T",
"bound": null,
"variance": "INVARIANT"
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
},
"required": true
},
{
"pos": 1,
"name": "iterable",
"type": {
"name": "list",
"args": [
{
"name": "T",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
},
"required": true
}
],
"args": [],
"kw_args": [],
"returns": {
"name": "list",
"args": [
{
"name": "U",
"bound": null,
"variance": "INVARIANT"
}
],
"body": {
"name": "list"
}
}
}
}
},
{
"location": {
"from": "L14:11",
"to": "L14:28"
},
"expr": {
"_type": "CallExpr",
"callee": {
"_type": "VariableExpr",
"name": "map"
},
"arguments": [
{
"_type": "VariableExpr",
"name": "is_odd"
},
{
"_type": "VariableExpr",
"name": "ints"
}
],
"keywords": {}
},
"type": {
"name": "list",
"args": [
{
"name": "bool"
}
],
"body": {
"name": "list"
}
}
}
]
}

View File

@@ -7,68 +7,14 @@ Module(
alias(name='Meter'),
alias(name='Second')],
level=0),
Assign(
targets=[
Name(id='__midas_a0__')],
value=Constant(value=123.45)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a0__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L3:19: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a0__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
Assign(
targets=[
Name(id='distance')],
value=Name(id='__midas_a0__')),
Delete(
targets=[
Name(id='__midas_a0__')]),
Assign(
targets=[
Name(id='__midas_a1__')],
value=Constant(value=6.7)),
Assert(
test=Call(
func=Name(id='isinstance'),
args=[
Name(id='__midas_a1__'),
Name(id='float')],
keywords=[]),
msg=JoinedStr(
values=[
Constant(value='01_simple_types.py:L4:16: CastError: Cannot cast '),
FormattedValue(
value=Attribute(
value=Call(
func=Name(id='type'),
args=[
Name(id='__midas_a1__')],
keywords=[]),
attr='__name__'),
conversion=-1),
Constant(value=' to float')])),
value=Constant(value=123.45)),
Assign(
targets=[
Name(id='time')],
value=Name(id='__midas_a1__')),
Delete(
targets=[
Name(id='__midas_a1__')]),
value=Constant(value=6.7)),
Assign(
targets=[
Name(id='speed')],

View File

@@ -8,6 +8,7 @@ from midas.ast.midas import (
Expr,
ExtendStmt,
ExtensionType,
FrameType,
FunctionType,
GenericType,
GetExpr,
@@ -197,3 +198,15 @@ class MidasAstJsonSerializer(
"base": type.base.accept(self),
"extension": type.extension.accept(self),
}
def visit_frame_type(self, type: FrameType) -> dict:
return {
"_type": "FrameType",
"columns": [self._serialize_column(col) for col in type.columns],
}
def _serialize_column(self, column: FrameType.Column):
return {
"name": column.name.lexeme,
"type": column.type.accept(self),
}

View File

@@ -263,6 +263,7 @@ class PythonAstJsonSerializer(
"_type": "CastExpr",
"type": expr.type.accept(self),
"expr": expr.expr.accept(self),
"unsafe": expr.unsafe,
}
def visit_ternary_expr(self, expr: TernaryExpr) -> dict: