From cc5e7af1432e1d309927621acdd735854964b1d6 Mon Sep 17 00:00:00 2001 From: LordBaryhobal Date: Tue, 23 Jun 2026 14:45:19 +0200 Subject: [PATCH] feat(gen): add support for tuples and dataframes --- midas/generator/generator.py | 22 +++++++++++++++++- midas/generator/stubs.py | 43 +++++++++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/midas/generator/generator.py b/midas/generator/generator.py index 0af3fcd..1792dcd 100644 --- a/midas/generator/generator.py +++ b/midas/generator/generator.py @@ -1,4 +1,5 @@ import ast +import logging import shutil from dataclasses import dataclass, field from pathlib import Path @@ -13,13 +14,16 @@ from midas.checker.types import ( AliasType, AppliedType, BaseType, + ColumnType, ComplexType, ConstraintType, + DataFrameType, ExtensionType, Function, GenericType, OverloadedFunction, TopType, + TupleType, Type, TypeVar, UnitType, @@ -40,6 +44,7 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): self.workdir: Path = workdir.resolve() self.build_dir: Path = self.workdir / "build" / "midas" self.rel_src_path: Path = Path() + self.logger: logging.Logger = logging.getLogger("Generator") self._typed_ast: TypedAST = TypedAST( stmts=[], @@ -332,6 +337,19 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): if bound is not None: self._make_cast_asserts(src_location, expr, bound) + case TupleType(items=items): + self._add_assert( + ast.Call( + func=ast.Name(id="isinstance"), + args=[expr, ast.Name(id="tuple")], + keywords=[], + ), + self._make_cast_assert_message(src_location, expr, type), + ) + assert isinstance(expr, ast.Tuple) + for item, item_type in zip(expr.elts, items): + self._make_cast_asserts(src_location, item, item_type) + case ( TopType() | Function() @@ -339,8 +357,10 @@ class Generator(p.Stmt.Visitor[ast.stmt], p.Expr.Visitor[ast.expr]): | ComplexType() | ExtensionType() | GenericType() + | ColumnType() + | DataFrameType() ): - raise NotImplementedError(f"Can't make assertion for type {type}") + self.logger.warning(f"Can't make assertion for type {type}") # Ensure exhaustiveness case _: diff --git a/midas/generator/stubs.py b/midas/generator/stubs.py index c9a3804..3088c76 100644 --- a/midas/generator/stubs.py +++ b/midas/generator/stubs.py @@ -7,13 +7,16 @@ from midas.checker.types import ( AliasType, AppliedType, BaseType, + ColumnType, ComplexType, ConstraintType, + DataFrameType, ExtensionType, Function, GenericType, OverloadedFunction, TopType, + TupleType, Type, TypeVar, UnitType, @@ -30,6 +33,7 @@ class StubsGenerator: self.types: TypesRegistry = types self.stubs: list[ast.stmt] = [] self.typing_imports: set[str] = set() + self.import_pandas: bool = False self.protocol_idx: int = 0 self.stub_idx: int = 0 self.type_var_idx: int = 0 @@ -38,6 +42,7 @@ class StubsGenerator: def generate_stubs(self) -> ast.Module: self.stubs = [] self.typing_imports = set() + self.import_pandas = False for name, type in self.types._types.items(): # Skip builtin types, not just based on name so the user can override # TODO: check if added members on builtin type @@ -53,7 +58,7 @@ class StubsGenerator: continue self.generate_stub(name, type) - imports = [ + imports: list[ast.stmt] = [ ast.ImportFrom( module="__future__", names=[ast.alias(name="annotations")], @@ -70,6 +75,17 @@ class StubsGenerator: level=0, ) ) + if self.import_pandas: + imports.append( + ast.Import( + names=[ + ast.alias( + name="pandas", + asname="pd", + ) + ], + ) + ) return ast.Module(body=imports + self.stubs, type_ignores=[]) def generate_stub(self, name: str, type: Type): @@ -231,6 +247,31 @@ class StubsGenerator: case ConstraintType(): return self.dump_type(type.type) + case TupleType(items=items): + return ast.Subscript( + value=ast.Name(id="tuple"), + slice=ast.Tuple( + elts=[self.dump_type(item) for item in items], + ), + ) + + case ColumnType(type=inner): + self.import_pandas = True + return ast.Subscript( + value=ast.Attribute( + value=ast.Name(id="pd"), + attr="Series", + ), + slice=self.dump_type(inner), + ) + + case DataFrameType(): + self.import_pandas = True + return ast.Attribute( + value=ast.Name(id="pd"), + attr="DataFrame", + ) + case _: assert_never(type)