From c1b5284f72087cd8339534ec32066cabf1647825 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 23 Jun 2026 12:27:31 +0200 Subject: [PATCH] feat(checker): type check subscript on dataframes --- midas/checker/frames.py | 51 +++++++++++++++++++++++++++++++++++++++++ midas/checker/python.py | 41 ++++++++++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 midas/checker/frames.py diff --git a/midas/checker/frames.py b/midas/checker/frames.py new file mode 100644 index 0000000..069143a --- /dev/null +++ b/midas/checker/frames.py @@ -0,0 +1,51 @@ +from typing import Optional + +from midas.checker.types import ColumnType, DataFrameType + + +class FrameManager: + @classmethod + def set_column( + cls, frame: DataFrameType, name: str, column: ColumnType + ) -> DataFrameType: + new_columns: list[DataFrameType.Column] = [] + index: int = len(frame.columns) + replace: bool = False + for i, col in enumerate(frame.columns): + if col.name == name: + index = i + replace = True + new_columns.append(col) + + new_col: DataFrameType.Column = DataFrameType.Column( + index=index, + name=name, + type=column, + ) + if replace: + new_columns[index] = new_col + else: + new_columns.append(new_col) + + return DataFrameType(columns=new_columns) + + @classmethod + def set_columns( + cls, frame: DataFrameType, names: list[str], columns: list[ColumnType] + ) -> DataFrameType: + for name, col in zip(names, columns): + frame = cls.set_column(frame, name, col) + return frame + + @classmethod + def get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]: + for col in frame.columns: + if col.name == name: + return col.type + return None + + @classmethod + def get_columns( + cls, frame: DataFrameType, names: list[str] + ) -> list[Optional[ColumnType]]: + return [cls.get_column(frame, name) for name in names] diff --git a/midas/checker/python.py b/midas/checker/python.py index 5ea0c85..ceb9e6a 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -1,13 +1,14 @@ import ast import logging from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Optional, cast import midas.ast.python as p from midas.ast.location import Location from midas.ast.printer import MidasPrinter from midas.checker.environment import Environment from midas.checker.evaluator import Evaluator +from midas.checker.frames import FrameManager from midas.checker.operators import ( PY_COMPARATOR_METHODS, PY_OPERATOR_METHODS, @@ -647,6 +648,8 @@ class PythonTyper( match unfolded: case TupleType(): return self._visit_tuple_subscript(unfolded, expr) + case DataFrameType(): + return self._visit_frame_subscript(unfolded, expr) operation: Optional[Type] = self.types.lookup_member(object, "__getitem__") if operation is None: @@ -1252,3 +1255,39 @@ class PythonTyper( expr.location, f"Invalid index type {expr.index} on {tup}" ) return UnknownType() + + def _visit_frame_subscript( + self, frame: DataFrameType, expr: p.SubscriptExpr + ) -> Type: + match expr.index: + case p.LiteralExpr(value=str() as name): + column: Optional[ColumnType] = FrameManager.get_column(frame, name) + if column is None: + self.reporter.error( + expr.location, f"Unknown column '{name}' on {frame}" + ) + return UnknownType() + return column + + case p.ListExpr(items=indices) if all( + isinstance(index, p.LiteralExpr) and isinstance(index.value, str) + for index in indices + ): + indices = cast(list[p.LiteralExpr], indices) + names: list[str] = [cast(str, index.value) for index in indices] + columns: list[ColumnType] = [] + for name in names: + column: Optional[ColumnType] = FrameManager.get_column(frame, name) + if column is None: + self.reporter.error( + expr.location, f"Unknown column '{name}' on {frame}" + ) + return UnknownType() + columns.append(column) + return TupleType(items=tuple(columns)) + + case _: + self.reporter.error( + expr.location, f"Invalid index type {expr.index} on {frame}" + ) + return UnknownType()