Compare commits

...

6 Commits

7 changed files with 134 additions and 14 deletions

View File

@@ -179,3 +179,99 @@ extend dict[K, V] {
// def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V] // def __ior__: fn(value: Iterable[tuple[K, V]], /) -> dict[K, V]
} }
extend str {
def capitalize: fn() -> str
def casefold: fn() -> str
def center: fn(width: int, fillchar: str?, /) -> str
def count: fn(sub: str, start: None?, end: None?, /) -> int
def count: fn(sub: str, start: int, end: None?, /) -> int
def count: fn(sub: str, start: None, end: int, /) -> int
def count: fn(sub: str, start: int, end: int, /) -> int
def encode: fn(encoding: str?, errors: str?) -> bytes
def endswith: fn(suffix: str, start: None?, end: None?, /) -> bool
def endswith: fn(suffix: str, start: int, end: None?, /) -> bool
def endswith: fn(suffix: str, start: None, end: int, /) -> bool
def endswith: fn(suffix: str, start: int, end: int, /) -> bool
def expandtabs: fn(tabsize: int?) -> str
def find: fn(sub: str, start: None?, end: None?, /) -> int
def find: fn(sub: str, start: int, end: None?, /) -> int
def find: fn(sub: str, start: None, end: int, /) -> int
def find: fn(sub: str, start: int, end: int, /) -> int
// def format: fn(*args: object, **kwargs: object) -> str
// def format_map: fn(mapping: _FormatMapMapping, /) -> str
def index: fn(sub: str, start: None?, end: None?, /) -> int
def index: fn(sub: str, start: int, end: None?, /) -> int
def index: fn(sub: str, start: None, end: int, /) -> int
def index: fn(sub: str, start: int, end: int, /) -> int
def isalnum: fn() -> bool
def isalpha: fn() -> bool
def isascii: fn() -> bool
def isdecimal: fn() -> bool
def isdigit: fn() -> bool
def isidentifier: fn() -> bool
def islower: fn() -> bool
def isnumeric: fn() -> bool
def isprintable: fn() -> bool
def isspace: fn() -> bool
def istitle: fn() -> bool
def isupper: fn() -> bool
def join: fn(iterable: list[str], /) -> str // TODO: use Iterable
def ljust: fn(width: int, fillchar: str?, /) -> str
def lower: fn() -> str
def lstrip: fn(chars: None?, /) -> str
def lstrip: fn(chars: str, /) -> str
def partition: fn(sep: str, /) -> tuple[str, str, str]
def replace: fn(old: str, new: str, count: int?, /) -> str
def removeprefix: fn(prefix: str, /) -> str
def removesuffix: fn(suffix: str, /) -> str
def rfind: fn(sub: str, start: None?, end: None?, /) -> int
def rfind: fn(sub: str, start: int, end: None?, /) -> int
def rfind: fn(sub: str, start: None, end: int, /) -> int
def rfind: fn(sub: str, start: int, end: int, /) -> int
def rindex: fn(sub: str, start: None?, end: None?, /) -> int
def rindex: fn(sub: str, start: int, end: None?, /) -> int
def rindex: fn(sub: str, start: None, end: int, /) -> int
def rindex: fn(sub: str, start: int, end: int, /) -> int
def rjust: fn(width: int, fillchar: str?, /) -> str
def rpartition: fn(sep: str, /) -> tuple[str, str, str]
def rsplit: fn(sep: None?, maxsplit: int?) -> list[str]
def rsplit: fn(sep: str, maxsplit: int?) -> list[str]
def rstrip: fn(chars: None?, /) -> str
def rstrip: fn(chars: str, /) -> str
def split: fn(sep: None?, maxsplit: int?) -> list[str]
def split: fn(sep: str, maxsplit: int?) -> list[str]
def splitlines: fn(keepends: bool?) -> list[str]
def startswith: fn(prefix: str, start: None?, end: None?, /) -> bool
def startswith: fn(prefix: str, start: int, end: None?, /) -> bool
def startswith: fn(prefix: str, start: None, end: int, /) -> bool
def startswith: fn(prefix: str, start: int, end: int, /) -> bool
def strip: fn(chars: None?, /) -> str
def strip: fn(chars: str, /) -> str
def swapcase: fn() -> str
def title: fn() -> str
// def translate: fn(table: _TranslateTable, /) -> str
def upper: fn() -> str
def zfill: fn(width: int, /) -> str
def __add__: fn(value: str, /) -> str
// Incompatible with Sequence.__contains__
def __contains__: fn(key: str, /) -> bool
def __eq__: fn(value: object, /) -> bool
def __ge__: fn(value: str, /) -> bool
def __getitem__: fn(key: slice, /) -> str
def __getitem__: fn(key: int, /) -> str
def __gt__: fn(value: str, /) -> bool
def __hash__: fn() -> int
// def __iter__: fn() -> Iterator[str]
def __le__: fn(value: str, /) -> bool
def __len__: fn() -> int
def __lt__: fn(value: str, /) -> bool
def __mod__: fn(value: Any, /) -> str
def __mul__: fn(value: int, /) -> str
def __ne__: fn(value: object, /) -> bool
def __rmul__: fn(value: int, /) -> str
def __getnewargs__: fn() -> tuple[str]
def __format__: fn(format_spec: str, /) -> str
}

