refactor: add MethodResolver class

This commit is contained in:
2026-06-25 22:14:25 +02:00
parent 894d5a7196
commit 5b3e87afcb
3 changed files with 137 additions and 117 deletions

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Protocol
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional
from midas.ast.location import Location
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import FileReporter
from midas.checker.types import (
ColumnType,
@@ -14,63 +16,86 @@ from midas.checker.types import (
)
if TYPE_CHECKING:
from midas.checker.frames import FrameManager
from midas.checker.python import PythonTyper, TypedExpr
class FrameMethod(Protocol):
@property
def __name__(self) -> str: ...
def __call__(
self,
typer: PythonTyper,
manager: FrameManager,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type: ...
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
frame: DataFrameType
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
FRAME_METHODS: dict[str, FrameMethod] = {}
def _is_method(obj: object, method: str) -> bool:
if not callable(obj):
return False
if not hasattr(obj, "__method_names__"):
return False
return method in obj.__method_names__ # type: ignore
def frame_method(*names: str):
def wrapper(func: FrameMethod):
class MethodResolver:
@staticmethod
def frame_method(*names: str):
def wrapper(func):
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
for name in names_:
FRAME_METHODS[name] = func
setattr(func, "__method_names__", names_)
return func
return wrapper
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
@frame_method("add", "__add__")
def add(
typer: PythonTyper,
manager: FrameManager,
reporter: FileReporter,
location: Location,
frame: DataFrameType,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
@property
def reporter(self) -> FileReporter:
return self.typer.reporter
@property
def types(self) -> TypesRegistry:
return self.typer.types
def _get_method_by_name(self, method: str) -> Optional[Callable]:
for name in dir(self):
attr = getattr(self, name)
if _is_method(attr, method):
return attr
return None
def call(
self,
method: str,
call: Call,
) -> Type:
func: Optional[Callable] = self._get_method_by_name(method)
if func is None:
self.reporter.error(call.location, f"Unknown method {method}")
return UnknownType()
return func(call)
@frame_method("add", "__add__")
def add(
self,
call: Call,
) -> Type:
new_columns: list[DataFrameType.Column] = []
by_name: dict[str, DataFrameType.Column] = {}
frame2: Optional[DataFrameType] = None
if len(positional) != 0:
other: Type = positional[0][1]
if len(call.positional) != 0:
other: Type = call.positional[0][1]
unfolded_other: Type = unfold_type(other)
if isinstance(unfolded_other, DataFrameType):
frame2 = unfolded_other
by_name = {col.name: col for col in frame2.columns if col.name is not None}
by_name = {
col.name: col for col in frame2.columns if col.name is not None
}
in_frame1: set[str] = set()
for column in frame.columns:
for column in call.frame.columns:
if column.name is not None:
in_frame1.add(column.name)
@@ -79,7 +104,7 @@ def add(
if column.name in by_name:
column2 = by_name[column.name]
col_type2: Type = column2.type
if manager.types.are_equivalent(col_type2, col_type1):
if self.types.are_equivalent(col_type2, col_type1):
col_type = col_type1
new_column = DataFrameType.Column(
@@ -114,11 +139,11 @@ def add(
)
return (
typer._get_call_result(
location=location,
self.typer._get_call_result(
location=call.location,
callee=signature,
positional=positional,
keywords=keywords,
positional=call.positional,
keywords=call.keywords,
)
or UnknownType()
)

View File

@@ -4,8 +4,7 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frame_methods import FRAME_METHODS, FrameMethod
from midas.checker.registry import TypesRegistry
from midas.checker.frame_methods import Call, MethodResolver
from midas.checker.reporter import FileReporter
from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType
@@ -18,8 +17,9 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]:
class FrameManager:
def __init__(self, types: TypesRegistry) -> None:
self.types: TypesRegistry = types
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: MethodResolver = MethodResolver(self.typer)
def assign(
self,
@@ -137,15 +137,18 @@ class FrameManager:
) -> list[Optional[ColumnType]]:
return [cls._get_column(frame, name) for name in names]
def call_method(
def call(
self,
typer: PythonTyper,
reporter: FileReporter,
method: str,
location: Location,
frame: DataFrameType,
method: str,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
function: FrameMethod = FRAME_METHODS[method]
return function(typer, self, reporter, location, frame, positional, keywords)
call: Call = Call(
location=location,
frame=frame,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)

View File

@@ -8,7 +8,6 @@ from midas.ast.location import Location
from midas.ast.printer import MidasPrinter
from midas.checker.environment import Environment
from midas.checker.evaluator import Evaluator
from midas.checker.frame_methods import FRAME_METHODS
from midas.checker.frames import FrameManager
from midas.checker.operators import (
PY_COMPARATOR_METHODS,
@@ -76,7 +75,7 @@ class PythonTyper(
self.logger: logging.Logger = logging.getLogger("PythonTyper")
self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types
self.frame_mgr: FrameManager = FrameManager(self.types)
self.frame_mgr: FrameManager = FrameManager(self)
self.global_env: Environment = Preamble(self.types)
self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {}
@@ -527,15 +526,11 @@ class PythonTyper(
case p.GetExpr(object=obj, name=method):
obj_type: Type = self.type_of(obj)
unfolded: Type = unfold_type(obj_type)
if isinstance(unfolded, DataFrameType) and self._is_frame_method(
method
):
return self.frame_mgr.call_method(
self,
self.reporter,
if isinstance(unfolded, DataFrameType):
return self.frame_mgr.call(
method,
expr.location,
unfolded,
method,
positional,
keywords,
)
@@ -1307,6 +1302,3 @@ class PythonTyper(
self, frame: DataFrameType, expr: p.SubscriptExpr
) -> Type:
return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index)
def _is_frame_method(self, method: str) -> bool:
return method in FRAME_METHODS