refactor: restructure frame method registry in submodule

This commit is contained in:
2026-07-02 18:19:41 +02:00
parent 3c92f0867d
commit 0d5840a4ce
4 changed files with 77 additions and 71 deletions

View File

@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frame_methods import Call, MethodRegistry
from midas.checker.frames.frame_methods import Call, FrameMethodRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import (
ColumnGroupBy,
@@ -27,7 +27,7 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
class FrameManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: MethodRegistry = MethodRegistry(self.typer)
self.method_resolver: FrameMethodRegistry = FrameMethodRegistry(self.typer)
def assign(
self,

View File

@@ -2,13 +2,12 @@ from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.dispatcher import CallDispatcher, CallResult
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.dispatcher import CallResult
from midas.checker.frames.utils import MethodRegistry, method
from midas.checker.types import (
ColumnType,
DataFrameType,
@@ -20,22 +19,9 @@ from midas.checker.types import (
UnknownType,
unfold_type,
)
from midas.generator.collector import AssertionCollector
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
@staticmethod
def frame_method(*names: str):
def wrapper(func):
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
@@ -48,60 +34,16 @@ class Call:
keywords: dict[str, TypedExpr]
class _MethodRegistryMeta(type):
_methods: dict[str, Callable[..., Type]] = {}
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
):
new_class = super().__new__(cls, name, bases, namespace)
new_class._methods = {}
for attr in namespace.values():
if callable(attr) and hasattr(attr, "__method_names__"):
for name in attr.__method_names__: # type: ignore
new_class._methods[name] = attr # type: ignore
return new_class
class MethodRegistry(metaclass=_MethodRegistryMeta):
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
@property
def reporter(self) -> FileReporter:
return self.typer.reporter
@property
def types(self) -> TypesRegistry:
return self.typer.types
@property
def dispatcher(self) -> CallDispatcher[p.Expr]:
return self.typer.dispatcher
@property
def assertions(self) -> AssertionCollector:
return self.typer.assertions
def call(
self,
method: str,
call: Call,
) -> Type:
class FrameMethodRegistry(MethodRegistry):
def call(self, method: str, call: Call) -> Type:
func: Optional[Callable[..., Type]] = self._methods.get(method)
if func is None:
self.reporter.warning(call.location, f"Unknown method {method}")
return UnknownType()
return func(self, call)
@frame_method("add", "__add__")
def add(
self,
call: Call,
) -> Type:
@method("add", "__add__")
def add(self, call: Call) -> Type:
# TODO: support add with scalar, sequence, Series, dict
# TODO: check operation exists on inner column types
@@ -184,7 +126,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
return result.result
@frame_method()
@method()
def mean(self, call: Call) -> Type:
with_axis = Function(
kw_args=[
@@ -223,7 +165,7 @@ class MethodRegistry(metaclass=_MethodRegistryMeta):
)
return result.result
@frame_method()
@method()
def groupby(self, call: Call) -> Type:
bool_: Type = self.types.get_type("bool")
function: Function = Function(

View File

@@ -0,0 +1,64 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable
import midas.ast.python as p
from midas.checker.dispatcher import CallDispatcher
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import Type
from midas.generator.collector import AssertionCollector
if TYPE_CHECKING:
from midas.checker.python import PythonTyper
@staticmethod
def method(*names: str):
def wrapper(func):
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper
class _MethodRegistryMeta(type):
_methods: dict[str, Callable[..., Type]] = {}
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
):
new_class = super().__new__(cls, name, bases, namespace)
new_class._methods = {}
for attr in namespace.values():
if callable(attr) and hasattr(attr, "__method_names__"):
for name in attr.__method_names__: # type: ignore
new_class._methods[name] = attr # type: ignore
return new_class
class MethodRegistry(metaclass=_MethodRegistryMeta):
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
@property
def reporter(self) -> FileReporter:
return self.typer.reporter
@property
def types(self) -> TypesRegistry:
return self.typer.types
@property
def dispatcher(self) -> CallDispatcher[p.Expr]:
return self.typer.dispatcher
@property
def assertions(self) -> AssertionCollector:
return self.typer.assertions

View File

@@ -9,7 +9,7 @@ from midas.ast.printer import MidasPrinter
from midas.checker.dispatcher import CallDispatcher, CallResult
from midas.checker.environment import Environment
from midas.checker.evaluator import Evaluator
from midas.checker.frames import FrameManager
from midas.checker.frames.frame_manager import FrameManager
from midas.checker.operators import (
PY_COMPARATOR_METHODS,
PY_OPERATOR_METHODS,