refactor: restructure frame method registry in submodule
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
|||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.frame_methods import Call, MethodRegistry
|
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
|
||||||
from midas.checker.reporter import FileReporter
|
from midas.checker.reporter import FileReporter
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
ColumnGroupBy,
|
ColumnGroupBy,
|
||||||
@@ -27,7 +27,7 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
|||||||
class FrameManager:
|
class FrameManager:
|
||||||
def __init__(self, typer: PythonTyper) -> None:
|
def __init__(self, typer: PythonTyper) -> None:
|
||||||
self.typer: PythonTyper = typer
|
self.typer: PythonTyper = typer
|
||||||
self.method_resolver: MethodRegistry = MethodRegistry(self.typer)
|
self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer)
|
||||||
|
|
||||||
def assign(
|
def assign(
|
||||||
self,
|
self,
|
||||||
@@ -2,13 +2,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import ast
|
import ast
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
from typing import TYPE_CHECKING, Callable, Optional
|
||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.dispatcher import CallDispatcher, CallResult
|
from midas.checker.dispatcher import CallResult
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.frames.utils import MethodRegistry, method
|
||||||
from midas.checker.reporter import FileReporter
|
|
||||||
from midas.checker.types import (
|
from midas.checker.types import (
|
||||||
ColumnType,
|
ColumnType,
|
||||||
DataFrameType,
|
DataFrameType,
|
||||||
@@ -20,22 +19,9 @@ from midas.checker.types import (
|
|||||||
UnknownType,
|
UnknownType,
|
||||||
unfold_type,
|
unfold_type,
|
||||||
)
|
)
|
||||||
from midas.generator.collector import AssertionCollector
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from midas.checker.python import PythonTyper, TypedExpr
|
from midas.checker.python import TypedExpr
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def frame_method(*names: str):
|
|
||||||
def wrapper(func):
|
|
||||||
names_: tuple[str, ...] = names
|
|
||||||
if len(names_) == 0:
|
|
||||||
names_ = (func.__name__,)
|
|
||||||
setattr(func, "__method_names__", names_)
|
|
||||||
return func
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
@@ -48,60 +34,16 @@ class Call:
|
|||||||
keywords: dict[str, TypedExpr]
|
keywords: dict[str, TypedExpr]
|
||||||
|
|
||||||
|
|
||||||
class _MethodRegistryMeta(type):
|
class FrameMethodRegistry(MethodRegistry):
|
||||||
_methods: dict[str, Callable[..., Type]] = {}
|
def call(self, method: str, call: Call) -> Type:
|
||||||
|
|
||||||
def __new__(
|
|
||||||
cls,
|
|
||||||
name: str,
|
|
||||||
bases: tuple[type, ...],
|
|
||||||
namespace: dict[str, Any],
|
|
||||||
):
|
|
||||||
new_class = super().__new__(cls, name, bases, namespace)
|
|
||||||
new_class._methods = {}
|
|
||||||
for attr in namespace.values():
|
|
||||||
if callable(attr) and hasattr(attr, "__method_names__"):
|
|
||||||
for name in attr.__method_names__: # type: ignore
|
|
||||||
new_class._methods[name] = attr # type: ignore
|
|
||||||
return new_class
|
|
||||||
|
|
||||||
|
|
||||||
class MethodRegistry(metaclass=_MethodRegistryMeta):
|
|
||||||
def __init__(self, typer: PythonTyper) -> None:
|
|
||||||
self.typer: PythonTyper = typer
|
|
||||||
|
|
||||||
@property
|
|
||||||
def reporter(self) -> FileReporter:
|
|
||||||
return self.typer.reporter
|
|
||||||
|
|
||||||
@property
|
|
||||||
def types(self) -> TypesRegistry:
|
|
||||||
return self.typer.types
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dispatcher(self) -> CallDispatcher[p.Expr]:
|
|
||||||
return self.typer.dispatcher
|
|
||||||
|
|
||||||
@property
|
|
||||||
def assertions(self) -> AssertionCollector:
|
|
||||||
return self.typer.assertions
|
|
||||||
|
|
||||||
def call(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
call: Call,
|
|
||||||
) -> Type:
|
|
||||||
func: Optional[Callable[..., Type]] = self._methods.get(method)
|
func: Optional[Callable[..., Type]] = self._methods.get(method)
|
||||||
if func is None:
|
if func is None:
|
||||||
self.reporter.warning(call.location, f"Unknown method {method}")
|
self.reporter.warning(call.location, f"Unknown method {method}")
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
return func(self, call)
|
return func(self, call)
|
||||||
|
|
||||||
@frame_method("add", "__add__")
|
@method("add", "__add__")
|
||||||
def add(
|
def add(self, call: Call) -> Type:
|
||||||
self,
|
|
||||||
call: Call,
|
|
||||||
) -> Type:
|
|
||||||
# TODO: support add with scalar, sequence, Series, dict
|
# TODO: support add with scalar, sequence, Series, dict
|
||||||
# TODO: check operation exists on inner column types
|
# TODO: check operation exists on inner column types
|
||||||
|
|
||||||
@@ -184,7 +126,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
|||||||
|
|
||||||
return result.result
|
return result.result
|
||||||
|
|
||||||
@frame_method()
|
@method()
|
||||||
def mean(self, call: Call) -> Type:
|
def mean(self, call: Call) -> Type:
|
||||||
with_axis = Function(
|
with_axis = Function(
|
||||||
kw_args=[
|
kw_args=[
|
||||||
@@ -223,7 +165,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
|||||||
)
|
)
|
||||||
return result.result
|
return result.result
|
||||||
|
|
||||||
@frame_method()
|
@method()
|
||||||
def groupby(self, call: Call) -> Type:
|
def groupby(self, call: Call) -> Type:
|
||||||
bool_: Type = self.types.get_type("bool")
|
bool_: Type = self.types.get_type("bool")
|
||||||
function: Function = Function(
|
function: Function = Function(
|
||||||
64
midas/checker/frames/utils.py
Normal file
64
midas/checker/frames/utils.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
|
|
||||||
|
import midas.ast.python as p
|
||||||
|
from midas.checker.dispatcher import CallDispatcher
|
||||||
|
from midas.checker.registry import TypesRegistry
|
||||||
|
from midas.checker.reporter import FileReporter
|
||||||
|
from midas.checker.types import Type
|
||||||
|
from midas.generator.collector import AssertionCollector
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from midas.checker.python import PythonTyper
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def method(*names: str):
|
||||||
|
def wrapper(func):
|
||||||
|
names_: tuple[str, ...] = names
|
||||||
|
if len(names_) == 0:
|
||||||
|
names_ = (func.__name__,)
|
||||||
|
setattr(func, "__method_names__", names_)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class _MethodRegistryMeta(type):
|
||||||
|
_methods: dict[str, Callable[..., Type]] = {}
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
bases: tuple[type, ...],
|
||||||
|
namespace: dict[str, Any],
|
||||||
|
):
|
||||||
|
new_class = super().__new__(cls, name, bases, namespace)
|
||||||
|
new_class._methods = {}
|
||||||
|
for attr in namespace.values():
|
||||||
|
if callable(attr) and hasattr(attr, "__method_names__"):
|
||||||
|
for name in attr.__method_names__: # type: ignore
|
||||||
|
new_class._methods[name] = attr # type: ignore
|
||||||
|
return new_class
|
||||||
|
|
||||||
|
|
||||||
|
class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||||
|
def __init__(self, typer: PythonTyper) -> None:
|
||||||
|
self.typer: PythonTyper = typer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reporter(self) -> FileReporter:
|
||||||
|
return self.typer.reporter
|
||||||
|
|
||||||
|
@property
|
||||||
|
def types(self) -> TypesRegistry:
|
||||||
|
return self.typer.types
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dispatcher(self) -> CallDispatcher[p.Expr]:
|
||||||
|
return self.typer.dispatcher
|
||||||
|
|
||||||
|
@property
|
||||||
|
def assertions(self) -> AssertionCollector:
|
||||||
|
return self.typer.assertions
|
||||||
@@ -9,7 +9,7 @@ from midas.ast.printer import MidasPrinter
|
|||||||
from midas.checker.dispatcher import CallDispatcher, CallResult
|
from midas.checker.dispatcher import CallDispatcher, CallResult
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.evaluator import Evaluator
|
from midas.checker.evaluator import Evaluator
|
||||||
from midas.checker.frames import FrameManager
|
from midas.checker.frames.frame_manager import FrameManager
|
||||||
from midas.checker.operators import (
|
from midas.checker.operators import (
|
||||||
PY_COMPARATOR_METHODS,
|
PY_COMPARATOR_METHODS,
|
||||||
PY_OPERATOR_METHODS,
|
PY_OPERATOR_METHODS,
|
||||||
|
|||||||
Reference in New Issue
Block a user