refactor: use metaclass to collect frame methods
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.registry import TypesRegistry
|
||||
@@ -19,25 +19,8 @@ if TYPE_CHECKING:
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
frame: DataFrameType
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
|
||||
def _is_method(obj: object, method: str) -> bool:
|
||||
if not callable(obj):
|
||||
return False
|
||||
if not hasattr(obj, "__method_names__"):
|
||||
return False
|
||||
return method in obj.__method_names__ # type: ignore
|
||||
|
||||
|
||||
class MethodResolver:
|
||||
@staticmethod
|
||||
def frame_method(*names: str):
|
||||
@staticmethod
|
||||
def frame_method(*names: str):
|
||||
def wrapper(func):
|
||||
names_: tuple[str, ...] = names
|
||||
if len(names_) == 0:
|
||||
@@ -47,6 +30,34 @@ class MethodResolver:
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
frame: DataFrameType
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
|
||||
class _MethodRegistryMeta(type):
|
||||
_methods: dict[str, Callable] = {}
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
):
|
||||
new_class = super().__new__(cls, name, bases, namespace)
|
||||
new_class._methods = {}
|
||||
for attr in namespace.values():
|
||||
if callable(attr) and hasattr(attr, "__method_names__"):
|
||||
for name in attr.__method_names__: # type: ignore
|
||||
new_class._methods[name] = attr
|
||||
return new_class
|
||||
|
||||
|
||||
class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
|
||||
@@ -58,23 +69,16 @@ class MethodResolver:
|
||||
def types(self) -> TypesRegistry:
|
||||
return self.typer.types
|
||||
|
||||
def _get_method_by_name(self, method: str) -> Optional[Callable]:
|
||||
for name in dir(self):
|
||||
attr = getattr(self, name)
|
||||
if _is_method(attr, method):
|
||||
return attr
|
||||
return None
|
||||
|
||||
def call(
|
||||
self,
|
||||
method: str,
|
||||
call: Call,
|
||||
) -> Type:
|
||||
func: Optional[Callable] = self._get_method_by_name(method)
|
||||
func: Optional[Callable] = self._methods.get(method)
|
||||
if func is None:
|
||||
self.reporter.error(call.location, f"Unknown method {method}")
|
||||
return UnknownType()
|
||||
return func(call)
|
||||
return func(self, call)
|
||||
|
||||
@frame_method("add", "__add__")
|
||||
def add(
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frame_methods import Call, MethodResolver
|
||||
from midas.checker.frame_methods import Call, MethodRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
|
||||
|
||||
@@ -19,7 +19,7 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
||||
class FrameManager:
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
self.method_resolver: MethodResolver = MethodResolver(self.typer)
|
||||
self.method_resolver: MethodRegistry = MethodRegistry(self.typer)
|
||||
|
||||
def assign(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user