refactor: add MethodResolver class
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Protocol
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import (
|
||||
ColumnType,
|
||||
@@ -14,63 +16,86 @@ from midas.checker.types import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from midas.checker.frames import FrameManager
|
||||
from midas.checker.python import PythonTyper, TypedExpr
|
||||
|
||||
|
||||
class FrameMethod(Protocol):
|
||||
@property
|
||||
def __name__(self) -> str: ...
|
||||
def __call__(
|
||||
self,
|
||||
typer: PythonTyper,
|
||||
manager: FrameManager,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type: ...
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Call:
|
||||
location: Location
|
||||
frame: DataFrameType
|
||||
positional: list[TypedExpr]
|
||||
keywords: dict[str, TypedExpr]
|
||||
|
||||
|
||||
FRAME_METHODS: dict[str, FrameMethod] = {}
|
||||
def _is_method(obj: object, method: str) -> bool:
|
||||
if not callable(obj):
|
||||
return False
|
||||
if not hasattr(obj, "__method_names__"):
|
||||
return False
|
||||
return method in obj.__method_names__ # type: ignore
|
||||
|
||||
|
||||
def frame_method(*names: str):
|
||||
def wrapper(func: FrameMethod):
|
||||
class MethodResolver:
|
||||
@staticmethod
|
||||
def frame_method(*names: str):
|
||||
def wrapper(func):
|
||||
names_: tuple[str, ...] = names
|
||||
if len(names_) == 0:
|
||||
names_ = (func.__name__,)
|
||||
for name in names_:
|
||||
FRAME_METHODS[name] = func
|
||||
setattr(func, "__method_names__", names_)
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
|
||||
@frame_method("add", "__add__")
|
||||
def add(
|
||||
typer: PythonTyper,
|
||||
manager: FrameManager,
|
||||
reporter: FileReporter,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
@property
|
||||
def reporter(self) -> FileReporter:
|
||||
return self.typer.reporter
|
||||
|
||||
@property
|
||||
def types(self) -> TypesRegistry:
|
||||
return self.typer.types
|
||||
|
||||
def _get_method_by_name(self, method: str) -> Optional[Callable]:
|
||||
for name in dir(self):
|
||||
attr = getattr(self, name)
|
||||
if _is_method(attr, method):
|
||||
return attr
|
||||
return None
|
||||
|
||||
def call(
|
||||
self,
|
||||
method: str,
|
||||
call: Call,
|
||||
) -> Type:
|
||||
func: Optional[Callable] = self._get_method_by_name(method)
|
||||
if func is None:
|
||||
self.reporter.error(call.location, f"Unknown method {method}")
|
||||
return UnknownType()
|
||||
return func(call)
|
||||
|
||||
@frame_method("add", "__add__")
|
||||
def add(
|
||||
self,
|
||||
call: Call,
|
||||
) -> Type:
|
||||
new_columns: list[DataFrameType.Column] = []
|
||||
|
||||
by_name: dict[str, DataFrameType.Column] = {}
|
||||
frame2: Optional[DataFrameType] = None
|
||||
if len(positional) != 0:
|
||||
other: Type = positional[0][1]
|
||||
if len(call.positional) != 0:
|
||||
other: Type = call.positional[0][1]
|
||||
unfolded_other: Type = unfold_type(other)
|
||||
if isinstance(unfolded_other, DataFrameType):
|
||||
frame2 = unfolded_other
|
||||
by_name = {col.name: col for col in frame2.columns if col.name is not None}
|
||||
by_name = {
|
||||
col.name: col for col in frame2.columns if col.name is not None
|
||||
}
|
||||
|
||||
in_frame1: set[str] = set()
|
||||
for column in frame.columns:
|
||||
for column in call.frame.columns:
|
||||
if column.name is not None:
|
||||
in_frame1.add(column.name)
|
||||
|
||||
@@ -79,7 +104,7 @@ def add(
|
||||
if column.name in by_name:
|
||||
column2 = by_name[column.name]
|
||||
col_type2: Type = column2.type
|
||||
if manager.types.are_equivalent(col_type2, col_type1):
|
||||
if self.types.are_equivalent(col_type2, col_type1):
|
||||
col_type = col_type1
|
||||
|
||||
new_column = DataFrameType.Column(
|
||||
@@ -114,11 +139,11 @@ def add(
|
||||
)
|
||||
|
||||
return (
|
||||
typer._get_call_result(
|
||||
location=location,
|
||||
self.typer._get_call_result(
|
||||
location=call.location,
|
||||
callee=signature,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
positional=call.positional,
|
||||
keywords=call.keywords,
|
||||
)
|
||||
or UnknownType()
|
||||
)
|
||||
|
||||
@@ -4,8 +4,7 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
|
||||
|
||||
import midas.ast.python as p
|
||||
from midas.ast.location import Location
|
||||
from midas.checker.frame_methods import FRAME_METHODS, FrameMethod
|
||||
from midas.checker.registry import TypesRegistry
|
||||
from midas.checker.frame_methods import Call, MethodResolver
|
||||
from midas.checker.reporter import FileReporter
|
||||
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
|
||||
|
||||
@@ -18,8 +17,9 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
|
||||
|
||||
|
||||
class FrameManager:
|
||||
def __init__(self, types: TypesRegistry) -> None:
|
||||
self.types: TypesRegistry = types
|
||||
def __init__(self, typer: PythonTyper) -> None:
|
||||
self.typer: PythonTyper = typer
|
||||
self.method_resolver: MethodResolver = MethodResolver(self.typer)
|
||||
|
||||
def assign(
|
||||
self,
|
||||
@@ -137,15 +137,18 @@ class FrameManager:
|
||||
) -> list[Optional[ColumnType]]:
|
||||
return [cls._get_column(frame, name) for name in names]
|
||||
|
||||
def call_method(
|
||||
def call(
|
||||
self,
|
||||
typer: PythonTyper,
|
||||
reporter: FileReporter,
|
||||
method: str,
|
||||
location: Location,
|
||||
frame: DataFrameType,
|
||||
method: str,
|
||||
positional: list[TypedExpr],
|
||||
keywords: dict[str, TypedExpr],
|
||||
) -> Type:
|
||||
function: FrameMethod = FRAME_METHODS[method]
|
||||
return function(typer, self, reporter, location, frame, positional, keywords)
|
||||
call: Call = Call(
|
||||
location=location,
|
||||
frame=frame,
|
||||
positional=positional,
|
||||
keywords=keywords,
|
||||
)
|
||||
return self.method_resolver.call(method, call)
|
||||
|
||||
@@ -8,7 +8,6 @@ from midas.ast.location import Location
|
||||
from midas.ast.printer import MidasPrinter
|
||||
from midas.checker.environment import Environment
|
||||
from midas.checker.evaluator import Evaluator
|
||||
from midas.checker.frame_methods import FRAME_METHODS
|
||||
from midas.checker.frames import FrameManager
|
||||
from midas.checker.operators import (
|
||||
PY_COMPARATOR_METHODS,
|
||||
@@ -76,7 +75,7 @@ class PythonTyper(
|
||||
self.logger: logging.Logger = logging.getLogger("PythonTyper")
|
||||
self.reporter: FileReporter = reporter.for_file(None)
|
||||
self.types: TypesRegistry = types
|
||||
self.frame_mgr: FrameManager = FrameManager(self.types)
|
||||
self.frame_mgr: FrameManager = FrameManager(self)
|
||||
self.global_env: Environment = Preamble(self.types)
|
||||
self.env: Environment = self.global_env
|
||||
self.locals: dict[p.Expr, int] = {}
|
||||
@@ -527,15 +526,11 @@ class PythonTyper(
|
||||
case p.GetExpr(object=obj, name=method):
|
||||
obj_type: Type = self.type_of(obj)
|
||||
unfolded: Type = unfold_type(obj_type)
|
||||
if isinstance(unfolded, DataFrameType) and self._is_frame_method(
|
||||
method
|
||||
):
|
||||
return self.frame_mgr.call_method(
|
||||
self,
|
||||
self.reporter,
|
||||
if isinstance(unfolded, DataFrameType):
|
||||
return self.frame_mgr.call(
|
||||
method,
|
||||
expr.location,
|
||||
unfolded,
|
||||
method,
|
||||
positional,
|
||||
keywords,
|
||||
)
|
||||
@@ -1307,6 +1302,3 @@ class PythonTyper(
|
||||
self, frame: DataFrameType, expr: p.SubscriptExpr
|
||||
) -> Type:
|
||||
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
|
||||
|
||||
def _is_frame_method(self, method: str) -> bool:
|
||||
return method in FRAME_METHODS
|
||||
|
||||
Reference in New Issue
Block a user