Compare commits
6 Commits
f796f4c6fa
...
78eba39ae3
| Author | SHA1 | Date | |
|---|---|---|---|
|
78eba39ae3
|
|||
|
3b78b37306
|
|||
|
9e14b30bc9
|
|||
|
a6a1075f91
|
|||
|
11be47fce3
|
|||
|
2eeede9826
|
@@ -178,4 +178,100 @@ extend dict[K, V] {
|
||||
// def __ior__: fn(value: SupportsKeysAndGetItem[K, V], /) -> dict[K, V]
|
||||
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
extend str {
|
||||
def capitalize: fn() -> str
|
||||
def casefold: fn() -> str
|
||||
def center: fn(width: int, fillchar: str?, /) -> str
|
||||
def count: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def count: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def count: fn(sub: str, start: None, end: int, /) -> int
|
||||
def count: fn(sub: str, start: int, end: int, /) -> int
|
||||
def encode: fn(encoding: str?, errors: str?) -> bytes
|
||||
def endswith: fn(suffix: str, start: None?, end: None?, /) -> bool
|
||||
def endswith: fn(suffix: str, start: int, end: None?, /) -> bool
|
||||
def endswith: fn(suffix: str, start: None, end: int, /) -> bool
|
||||
def endswith: fn(suffix: str, start: int, end: int, /) -> bool
|
||||
def expandtabs: fn(tabsize: int?) -> str
|
||||
def find: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def find: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def find: fn(sub: str, start: None, end: int, /) -> int
|
||||
def find: fn(sub: str, start: int, end: int, /) -> int
|
||||
// def format: fn(*args: object, **kwargs: object) -> str
|
||||
// def format_map: fn(mapping: _FormatMapMapping, /) -> str
|
||||
def index: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def index: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def index: fn(sub: str, start: None, end: int, /) -> int
|
||||
def index: fn(sub: str, start: int, end: int, /) -> int
|
||||
def isalnum: fn() -> bool
|
||||
def isalpha: fn() -> bool
|
||||
def isascii: fn() -> bool
|
||||
def isdecimal: fn() -> bool
|
||||
def isdigit: fn() -> bool
|
||||
def isidentifier: fn() -> bool
|
||||
def islower: fn() -> bool
|
||||
def isnumeric: fn() -> bool
|
||||
def isprintable: fn() -> bool
|
||||
def isspace: fn() -> bool
|
||||
def istitle: fn() -> bool
|
||||
def isupper: fn() -> bool
|
||||
def join: fn(iterable: list[str], /) -> str // TODO: use Iterable
|
||||
def ljust: fn(width: int, fillchar: str?, /) -> str
|
||||
def lower: fn() -> str
|
||||
def lstrip: fn(chars: None?, /) -> str
|
||||
def lstrip: fn(chars: str, /) -> str
|
||||
def partition: fn(sep: str, /) -> tuple[str, str, str]
|
||||
|
||||
def replace: fn(old: str, new: str, count: int?, /) -> str
|
||||
|
||||
def removeprefix: fn(prefix: str, /) -> str
|
||||
def removesuffix: fn(suffix: str, /) -> str
|
||||
def rfind: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def rfind: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def rfind: fn(sub: str, start: None, end: int, /) -> int
|
||||
def rfind: fn(sub: str, start: int, end: int, /) -> int
|
||||
def rindex: fn(sub: str, start: None?, end: None?, /) -> int
|
||||
def rindex: fn(sub: str, start: int, end: None?, /) -> int
|
||||
def rindex: fn(sub: str, start: None, end: int, /) -> int
|
||||
def rindex: fn(sub: str, start: int, end: int, /) -> int
|
||||
def rjust: fn(width: int, fillchar: str?, /) -> str
|
||||
def rpartition: fn(sep: str, /) -> tuple[str, str, str]
|
||||
def rsplit: fn(sep: None?, maxsplit: int?) -> list[str]
|
||||
def rsplit: fn(sep: str, maxsplit: int?) -> list[str]
|
||||
def rstrip: fn(chars: None?, /) -> str
|
||||
def rstrip: fn(chars: str, /) -> str
|
||||
def split: fn(sep: None?, maxsplit: int?) -> list[str]
|
||||
def split: fn(sep: str, maxsplit: int?) -> list[str]
|
||||
def splitlines: fn(keepends: bool?) -> list[str]
|
||||
def startswith: fn(prefix: str, start: None?, end: None?, /) -> bool
|
||||
def startswith: fn(prefix: str, start: int, end: None?, /) -> bool
|
||||
def startswith: fn(prefix: str, start: None, end: int, /) -> bool
|
||||
def startswith: fn(prefix: str, start: int, end: int, /) -> bool
|
||||
def strip: fn(chars: None?, /) -> str
|
||||
def strip: fn(chars: str, /) -> str
|
||||
def swapcase: fn() -> str
|
||||
def title: fn() -> str
|
||||
// def translate: fn(table: _TranslateTable, /) -> str
|
||||
def upper: fn() -> str
|
||||
def zfill: fn(width: int, /) -> str
|
||||
def __add__: fn(value: str, /) -> str
|
||||
// Incompatible with Sequence.__contains__
|
||||
def __contains__: fn(key: str, /) -> bool
|
||||
def __eq__: fn(value: object, /) -> bool
|
||||
def __ge__: fn(value: str, /) -> bool
|
||||
def __getitem__: fn(key: slice, /) -> str
|
||||
def __getitem__: fn(key: int, /) -> str
|
||||
def __gt__: fn(value: str, /) -> bool
|
||||
def __hash__: fn() -> int
|
||||
// def __iter__: fn() -> Iterator[str]
|
||||
def __le__: fn(value: str, /) -> bool
|
||||
def __len__: fn() -> int
|
||||
def __lt__: fn(value: str, /) -> bool
|
||||
def __mod__: fn(value: Any, /) -> str
|
||||
def __mul__: fn(value: int, /) -> str
|
||||
def __ne__: fn(value: object, /) -> bool
|
||||
def __rmul__: fn(value: int, /) -> str
|
||||
def __getnewargs__: fn() -> tuple[str]
|
||||
def __format__: fn(format_spec: str, /) -> str
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
BUILTIN_SUBTYPES: dict[str, set[str]] = {
|
||||
"object": {"float", "list", "dict", "str"},
|
||||
"object": {"float", "list", "dict", "str", "bytes", "tuple"},
|
||||
"float": {"int"},
|
||||
"int": {"bool"},
|
||||
}
|
||||
@@ -26,12 +26,15 @@ def define_builtins(reg: TypesRegistry):
|
||||
any = reg.define_type("Any", TopType())
|
||||
unit = reg.define_type("None", UnitType())
|
||||
object = reg.define_type("object", BaseType(name="object"))
|
||||
bytes = reg.define_type("bytes", BaseType(name="bytes"))
|
||||
bool = reg.define_type("bool", BaseType(name="bool"))
|
||||
int = reg.define_type("int", BaseType(name="int"))
|
||||
float = reg.define_type("float", BaseType(name="float"))
|
||||
str = reg.define_type("str", BaseType(name="str"))
|
||||
slice = reg.define_type("slice", BaseType(name="slice"))
|
||||
|
||||
tuple = reg.define_type("tuple", BaseType(name="tuple"))
|
||||
|
||||
list = reg.define_type(
|
||||
"list",
|
||||
GenericType(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.registry import TypesRegistry
|
||||
@@ -17,7 +17,7 @@ class Preamble(Environment):
|
||||
def __init__(self, types: TypesRegistry) -> None:
|
||||
super().__init__()
|
||||
self._types: TypesRegistry = types
|
||||
self._python_funcs: dict[str, Callable] = {}
|
||||
self._python_funcs: dict[str, Callable[..., Any]] = {}
|
||||
|
||||
self._def_type_constructor("object", object)
|
||||
self._def_type_constructor("float", float)
|
||||
@@ -34,7 +34,7 @@ class Preamble(Environment):
|
||||
# TODO: use sink
|
||||
self._def_function(
|
||||
name="print",
|
||||
pos=[Param("object", TopType())],
|
||||
pos=[Param("object", TopType(), required=False)],
|
||||
returns=UnitType(),
|
||||
py_function=print,
|
||||
)
|
||||
@@ -64,11 +64,18 @@ class Preamble(Environment):
|
||||
pos=[Param("prompt", TopType(), required=False)],
|
||||
returns=self._types.get_type("str"),
|
||||
)
|
||||
self._def_function(
|
||||
name="len",
|
||||
pos=[Param("object", TopType())],
|
||||
returns=self._types.get_type("int"),
|
||||
)
|
||||
|
||||
def _list_of(self, item_type: Type) -> Type:
|
||||
return self._types.apply_generic(self._types.get_type("list"), [item_type])
|
||||
|
||||
def _def_type_constructor(self, name: str, py_function: Optional[Callable] = None):
|
||||
def _def_type_constructor(
|
||||
self, name: str, py_function: Optional[Callable[..., Any]] = None
|
||||
):
|
||||
# TODO: more specific arg types
|
||||
self._def_function(
|
||||
name=name,
|
||||
@@ -121,7 +128,7 @@ class Preamble(Environment):
|
||||
kw: list[Param] = [],
|
||||
returns: Type = UnitType(),
|
||||
type_vars: list[TypeVar] = [],
|
||||
py_function: Optional[Callable] = None,
|
||||
py_function: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
function: Type = self._make_function(
|
||||
name=name,
|
||||
@@ -135,5 +142,5 @@ class Preamble(Environment):
|
||||
if py_function is not None:
|
||||
self._python_funcs[name] = py_function
|
||||
|
||||
def get_py_func(self, name: str) -> Optional[Callable]:
|
||||
def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
|
||||
return self._python_funcs.get(name)
|
||||
|
||||
@@ -413,13 +413,16 @@ class PythonTyper(
|
||||
value_type: Type,
|
||||
):
|
||||
var_type: Type = self.type_of(var)
|
||||
unfolded_type: Type = unfold_type(var_type)
|
||||
# TODO: what happens if type is an alias of a dataframe type
|
||||
match var_type:
|
||||
match unfolded_type:
|
||||
case DataFrameType() as frame:
|
||||
new_type: Type = self.frame_mgr.assign(
|
||||
self.reporter, location, frame, index, value_type
|
||||
)
|
||||
self.env.assign(var.name, new_type)
|
||||
case UnknownType():
|
||||
return
|
||||
case _:
|
||||
self.reporter.error(
|
||||
location,
|
||||
@@ -582,7 +585,7 @@ class PythonTyper(
|
||||
object: Type = self.type_of(expr.object)
|
||||
member: Optional[Type] = self.types.lookup_member(object, expr.name)
|
||||
if member is None:
|
||||
self.reporter.error(
|
||||
self.reporter.warning(
|
||||
expr.location, f"Unknown member '{expr.name}' of {object}"
|
||||
)
|
||||
return UnknownType()
|
||||
|
||||
@@ -18,6 +18,7 @@ from midas.checker.types import (
|
||||
OverloadedFunction,
|
||||
Predicate,
|
||||
TopType,
|
||||
TupleType,
|
||||
Type,
|
||||
TypeVar,
|
||||
UnknownType,
|
||||
@@ -346,6 +347,9 @@ class TypesRegistry:
|
||||
body=substitute_typevars(body, substitutions),
|
||||
)
|
||||
|
||||
case BaseType(name="tuple"):
|
||||
return TupleType(items=tuple(args))
|
||||
|
||||
case _:
|
||||
raise ValueError(f"{type} is not a generic type")
|
||||
|
||||
|
||||
@@ -440,7 +440,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
),
|
||||
),
|
||||
]
|
||||
asserts.append(self._make_column_inner_assert(src_location, expr, type))
|
||||
inner_assert: Optional[ast.stmt] = self._make_column_inner_assert(
|
||||
src_location, expr, type
|
||||
)
|
||||
if inner_assert is not None:
|
||||
asserts.append(inner_assert)
|
||||
return asserts
|
||||
|
||||
case (
|
||||
@@ -592,12 +596,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
|
||||
|
||||
def _make_column_inner_assert(
|
||||
self, src_location: Location, column: ast.expr, type: ColumnType
|
||||
) -> ast.stmt:
|
||||
) -> Optional[ast.stmt]:
|
||||
# TODO: improve message, maybe chain contexts
|
||||
col: ast.expr = ast.Name(id="col")
|
||||
body: list[ast.stmt] = self._make_cast_asserts(src_location, col, type.type)
|
||||
if len(body) == 0:
|
||||
return None
|
||||
return ast.For(
|
||||
target=col,
|
||||
iter=column,
|
||||
body=self._make_cast_asserts(src_location, col, type.type),
|
||||
body=body,
|
||||
orelse=[],
|
||||
)
|
||||
|
||||
@@ -377,7 +377,7 @@ class MidasParser(Parser):
|
||||
pos_args: list[Expr] = []
|
||||
kw_args: dict[str, Expr] = {}
|
||||
keywords: bool = False
|
||||
while not self.match(TokenType.RIGHT_PAREN):
|
||||
while not self.check(TokenType.RIGHT_PAREN):
|
||||
if self.check_identifier() and self.check_next(TokenType.EQUAL):
|
||||
keywords = True
|
||||
keyword: Token = self.advance()
|
||||
|
||||
Reference in New Issue
Block a user