From 5b3e87afcb0fdf54ec3e9f27f576dd4959565e86 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Thu, 25 Jun 2026 22:14:25 +0200 Subject: [PATCH] refactor: add MethodResolver class --- midas/checker/frame_methods.py | 215 ++++++++++++++++++--------------- midas/checker/frames.py | 23 ++-- midas/checker/python.py | 16 +-- 3 files changed, 137 insertions(+), 117 deletions(-) diff --git a/midas/checker/frame_methods.py b/midas/checker/frame_methods.py index 60c6eb1..2bbc016 100644 --- a/midas/checker/frame_methods.py +++ b/midas/checker/frame_methods.py @@ -1,8 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Protocol +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Optional from midas.ast.location import Location +from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter from midas.checker.types import ( ColumnType, @@ -14,111 +16,134 @@ from midas.checker.types import ( ) if TYPE_CHECKING: - from midas.checker.frames import FrameManager from midas.checker.python import PythonTyper, TypedExpr -class FrameMethod(Protocol): +@dataclass(frozen=True, kw_only=True) +class Call: + location: Location + frame: DataFrameType + positional: list[TypedExpr] + keywords: dict[str, TypedExpr] + + +def _is_method(obj: object, method: str) -> bool: + if not callable(obj): + return False + if not hasattr(obj, "__method_names__"): + return False + return method in obj.__method_names__ # type: ignore + + +class MethodResolver: + @staticmethod + def frame_method(*names: str): + def wrapper(func): + names_: tuple[str, ...] = names + if len(names_) == 0: + names_ = (func.__name__,) + setattr(func, "__method_names__", names_) + return func + + return wrapper + + def __init__(self, typer: PythonTyper) -> None: + self.typer: PythonTyper = typer + @property - def __name__(self) -> str: ... - def __call__( + def reporter(self) -> FileReporter: + return self.typer.reporter + + @property + def types(self) -> TypesRegistry: + return self.typer.types + + def _get_method_by_name(self, method: str) -> Optional[Callable]: + for name in dir(self): + attr = getattr(self, name) + if _is_method(attr, method): + return attr + return None + + def call( self, - typer: PythonTyper, - manager: FrameManager, - reporter: FileReporter, - location: Location, - frame: DataFrameType, - positional: list[TypedExpr], - keywords: dict[str, TypedExpr], - ) -> Type: ... + method: str, + call: Call, + ) -> Type: + func: Optional[Callable] = self._get_method_by_name(method) + if func is None: + self.reporter.error(call.location, f"Unknown method {method}") + return UnknownType() + return func(call) + @frame_method("add", "__add__") + def add( + self, + call: Call, + ) -> Type: + new_columns: list[DataFrameType.Column] = [] -FRAME_METHODS: dict[str, FrameMethod] = {} + by_name: dict[str, DataFrameType.Column] = {} + frame2: Optional[DataFrameType] = None + if len(call.positional) != 0: + other: Type = call.positional[0][1] + unfolded_other: Type = unfold_type(other) + if isinstance(unfolded_other, DataFrameType): + frame2 = unfolded_other + by_name = { + col.name: col for col in frame2.columns if col.name is not None + } + in_frame1: set[str] = set() + for column in call.frame.columns: + if column.name is not None: + in_frame1.add(column.name) -def frame_method(*names: str): - def wrapper(func: FrameMethod): - names_: tuple[str, ...] = names - if len(names_) == 0: - names_ = (func.__name__,) - for name in names_: - FRAME_METHODS[name] = func - return func + col_type1: Type = column.type + col_type: Type = ColumnType(type=UnknownType()) + if column.name in by_name: + column2 = by_name[column.name] + col_type2: Type = column2.type + if self.types.are_equivalent(col_type2, col_type1): + col_type = col_type1 - return wrapper - - -@frame_method("add", "__add__") -def add( - typer: PythonTyper, - manager: FrameManager, - reporter: FileReporter, - location: Location, - frame: DataFrameType, - positional: list[TypedExpr], - keywords: dict[str, TypedExpr], -) -> Type: - new_columns: list[DataFrameType.Column] = [] - - by_name: dict[str, DataFrameType.Column] = {} - frame2: Optional[DataFrameType] = None - if len(positional) != 0: - other: Type = positional[0][1] - unfolded_other: Type = unfold_type(other) - if isinstance(unfolded_other, DataFrameType): - frame2 = unfolded_other - by_name = {col.name: col for col in frame2.columns if col.name is not None} - - in_frame1: set[str] = set() - for column in frame.columns: - if column.name is not None: - in_frame1.add(column.name) - - col_type1: Type = column.type - col_type: Type = ColumnType(type=UnknownType()) - if column.name in by_name: - column2 = by_name[column.name] - col_type2: Type = column2.type - if manager.types.are_equivalent(col_type2, col_type1): - col_type = col_type1 - - new_column = DataFrameType.Column( - index=column.index, - name=column.name, - type=col_type, - ) - new_columns.append(new_column) - - if frame2 is not None: - for column in frame2.columns: - if column.name in in_frame1: - continue - new_columns.append( - DataFrameType.Column( - index=len(new_columns), - name=column.name, - type=ColumnType(type=UnknownType()), - ) + new_column = DataFrameType.Column( + index=column.index, + name=column.name, + type=col_type, ) + new_columns.append(new_column) - signature = Function( - args=[ - Function.Argument( - pos=0, - name="other", - type=DataFrameType(columns=[]), - required=True, - ), - ], - returns=DataFrameType(columns=new_columns), - ) + if frame2 is not None: + for column in frame2.columns: + if column.name in in_frame1: + continue + new_columns.append( + DataFrameType.Column( + index=len(new_columns), + name=column.name, + type=ColumnType(type=UnknownType()), + ) + ) - return ( - typer._get_call_result( - location=location, - callee=signature, - positional=positional, - keywords=keywords, + signature = Function( + args=[ + Function.Argument( + pos=0, + name="other", + type=DataFrameType(columns=[]), + required=True, + ), + ], + returns=DataFrameType(columns=new_columns), + ) + + return ( + self.typer._get_call_result( + location=call.location, + callee=signature, + positional=call.positional, + keywords=call.keywords, + ) + or UnknownType() ) - or UnknownType() - ) diff --git a/midas/checker/frames.py b/midas/checker/frames.py index 50ae080..dff5f71 100644 --- a/midas/checker/frames.py +++ b/midas/checker/frames.py @@ -4,8 +4,7 @@ from typing import TYPE_CHECKING, Optional, TypeGuard, cast import midas.ast.python as p from midas.ast.location import Location -from midas.checker.frame_methods import FRAME_METHODS, FrameMethod -from midas.checker.registry import TypesRegistry +from midas.checker.frame_methods import Call, MethodResolver from midas.checker.reporter import FileReporter from midas.checker.types import ColumnType, DataFrameType, TupleType, Type, UnknownType @@ -18,8 +17,9 @@ def is_list_of_literals(exprs: list[p.Expr]) -> TypeGuard[list[p.LiteralExpr]]: class FrameManager: - def __init__(self, types: TypesRegistry) -> None: - self.types: TypesRegistry = types + def __init__(self, typer: PythonTyper) -> None: + self.typer: PythonTyper = typer + self.method_resolver: MethodResolver = MethodResolver(self.typer) def assign( self, @@ -137,15 +137,18 @@ class FrameManager: ) -> list[Optional[ColumnType]]: return [cls._get_column(frame, name) for name in names] - def call_method( + def call( self, - typer: PythonTyper, - reporter: FileReporter, + method: str, location: Location, frame: DataFrameType, - method: str, positional: list[TypedExpr], keywords: dict[str, TypedExpr], ) -> Type: - function: FrameMethod = FRAME_METHODS[method] - return function(typer, self, reporter, location, frame, positional, keywords) + call: Call = Call( + location=location, + frame=frame, + positional=positional, + keywords=keywords, + ) + return self.method_resolver.call(method, call) diff --git a/midas/checker/python.py b/midas/checker/python.py index f7b59d8..e733032 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -8,7 +8,6 @@ from midas.ast.location import Location from midas.ast.printer import MidasPrinter from midas.checker.environment import Environment from midas.checker.evaluator import Evaluator -from midas.checker.frame_methods import FRAME_METHODS from midas.checker.frames import FrameManager from midas.checker.operators import ( PY_COMPARATOR_METHODS, @@ -76,7 +75,7 @@ class PythonTyper( self.logger: logging.Logger = logging.getLogger("PythonTyper") self.reporter: FileReporter = reporter.for_file(None) self.types: TypesRegistry = types - self.frame_mgr: FrameManager = FrameManager(self.types) + self.frame_mgr: FrameManager = FrameManager(self) self.global_env: Environment = Preamble(self.types) self.env: Environment = self.global_env self.locals: dict[p.Expr, int] = {} @@ -527,15 +526,11 @@ class PythonTyper( case p.GetExpr(object=obj, name=method): obj_type: Type = self.type_of(obj) unfolded: Type = unfold_type(obj_type) - if isinstance(unfolded, DataFrameType) and self._is_frame_method( - method - ): - return self.frame_mgr.call_method( - self, - self.reporter, + if isinstance(unfolded, DataFrameType): + return self.frame_mgr.call( + method, expr.location, unfolded, - method, positional, keywords, ) @@ -1307,6 +1302,3 @@ class PythonTyper( self, frame: DataFrameType, expr: p.SubscriptExpr ) -> Type: return self.frame_mgr.get(self.reporter, expr.location, frame, expr.index) - - def _is_frame_method(self, method: str) -> bool: - return method in FRAME_METHODS