refactor: restructure frame method registry in submodule
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frame_methods import Call, MethodRegistry
|
||||
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
ColumnGroupBy,
|
||||
@@ -27,7 +27,7 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
||||
class FrameManager:
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
self.method_resolver: MethodRegistry = MethodRegistry(self.typer)
|
||||
self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer)
|
||||
|
||||
def assign(
|
||||
self,
|
||||
@@ -2,13 +2,12 @@ from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.dispatcher import CallDispatcher, CallResult
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.dispatcher import CallResult
|
||||
from midas.checker.frames.utils import MethodRegistry, method
|
||||
from midas.checker.types import (
|
||||
ColumnType,
|
||||
DataFrameType,
|
||||
@@ -20,22 +19,9 @@ from midas.checker.types import (
|
||||
UnknownType,
|
||||
unfold_type,
|
||||
)
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
@staticmethod
|
||||
def frame_method(*names: str):
|
||||
def wrapper(func):
|
||||
names_: tuple[str, ...] = names
|
||||
if len(names_) == 0:
|
||||
names_ = (func.__name__,)
|
||||
setattr(func, "__method_names__", names_)
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
from midas.checker.python import TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@@ -48,60 +34,16 @@ class Call:
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
|
||||
class _MethodRegistryMeta(type):
|
||||
_methods: dict[str, Callable[..., Type]] = {}
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
):
|
||||
new_class = super().__new__(cls, name, bases, namespace)
|
||||
new_class._methods = {}
|
||||
for attr in namespace.values():
|
||||
if callable(attr) and hasattr(attr, "__method_names__"):
|
||||
for name in attr.__method_names__: # type: ignore
|
||||
new_class._methods[name] = attr # type: ignore
|
||||
return new_class
|
||||
|
||||
|
||||
class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
|
||||
@property
|
||||
def reporter(self) -> FileReporter:
|
||||
return self.typer.reporter
|
||||
|
||||
@property
|
||||
def types(self) -> TypesRegistry:
|
||||
return self.typer.types
|
||||
|
||||
@property
|
||||
def dispatcher(self) -> CallDispatcher[p.Expr]:
|
||||
return self.typer.dispatcher
|
||||
|
||||
@property
|
||||
def assertions(self) -> AssertionCollector:
|
||||
return self.typer.assertions
|
||||
|
||||
def call(
|
||||
self,
|
||||
method: str,
|
||||
call: Call,
|
||||
) -> Type:
|
||||
class FrameMethodRegistry(MethodRegistry):
|
||||
def call(self, method: str, call: Call) -> Type:
|
||||
func: Optional[Callable[..., Type]] = self._methods.get(method)
|
||||
if func is None:
|
||||
self.reporter.warning(call.location, f"Unknown method {method}")
|
||||
return UnknownType()
|
||||
return func(self, call)
|
||||
|
||||
@frame_method("add", "__add__")
|
||||
def add(
|
||||
self,
|
||||
call: Call,
|
||||
) -> Type:
|
||||
@method("add", "__add__")
|
||||
def add(self, call: Call) -> Type:
|
||||
# TODO: support add with scalar, sequence, Series, dict
|
||||
# TODO: check operation exists on inner column types
|
||||
|
||||
@@ -184,7 +126,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
|
||||
return result.result
|
||||
|
||||
@frame_method()
|
||||
@method()
|
||||
def mean(self, call: Call) -> Type:
|
||||
with_axis = Function(
|
||||
kw_args=[
|
||||
@@ -223,7 +165,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
)
|
||||
return result.result
|
||||
|
||||
@frame_method()
|
||||
@method()
|
||||
def groupby(self, call: Call) -> Type:
|
||||
bool_: Type = self.types.get_type("bool")
|
||||
function: Function = Function(
|
||||
64
midas/checker/frames/utils.py
Normal file
64
midas/checker/frames/utils.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.checker.dispatcher import CallDispatcher
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import Type
|
||||
from midas.generator.collector import AssertionCollector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper
|
||||
|
||||
|
||||
@staticmethod
|
||||
def method(*names: str):
|
||||
def wrapper(func):
|
||||
names_: tuple[str, ...] = names
|
||||
if len(names_) == 0:
|
||||
names_ = (func.__name__,)
|
||||
setattr(func, "__method_names__", names_)
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class _MethodRegistryMeta(type):
|
||||
_methods: dict[str, Callable[..., Type]] = {}
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
):
|
||||
new_class = super().__new__(cls, name, bases, namespace)
|
||||
new_class._methods = {}
|
||||
for attr in namespace.values():
|
||||
if callable(attr) and hasattr(attr, "__method_names__"):
|
||||
for name in attr.__method_names__: # type: ignore
|
||||
new_class._methods[name] = attr # type: ignore
|
||||
return new_class
|
||||
|
||||
|
||||
class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
|
||||
@property
|
||||
def reporter(self) -> FileReporter:
|
||||
return self.typer.reporter
|
||||
|
||||
@property
|
||||
def types(self) -> TypesRegistry:
|
||||
return self.typer.types
|
||||
|
||||
@property
|
||||
def dispatcher(self) -> CallDispatcher[p.Expr]:
|
||||
return self.typer.dispatcher
|
||||
|
||||
@property
|
||||
def assertions(self) -> AssertionCollector:
|
||||
return self.typer.assertions
|
||||
@@ -9,7 +9,7 @@ from midas.ast.printer import MidasPrinter
|
||||
from midas.checker.dispatcher import CallDispatcher, CallResult
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.evaluator import Evaluator
|
||||
from midas.checker.frames import FrameManager
|
||||
from midas.checker.frames.frame_manager import FrameManager
|
||||
from midas.checker.operators import (
|
||||
PY_COMPARATOR_METHODS,
|
||||
PY_OPERATOR_METHODS,
|
||||
|
||||
Reference in New Issue
Block a user