feat(checker): type check subscript on dataframes

This commit is contained in:
2026-06-23 12:27:31 +02:00
parent 5e9ccd4e13
commit c1b5284f72
2 changed files with 91 additions and 1 deletions

51
midas/checker/frames.py Normal file
View File

@@ -0,0 +1,51 @@
from typing import Optional
from midas.checker.types import ColumnType, DataFrameType
class FrameManager:
@classmethod
def set_column(
cls, frame: DataFrameType, name: str, column: ColumnType
) -> DataFrameType:
new_columns: list[DataFrameType.Column] = []
index: int = len(frame.columns)
replace: bool = False
for i, col in enumerate(frame.columns):
if col.name == name:
index = i
replace = True
new_columns.append(col)
new_col: DataFrameType.Column = DataFrameType.Column(
index=index,
name=name,
type=column,
)
if replace:
new_columns[index] = new_col
else:
new_columns.append(new_col)
return DataFrameType(columns=new_columns)
@classmethod
def set_columns(
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
) -> DataFrameType:
for name, col in zip(names, columns):
frame = cls.set_column(frame, name, col)
return frame
@classmethod
def get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
for col in frame.columns:
if col.name == name:
return col.type
return None
@classmethod
def get_columns(
cls, frame: DataFrameType, names: list[str]
) -> list[Optional[ColumnType]]:
return [cls.get_column(frame, name) for name in names]

View File

@@ -1,13 +1,14 @@
import ast import ast
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional, cast
import midas.ast.python as p import midas.ast.python as p
from midas.ast.location import Location from midas.ast.location import Location
from midas.ast.printer import MidasPrinter from midas.ast.printer import MidasPrinter
from midas.checker.environment import Environment from midas.checker.environment import Environment
from midas.checker.evaluator import Evaluator from midas.checker.evaluator import Evaluator
from midas.checker.frames import FrameManager
from midas.checker.operators import ( from midas.checker.operators import (
PY_COMPARATOR_METHODS, PY_COMPARATOR_METHODS,
PY_OPERATOR_METHODS, PY_OPERATOR_METHODS,
@@ -647,6 +648,8 @@ class PythonTyper(
match unfolded: match unfolded:
case TupleType(): case TupleType():
return self._visit_tuple_subscript(unfolded, expr) return self._visit_tuple_subscript(unfolded, expr)
case DataFrameType():
return self._visit_frame_subscript(unfolded, expr)
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__") operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
if operation is None: if operation is None:
@@ -1252,3 +1255,39 @@ class PythonTyper(
expr.location, f"Invalid index type {expr.index} on {tup}" expr.location, f"Invalid index type {expr.index} on {tup}"
) )
return UnknownType() return UnknownType()
def _visit_frame_subscript(
self, frame: DataFrameType, expr: p.SubscriptExpr
) -> Type:
match expr.index:
case p.LiteralExpr(value=str() as name):
column: Optional[ColumnType] = FrameManager.get_column(frame, name)
if column is None:
self.reporter.error(
expr.location, f"Unknown column '{name}' on {frame}"
)
return UnknownType()
return column
case p.ListExpr(items=indices) if all(
isinstance(index, p.LiteralExpr) and isinstance(index.value, str)
for index in indices
):
indices = cast(list[p.LiteralExpr], indices)
names: list[str] = [cast(str, index.value) for index in indices]
columns: list[ColumnType] = []
for name in names:
column: Optional[ColumnType] = FrameManager.get_column(frame, name)
if column is None:
self.reporter.error(
expr.location, f"Unknown column '{name}' on {frame}"
)
return UnknownType()
columns.append(column)
return TupleType(items=tuple(columns))
case _:
self.reporter.error(
expr.location, f"Invalid index type {expr.index} on {frame}"
)
return UnknownType()