feat(checker): type check subscript on dataframes
This commit is contained in:
51
midas/checker/frames.py
Normal file
51
midas/checker/frames.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from midas.checker.types import ColumnType, DataFrameType
|
||||||
|
|
||||||
|
|
||||||
|
class FrameManager:
|
||||||
|
@classmethod
|
||||||
|
def set_column(
|
||||||
|
cls, frame: DataFrameType, name: str, column: ColumnType
|
||||||
|
) -> DataFrameType:
|
||||||
|
new_columns: list[DataFrameType.Column] = []
|
||||||
|
index: int = len(frame.columns)
|
||||||
|
replace: bool = False
|
||||||
|
for i, col in enumerate(frame.columns):
|
||||||
|
if col.name == name:
|
||||||
|
index = i
|
||||||
|
replace = True
|
||||||
|
new_columns.append(col)
|
||||||
|
|
||||||
|
new_col: DataFrameType.Column = DataFrameType.Column(
|
||||||
|
index=index,
|
||||||
|
name=name,
|
||||||
|
type=column,
|
||||||
|
)
|
||||||
|
if replace:
|
||||||
|
new_columns[index] = new_col
|
||||||
|
else:
|
||||||
|
new_columns.append(new_col)
|
||||||
|
|
||||||
|
return DataFrameType(columns=new_columns)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_columns(
|
||||||
|
cls, frame: DataFrameType, names: list[str], columns: list[ColumnType]
|
||||||
|
) -> DataFrameType:
|
||||||
|
for name, col in zip(names, columns):
|
||||||
|
frame = cls.set_column(frame, name, col)
|
||||||
|
return frame
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_column(cls, frame: DataFrameType, name: str) -> Optional[ColumnType]:
|
||||||
|
for col in frame.columns:
|
||||||
|
if col.name == name:
|
||||||
|
return col.type
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_columns(
|
||||||
|
cls, frame: DataFrameType, names: list[str]
|
||||||
|
) -> list[Optional[ColumnType]]:
|
||||||
|
return [cls.get_column(frame, name) for name in names]
|
||||||
@@ -1,13 +1,14 @@
|
|||||||
import ast
|
import ast
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
import midas.ast.python as p
|
import midas.ast.python as p
|
||||||
from midas.ast.location import Location
|
from midas.ast.location import Location
|
||||||
from midas.ast.printer import MidasPrinter
|
from midas.ast.printer import MidasPrinter
|
||||||
from midas.checker.environment import Environment
|
from midas.checker.environment import Environment
|
||||||
from midas.checker.evaluator import Evaluator
|
from midas.checker.evaluator import Evaluator
|
||||||
|
from midas.checker.frames import FrameManager
|
||||||
from midas.checker.operators import (
|
from midas.checker.operators import (
|
||||||
PY_COMPARATOR_METHODS,
|
PY_COMPARATOR_METHODS,
|
||||||
PY_OPERATOR_METHODS,
|
PY_OPERATOR_METHODS,
|
||||||
@@ -647,6 +648,8 @@ class PythonTyper(
|
|||||||
match unfolded:
|
match unfolded:
|
||||||
case TupleType():
|
case TupleType():
|
||||||
return self._visit_tuple_subscript(unfolded, expr)
|
return self._visit_tuple_subscript(unfolded, expr)
|
||||||
|
case DataFrameType():
|
||||||
|
return self._visit_frame_subscript(unfolded, expr)
|
||||||
|
|
||||||
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
|
operation: Optional[Type] = self.types.lookup_member(object, "__getitem__")
|
||||||
if operation is None:
|
if operation is None:
|
||||||
@@ -1252,3 +1255,39 @@ class PythonTyper(
|
|||||||
expr.location, f"Invalid index type {expr.index} on {tup}"
|
expr.location, f"Invalid index type {expr.index} on {tup}"
|
||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
|
def _visit_frame_subscript(
|
||||||
|
self, frame: DataFrameType, expr: p.SubscriptExpr
|
||||||
|
) -> Type:
|
||||||
|
match expr.index:
|
||||||
|
case p.LiteralExpr(value=str() as name):
|
||||||
|
column: Optional[ColumnType] = FrameManager.get_column(frame, name)
|
||||||
|
if column is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Unknown column '{name}' on {frame}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return column
|
||||||
|
|
||||||
|
case p.ListExpr(items=indices) if all(
|
||||||
|
isinstance(index, p.LiteralExpr) and isinstance(index.value, str)
|
||||||
|
for index in indices
|
||||||
|
):
|
||||||
|
indices = cast(list[p.LiteralExpr], indices)
|
||||||
|
names: list[str] = [cast(str, index.value) for index in indices]
|
||||||
|
columns: list[ColumnType] = []
|
||||||
|
for name in names:
|
||||||
|
column: Optional[ColumnType] = FrameManager.get_column(frame, name)
|
||||||
|
if column is None:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Unknown column '{name}' on {frame}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
columns.append(column)
|
||||||
|
return TupleType(items=tuple(columns))
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location, f"Invalid index type {expr.index} on {frame}"
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|||||||
Reference in New Issue
Block a user