feat(checker): make checker return TypedAST

This commit is contained in:
2026-06-15 14:16:10 +02:00
parent f25341b1e7
commit 74f51f361a
3 changed files with 18 additions and 4 deletions

View File

@@ -6,6 +6,7 @@ from midas.checker.midas import MidasTyper
from midas.checker.python import PythonTyper
from midas.checker.registry import TypesRegistry
from midas.checker.reporter import Reporter
from midas.utils import TypedAST
class TypeChecker:
@@ -23,12 +24,12 @@ class TypeChecker:
def import_midas_source(self, source: str, path: Optional[str] = None):
self.midas_typer.process(source, path)
def type_check(self, path: Path):
def type_check(self, path: Path) -> TypedAST:
source: str = path.read_text()
return self.type_check_source(source, path=str(path))
def type_check_source(self, source: str, path: Optional[str] = None):
self.python_typer.process(source, path)
def type_check_source(self, source: str, path: Optional[str] = None) -> TypedAST:
return self.python_typer.process(source, path)
@property
def diagnostics(self) -> list[Diagnostic]:

View File

@@ -19,6 +19,7 @@ from midas.checker.types import (
unfold_type,
)
from midas.parser.python import PythonParser
from midas.utils import TypedAST
TypedExpr = tuple[p.Expr, Type]
@@ -60,7 +61,7 @@ class PythonTyper(
self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = []
def process(self, source: str, path: Optional[str]):
def process(self, source: str, path: Optional[str]) -> TypedAST:
self.reporter = self.reporter.for_file(path)
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
@@ -75,6 +76,8 @@ class PythonTyper(
self.check(stmts)
return TypedAST(stmts=stmts, judgements=self.judgements)
def type_of(self, expr: p.Expr) -> Type:
"""Evaluate the type of an expression

View File

@@ -1,5 +1,9 @@
from dataclasses import dataclass
from typing import Any, Callable, Optional
import midas.ast.python as p
from midas.checker.types import Type
AllowRepeat = Callable[[object], bool]
@@ -52,3 +56,9 @@ class UniversalJSONDumper:
}
case _:
raise ValueError(f"Unsupported value: {obj}")
@dataclass(frozen=True, kw_only=True)
class TypedAST:
stmts: list[p.Stmt]
judgements: list[tuple[p.Expr, Type]]