refactor: use metaclass to collect frame methods
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Callable, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||||
|
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.registry import TypesRegistry
|
from midas.checker.registry import TypesRegistry
|
||||||
@@ -19,25 +19,8 @@ if TYPE_CHECKING:
|
|||||||
from midas.checker.python import PythonTyper, TypedExpr
|
from midas.checker.python import PythonTyper, TypedExpr
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@staticmethod
|
||||||
class Call:
|
def frame_method(*names: str):
|
||||||
location: Location
|
|
||||||
frame: DataFrameType
|
|
||||||
positional: list[TypedExpr]
|
|
||||||
keywords: dict[str, TypedExpr]
|
|
||||||
|
|
||||||
|
|
||||||
def _is_method(obj: object, method: str) -> bool:
|
|
||||||
if not callable(obj):
|
|
||||||
return False
|
|
||||||
if not hasattr(obj, "__method_names__"):
|
|
||||||
return False
|
|
||||||
return method in obj.__method_names__ # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class MethodResolver:
|
|
||||||
@staticmethod
|
|
||||||
def frame_method(*names: str):
|
|
||||||
def wrapper(func):
|
def wrapper(func):
|
||||||
names_: tuple[str, ...] = names
|
names_: tuple[str, ...] = names
|
||||||
if len(names_) == 0:
|
if len(names_) == 0:
|
||||||
@@ -47,6 +30,34 @@ class MethodResolver:
|
|||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class Call:
|
||||||
|
location: Location
|
||||||
|
frame: DataFrameType
|
||||||
|
positional: list[TypedExpr]
|
||||||
|
keywords: dict[str, TypedExpr]
|
||||||
|
|
||||||
|
|
||||||
|
class _MethodRegistryMeta(type):
|
||||||
|
_methods: dict[str, Callable] = {}
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
bases: tuple[type, ...],
|
||||||
|
namespace: dict[str, Any],
|
||||||
|
):
|
||||||
|
new_class = super().__new__(cls, name, bases, namespace)
|
||||||
|
new_class._methods = {}
|
||||||
|
for attr in namespace.values():
|
||||||
|
if callable(attr) and hasattr(attr, "__method_names__"):
|
||||||
|
for name in attr.__method_names__: # type: ignore
|
||||||
|
new_class._methods[name] = attr
|
||||||
|
return new_class
|
||||||
|
|
||||||
|
|
||||||
|
class MethodRegistry(metaclass=_MethodRegistryMeta):
|
||||||
def __init__(self, typer: PythonTyper) -> None:
|
def __init__(self, typer: PythonTyper) -> None:
|
||||||
self.typer: PythonTyper = typer
|
self.typer: PythonTyper = typer
|
||||||
|
|
||||||
@@ -58,23 +69,16 @@ class MethodResolver:
|
|||||||
def types(self) -> TypesRegistry:
|
def types(self) -> TypesRegistry:
|
||||||
return self.typer.types
|
return self.typer.types
|
||||||
|
|
||||||
def _get_method_by_name(self, method: str) -> Optional[Callable]:
|
|
||||||
for name in dir(self):
|
|
||||||
attr = getattr(self, name)
|
|
||||||
if _is_method(attr, method):
|
|
||||||
return attr
|
|
||||||
return None
|
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
call: Call,
|
call: Call,
|
||||||
) -> Type:
|
) -> Type:
|
||||||
func: Optional[Callable] = self._get_method_by_name(method)
|
func: Optional[Callable] = self._methods.get(method)
|
||||||
if func is None:
|
if func is None:
|
||||||
self.reporter.error(call.location, f"Unknown method {method}")
|
self.reporter.error(call.location, f"Unknown method {method}")
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
return func(call)
|
return func(self, call)
|
||||||
|
|
||||||
@frame_method("add", "__add__")
|
@frame_method("add", "__add__")
|
||||||
def add(
|
def add(
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
|||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.checker.frame_methods import Call, MethodResolver
|
from midas.checker.frame_methods import Call, MethodRegistry
|
||||||
from midas.checker.reporter import FileReporter
|
from midas.checker.reporter import FileReporter
|
||||||
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
|
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
|
||||||
|
|
||||||
@@ -19,7 +19,7 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
|||||||
class FrameManager:
|
class FrameManager:
|
||||||
def __init__(self, typer: PythonTyper) -> None:
|
def __init__(self, typer: PythonTyper) -> None:
|
||||||
self.typer: PythonTyper = typer
|
self.typer: PythonTyper = typer
|
||||||
self.method_resolver: MethodResolver = MethodResolver(self.typer)
|
self.method_resolver: MethodRegistry = MethodRegistry(self.typer)
|
||||||
|
|
||||||
def assign(
|
def assign(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user