feat(checker): handle calls on group-bys
This commit is contained in:
66
midas/checker/frames/column_groupby_methods.py
Normal file
66
midas/checker/frames/column_groupby_methods.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import ColumnGroupBy, Function, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
groupby: ColumnGroupBy
|
||||
groupby_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
|
||||
class ColumnGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
bool_ = self.types.get_type("bool")
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="numeric_only",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="skipna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="engine",
|
||||
type=self.types.get_type("str"),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="engine_kwargs",
|
||||
type=self.types.get_type("dict"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.groupby.column,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
@@ -4,8 +4,10 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.column_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.column_groupby_methods import ColumnGroupByMethodRegistry
|
||||
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
|
||||
from midas.checker.types import ColumnType, Type
|
||||
from midas.checker.types import ColumnGroupBy, ColumnType, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
@@ -15,6 +17,9 @@ class ColumnManager:
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
self.method_resolver: ColumnMethodRegistry = ColumnMethodRegistry(self.typer)
|
||||
self.groupby_method_resolver: ColumnGroupByMethodRegistry = (
|
||||
ColumnGroupByMethodRegistry(self.typer)
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
@@ -35,3 +40,23 @@ class ColumnManager:
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.method_resolver.call(method, call)
|
||||
|
||||
def groupby_call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
groupby: ColumnGroupBy,
|
||||
groupby_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: GroupByCall = GroupByCall(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=groupby,
|
||||
groupby_expr=groupby_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
|
||||
66
midas/checker/frames/frame_groupby_methods.py
Normal file
66
midas/checker/frames/frame_groupby_methods.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import FrameGroupBy, Function, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
call_expr: p.Expr
|
||||
groupby: FrameGroupBy
|
||||
groupby_expr: p.Expr
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
|
||||
class FrameGroupByMethodRegistry(MethodRegistry[Call]):
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
bool_ = self.types.get_type("bool")
|
||||
signature = Function(
|
||||
args=[
|
||||
Function.Argument(
|
||||
pos=0,
|
||||
name="numeric_only",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=1,
|
||||
name="skipna",
|
||||
type=bool_,
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=2,
|
||||
name="engine",
|
||||
type=self.types.get_type("str"),
|
||||
required=False,
|
||||
),
|
||||
Function.Argument(
|
||||
pos=3,
|
||||
name="engine_kwargs",
|
||||
type=self.types.get_type("dict"),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
returns=call.groupby.frame,
|
||||
)
|
||||
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
return result.result
|
||||
@@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frames.frame_groupby_methods import Call as GroupByCall
|
||||
from midas.checker.frames.frame_groupby_methods import FrameGroupByMethodRegistry
|
||||
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
@@ -28,6 +30,9 @@ class FrameManager:
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer)
|
||||
self.groupby_method_resolver: FrameGroupByMethodRegistry = (
|
||||
FrameGroupByMethodRegistry(self.typer)
|
||||
)
|
||||
|
||||
def assign(
|
||||
self,
|
||||
@@ -184,3 +189,23 @@ class FrameManager:
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.method_resolver.call(method, call)
|
||||
|
||||
def groupby_call(
|
||||
self,
|
||||
method: str,
|
||||
location: Location,
|
||||
call_expr: p.Expr,
|
||||
groupby: FrameGroupBy,
|
||||
groupby_expr: p.Expr,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
call: GroupByCall = GroupByCall(
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=groupby,
|
||||
groupby_expr=groupby_expr,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.groupby_method_resolver.call(method, call)
|
||||
|
||||
@@ -23,6 +23,7 @@ from midas.checker.resolver import Resolver
|
||||
from midas.checker.types import (
|
||||
AppliedType,
|
||||
BaseType,
|
||||
ColumnGroupBy,
|
||||
ColumnType,
|
||||
ConstraintType,
|
||||
DataFrameType,
|
||||
@@ -235,6 +236,17 @@ class PythonTyper(
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
case FrameGroupBy():
|
||||
return self.frame_mgr.groupby_call(
|
||||
method=method_name,
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=unfolded,
|
||||
groupby_expr=obj[0],
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
case ColumnType():
|
||||
return self.column_mgr.call(
|
||||
method=method_name,
|
||||
@@ -246,6 +258,17 @@ class PythonTyper(
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
case ColumnGroupBy():
|
||||
return self.column_mgr.groupby_call(
|
||||
method=method_name,
|
||||
location=location,
|
||||
call_expr=call_expr,
|
||||
groupby=unfolded,
|
||||
groupby_expr=obj[0],
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
method: Optional[Type] = self.types.lookup_member(obj[1], method_name)
|
||||
if method is None:
|
||||
raise UndefinedMethodException
|
||||
@@ -612,17 +635,17 @@ class PythonTyper(
|
||||
match expr.callee:
|
||||
case p.GetExpr(object=obj, name=method):
|
||||
obj_type: Type = self.type_of(obj)
|
||||
unfolded: Type = unfold_type(obj_type)
|
||||
if isinstance(unfolded, DataFrameType):
|
||||
return self.frame_mgr.call(
|
||||
method=method,
|
||||
return (
|
||||
self.call_method(
|
||||
location=expr.location,
|
||||
call_expr=expr,
|
||||
frame=unfolded,
|
||||
frame_expr=obj,
|
||||
obj=(obj, obj_type),
|
||||
method_name=method,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
or UnknownType()
|
||||
)
|
||||
|
||||
callee: Type = self.type_of(expr.callee)
|
||||
result: CallResult = self.dispatcher.get_result(
|
||||
|
||||
Reference in New Issue
Block a user