refactor: use metaclass to collect frame methods

This commit is contained in:
2026-06-25 22:31:10 +02:00
parent 5b3e87afcb
commit d1c217a335
2 changed files with 34 additions and 30 deletions

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional
from midas.ast.location import Location
from midas.checker.registry import TypesRegistry
@@ -19,6 +19,18 @@ if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
@staticmethod
def frame_method(*names: str):
def wrapper(func):
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
@@ -27,26 +39,25 @@ class Call:
keywords: dict[str, TypedExpr]
def _is_method(obj: object, method: str) -> bool:
if not callable(obj):
return False
if not hasattr(obj, "__method_names__"):
return False
return method in obj.__method_names__ # type: ignore
class _MethodRegistryMeta(type):
_methods: dict[str, Callable] = {}
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
):
new_class = super().__new__(cls, name, bases, namespace)
new_class._methods = {}
for attr in namespace.values():
if callable(attr) and hasattr(attr, "__method_names__"):
for name in attr.__method_names__: # type: ignore
new_class._methods[name] = attr
return new_class
class MethodResolver:
@staticmethod
def frame_method(*names: str):
def wrapper(func):
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper
class MethodRegistry(metaclass=_MethodRegistryMeta):
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
@@ -58,23 +69,16 @@ class MethodResolver:
def types(self) -> TypesRegistry:
return self.typer.types
def _get_method_by_name(self, method: str) -> Optional[Callable]:
for name in dir(self):
attr = getattr(self, name)
if _is_method(attr, method):
return attr
return None
def call(
self,
method: str,
call: Call,
) -> Type:
func: Optional[Callable] = self._get_method_by_name(method)
func: Optional[Callable] = self._methods.get(method)
if func is None:
self.reporter.error(call.location, f"Unknown method {method}")
return UnknownType()
return func(call)
return func(self, call)
@frame_method("add", "__add__")
def add(

View File

@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frame_methods import Call, MethodResolver
from midas.checker.frame_methods import Call, MethodRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
@@ -19,7 +19,7 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
class FrameManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: MethodResolver = MethodResolver(self.typer)
self.method_resolver: MethodRegistry = MethodRegistry(self.typer)
def assign(
self,