From d1c217a335a565b93ac6ecb66a5e3ce74b9c6d27 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Thu, 25 Jun 2026 22:31:10 +0200 Subject: [PATCH] refactor: use metaclass to collect frame methods --- midas/checker/frame_methods.py | 60 ++++++++++++++++++---------------- midas/checker/frames.py | 4 +-- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/midas/checker/frame_methods.py b/midas/checker/frame_methods.py index 2bbc016..96c868c 100644 --- a/midas/checker/frame_methods.py +++ b/midas/checker/frame_methods.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional from midas.ast.location import Location from midas.checker.registry import TypesRegistry @@ -19,6 +19,18 @@ if TYPE_CHECKING: from midas.checker.python import PythonTyper, TypedExpr +@staticmethod +def frame_method(*names: str): + def wrapper(func): + names_: tuple[str, ...] = names + if len(names_) == 0: + names_ = (func.__name__,) + setattr(func, "__method_names__", names_) + return func + + return wrapper + + @dataclass(frozen=True, kw_only=True) class Call: location: Location @@ -27,26 +39,25 @@ class Call: keywords: dict[str, TypedExpr] -def _is_method(obj: object, method: str) -> bool: - if not callable(obj): - return False - if not hasattr(obj, "__method_names__"): - return False - return method in obj.__method_names__ # type: ignore +class _MethodRegistryMeta(type): + _methods: dict[str, Callable] = {} + + def __new__( + cls, + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + ): + new_class = super().__new__(cls, name, bases, namespace) + new_class._methods = {} + for attr in namespace.values(): + if callable(attr) and hasattr(attr, "__method_names__"): + for name in attr.__method_names__: # type: ignore + new_class._methods[name] = attr + return new_class -class MethodResolver: - @staticmethod - def frame_method(*names: str): - def wrapper(func): - names_: tuple[str, ...] = names - if len(names_) == 0: - names_ = (func.__name__,) - setattr(func, "__method_names__", names_) - return func - - return wrapper - +class MethodRegistry(metaclass=_MethodRegistryMeta): def __init__(self, typer: PythonTyper) -> None: self.typer: PythonTyper = typer @@ -58,23 +69,16 @@ class MethodResolver: def types(self) -> TypesRegistry: return self.typer.types - def _get_method_by_name(self, method: str) -> Optional[Callable]: - for name in dir(self): - attr = getattr(self, name) - if _is_method(attr, method): - return attr - return None - def call( self, method: str, call: Call, ) -> Type: - func: Optional[Callable] = self._get_method_by_name(method) + func: Optional[Callable] = self._methods.get(method) if func is None: self.reporter.error(call.location, f"Unknown method {method}") return UnknownType() - return func(call) + return func(self, call) @frame_method("add", "__add__") def add( diff --git a/midas/checker/frames.py b/midas/checker/frames.py index dff5f71..da9eb0e 100644 --- a/midas/checker/frames.py +++ b/midas/checker/frames.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast import midas.ast.python as p from midas.ast.location import Location -from midas.checker.frame_methods import Call, MethodResolver +from midas.checker.frame_methods import Call, MethodRegistry from midas.checker.reporter import FileReporter from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType @@ -19,7 +19,7 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]: class FrameManager: def __init__(self, typer: PythonTyper) -> None: self.typer: PythonTyper = typer - self.method_resolver: MethodResolver = MethodResolver(self.typer) + self.method_resolver: MethodRegistry = MethodRegistry(self.typer) def assign( self,