diff --git a/midas/checker/frame_methods.py b/midas/checker/frame_methods.py index 8a7dfbc..0f85493 100644 --- a/midas/checker/frame_methods.py +++ b/midas/checker/frame_methods.py @@ -3,7 +3,9 @@ from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Optional +import midas.ast.python as p from midas.ast.location import Location +from midas.checker.dispatcher import CallDispatcher, CallResult from midas.checker.registry import TypesRegistry from midas.checker.reporter import FileReporter from midas.checker.types import ( @@ -71,6 +73,10 @@ class MethodRegistry(metaclass=_MethodRegistryMeta): def types(self) -> TypesRegistry: return self.typer.types + @property + def dispatcher(self) -> CallDispatcher[p.Expr]: + return self.typer.dispatcher + def call( self, method: str, @@ -147,15 +153,13 @@ class MethodRegistry(metaclass=_MethodRegistryMeta): returns=DataFrameType(columns=new_columns), ) - return ( - self.typer._get_call_result( - location=call.location, - callee=signature, - positional=call.positional, - keywords=call.keywords, - ) - or UnknownType() + result: CallResult = self.dispatcher.get_result( + location=call.location, + callee=signature, + positional=call.positional, + keywords=call.keywords, ) + return result.result @frame_method() def mean(self, call: Call) -> Type: @@ -187,12 +191,11 @@ class MethodRegistry(metaclass=_MethodRegistryMeta): without_axis, ] ) - return ( - self.typer._get_call_result( - location=call.location, - callee=overload, - positional=call.positional, - keywords=call.keywords, - ) - or UnknownType() + + result: CallResult = self.dispatcher.get_result( + location=call.location, + callee=overload, + positional=call.positional, + keywords=call.keywords, ) + return result.result