feat(checker): add column method registry

This commit is contained in:
2026-07-02 19:23:23 +02:00
parent 640f2d1771
commit 2fce2f4bfc
4 changed files with 92 additions and 12 deletions

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frames.column_methods import Call, ColumnMethodRegistry
from midas.checker.types import ColumnType, Type
if TYPE_CHECKING:
from midas.checker.python import PythonTyper, TypedExpr
class ColumnManager:
def __init__(self, typer: PythonTyper) -> None:
self.typer: PythonTyper = typer
self.method_resolver: ColumnMethodRegistry = ColumnMethodRegistry(self.typer)
def call(
self,
method: str,
location: Location,
call_expr: p.Expr,
column: ColumnType,
column_expr: p.Expr,
positional: list[TypedExpr],
keywords: dict[str, TypedExpr],
) -> Type:
call: Call = Call(
location=location,
call_expr=call_expr,
column=column,
column_expr=column_expr,
positional=positional,
keywords=keywords,
)
return self.method_resolver.call(method, call)

View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import midas.ast.python as p
from midas.ast.location import Location
from midas.checker.frames.utils import MethodRegistry
from midas.checker.types import (
ColumnType,
)
if TYPE_CHECKING:
from midas.checker.python import TypedExpr
@dataclass(frozen=True, kw_only=True)
class Call:
location: Location
call_expr: p.Expr
column: ColumnType
column_expr: p.Expr
positional: list[TypedExpr]
keywords: dict[str, TypedExpr]
class ColumnMethodRegistry(MethodRegistry[Call]): ...

View File

@@ -23,18 +23,6 @@ if TYPE_CHECKING:
from midas.checker.python import PythonTyper from midas.checker.python import PythonTyper
@staticmethod
def method(*names: str):
def wrapper(func):
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper
class _MethodRegistryMeta(type): class _MethodRegistryMeta(type):
_methods: dict[str, Callable[..., Type]] = {} _methods: dict[str, Callable[..., Type]] = {}
@@ -87,3 +75,18 @@ class MethodRegistry(Generic[T], metaclass=_MethodRegistryMeta):
self.reporter.warning(call.location, f"Unknown method {method}") self.reporter.warning(call.location, f"Unknown method {method}")
return UnknownType() return UnknownType()
return func(self, call) return func(self, call)
_Self = TypeVar("_Self", bound=MethodRegistry[Any])
Method = Callable[[_Self, T], Type]
def method(*names: str) -> Callable[[Method[_Self, T]], Method[_Self, T]]:
def wrapper(func: Method[_Self, T]) -> Method[_Self, T]:
names_: tuple[str, ...] = names
if len(names_) == 0:
names_ = (func.__name__,)
setattr(func, "__method_names__", names_)
return func
return wrapper

View File

@@ -9,6 +9,7 @@ from midas.ast.printer import MidasPrinter
from midas.checker.dispatcher import CallDispatcher, CallResult from midas.checker.dispatcher import CallDispatcher, CallResult
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.evaluator import Evaluator from midas.checker.evaluator import Evaluator
from midas.checker.frames.column_manager import ColumnManager
from midas.checker.frames.frame_manager import FrameManager from midas.checker.frames.frame_manager import FrameManager
from midas.checker.operators import ( from midas.checker.operators import (
PY_COMPARATOR_METHODS, PY_COMPARATOR_METHODS,
@@ -82,6 +83,7 @@ class PythonTyper(
self.reporter: FileReporter = reporter.for_file(None) self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types self.types: TypesRegistry = types
self.frame_mgr: FrameManager = FrameManager(self) self.frame_mgr: FrameManager = FrameManager(self)
self.column_mgr: ColumnManager = ColumnManager(self)
self.global_env: Environment = Preamble(self.types) self.global_env: Environment = Preamble(self.types)
self.env: Environment = self.global_env self.env: Environment = self.global_env
self.locals: dict[p.Expr, int] = {} self.locals: dict[p.Expr, int] = {}
@@ -233,6 +235,17 @@ class PythonTyper(
keywords=keywords, keywords=keywords,
) )
case ColumnType():
return self.column_mgr.call(
method=method_name,
location=location,
call_expr=call_expr,
column=unfolded,
column_expr=obj[0],
positional=positional,
keywords=keywords,
)
method: Optional[Type] = self.types.lookup_member(obj[1], method_name) method: Optional[Type] = self.types.lookup_member(obj[1], method_name)
if method is None: if method is None:
raise UndefinedMethodException raise UndefinedMethodException