View File

@@ -15,7 +15,7 @@ if TYPE_CHECKING:
BUILTIN_SUBTYPES: dict[str, set[str]] = { BUILTIN_SUBTYPES: dict[str, set[str]] = {
"object": {"float", "list", "dict", "str"}, "object": {"float", "list", "dict", "str", "bytes", "tuple"},
"float": {"int"}, "float": {"int"},
"int": {"bool"}, "int": {"bool"},
} }
@@ -26,12 +26,15 @@ def define_builtins(reg: TypesRegistry):
any = reg.define_type("Any", TopType()) any = reg.define_type("Any", TopType())
unit = reg.define_type("None", UnitType()) unit = reg.define_type("None", UnitType())
object = reg.define_type("object", BaseType(name="object")) object = reg.define_type("object", BaseType(name="object"))
bytes = reg.define_type("bytes", BaseType(name="bytes"))
bool = reg.define_type("bool", BaseType(name="bool")) bool = reg.define_type("bool", BaseType(name="bool"))
int = reg.define_type("int", BaseType(name="int")) int = reg.define_type("int", BaseType(name="int"))
float = reg.define_type("float", BaseType(name="float")) float = reg.define_type("float", BaseType(name="float"))
str = reg.define_type("str", BaseType(name="str")) str = reg.define_type("str", BaseType(name="str"))
slice = reg.define_type("slice", BaseType(name="slice")) slice = reg.define_type("slice", BaseType(name="slice"))
tuple = reg.define_type("tuple", BaseType(name="tuple"))
list = reg.define_type( list = reg.define_type(
"list", "list",
GenericType( GenericType(

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Optional from typing import Any, Callable, Optional
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.registry import TypesRegistry from midas.checker.registry import TypesRegistry
@@ -17,7 +17,7 @@ class Preamble(Environment):
def __init__(self, types: TypesRegistry) -> None: def __init__(self, types: TypesRegistry) -> None:
super().__init__() super().__init__()
self._types: TypesRegistry = types self._types: TypesRegistry = types
self._python_funcs: dict[str, Callable] = {} self._python_funcs: dict[str, Callable[..., Any]] = {}
self._def_type_constructor("object", object) self._def_type_constructor("object", object)
self._def_type_constructor("float", float) self._def_type_constructor("float", float)
@@ -34,7 +34,7 @@ class Preamble(Environment):
# TODO: use sink # TODO: use sink
self._def_function( self._def_function(
name="print", name="print",
pos=[Param("object", TopType())], pos=[Param("object", TopType(), required=False)],
returns=UnitType(), returns=UnitType(),
py_function=print, py_function=print,
) )
@@ -64,11 +64,18 @@ class Preamble(Environment):
pos=[Param("prompt", TopType(), required=False)], pos=[Param("prompt", TopType(), required=False)],
returns=self._types.get_type("str"), returns=self._types.get_type("str"),
) )
self._def_function(
name="len",
pos=[Param("object", TopType())],
returns=self._types.get_type("int"),
)
def _list_of(self, item_type: Type) -> Type: def _list_of(self, item_type: Type) -> Type:
return self._types.apply_generic(self._types.get_type("list"), [item_type]) return self._types.apply_generic(self._types.get_type("list"), [item_type])
def _def_type_constructor(self, name: str, py_function: Optional[Callable] = None): def _def_type_constructor(
self, name: str, py_function: Optional[Callable[..., Any]] = None
):
# TODO: more specific arg types # TODO: more specific arg types
self._def_function( self._def_function(
name=name, name=name,
@@ -121,7 +128,7 @@ class Preamble(Environment):
kw: list[Param] = [], kw: list[Param] = [],
returns: Type = UnitType(), returns: Type = UnitType(),
type_vars: list[TypeVar] = [], type_vars: list[TypeVar] = [],
py_function: Optional[Callable] = None, py_function: Optional[Callable[..., Any]] = None,
): ):
function: Type = self._make_function( function: Type = self._make_function(
name=name, name=name,
@@ -135,5 +142,5 @@ class Preamble(Environment):
if py_function is not None: if py_function is not None:
self._python_funcs[name] = py_function self._python_funcs[name] = py_function
def get_py_func(self, name: str) -> Optional[Callable]: def get_py_func(self, name: str) -> Optional[Callable[..., Any]]:
return self._python_funcs.get(name) return self._python_funcs.get(name)

View File

@@ -413,13 +413,16 @@ class PythonTyper(
value_type: Type, value_type: Type,
): ):
var_type: Type = self.type_of(var) var_type: Type = self.type_of(var)
unfolded_type: Type = unfold_type(var_type)
# TODO: what happens if type is an alias of a dataframe type # TODO: what happens if type is an alias of a dataframe type
match var_type: match unfolded_type:
case DataFrameType() as frame: case DataFrameType() as frame:
new_type: Type = self.frame_mgr.assign( new_type: Type = self.frame_mgr.assign(
self.reporter, location, frame, index, value_type self.reporter, location, frame, index, value_type
) )
self.env.assign(var.name, new_type) self.env.assign(var.name, new_type)
case UnknownType():
return
case _: case _:
self.reporter.error( self.reporter.error(
location, location,
@@ -582,7 +585,7 @@ class PythonTyper(
object: Type = self.type_of(expr.object) object: Type = self.type_of(expr.object)
member: Optional[Type] = self.types.lookup_member(object, expr.name) member: Optional[Type] = self.types.lookup_member(object, expr.name)
if member is None: if member is None:
self.reporter.error( self.reporter.warning(
expr.location, f"Unknown member '{expr.name}' of {object}" expr.location, f"Unknown member '{expr.name}' of {object}"
) )
return UnknownType() return UnknownType()

View File

@@ -18,6 +18,7 @@ from midas.checker.types import (
OverloadedFunction, OverloadedFunction,
Predicate, Predicate,
TopType, TopType,
TupleType,
Type, Type,
TypeVar, TypeVar,
UnknownType, UnknownType,
@@ -346,6 +347,9 @@ class TypesRegistry:
body=substitute_typevars(body, substitutions), body=substitute_typevars(body, substitutions),
) )
case BaseType(name="tuple"):
return TupleType(items=tuple(args))
case _: case _:
raise ValueError(f"{type} is not a generic type") raise ValueError(f"{type} is not a generic type")

View File

@@ -440,7 +440,11 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
), ),
), ),
] ]
asserts.append(self._make_column_inner_assert(src_location, expr, type)) inner_assert: Optional[ast.stmt] = self._make_column_inner_assert(
src_location, expr, type
)
if inner_assert is not None:
asserts.append(inner_assert)
return asserts return asserts
case ( case (
@@ -592,12 +596,15 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]):
def _make_column_inner_assert( def _make_column_inner_assert(
self, src_location: Location, column: ast.expr, type: ColumnType self, src_location: Location, column: ast.expr, type: ColumnType
) -> ast.stmt: ) -> Optional[ast.stmt]:
# TODO: improve message, maybe chain contexts # TODO: improve message, maybe chain contexts
col: ast.expr = ast.Name(id="col") col: ast.expr = ast.Name(id="col")
body: list[ast.stmt] = self._make_cast_asserts(src_location, col, type.type)
if len(body) == 0:
return None
return ast.For( return ast.For(
target=col, target=col,
iter=column, iter=column,
body=self._make_cast_asserts(src_location, col, type.type), body=body,
orelse=[], orelse=[],
) )

View File

@@ -377,7 +377,7 @@ class MidasParser(Parser):
pos_args: list[Expr] = [] pos_args: list[Expr] = []
kw_args: dict[str, Expr] = {} kw_args: dict[str, Expr] = {}
keywords: bool = False keywords: bool = False
while not self.match(TokenType.RIGHT_PAREN): while not self.check(TokenType.RIGHT_PAREN):
if self.check_identifier() and self.check_next(TokenType.EQUAL): if self.check_identifier() and self.check_next(TokenType.EQUAL):
keywords = True keywords = True
keyword: Token = self.advance() keyword: Token = self.advance()