refactor: reuse CallDispatcher

This commit is contained in:
2026-07-01 11:32:41 +02:00
parent 6e717a3f9e
commit 9a276c34c7
3 changed files with 30 additions and 17 deletions

View File

@@ -74,6 +74,9 @@ class CallDispatcher(Generic[E]):
self.reporter: FileReporter = reporter
self.logger: logging.Logger = logging.getLogger("CallDispatcher")
def set_reporter(self, reporter: FileReporter):
self.reporter = reporter
def get_result(
self,
location: Location,

View File

@@ -62,8 +62,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
def __init__(self, types: TypesRegistry, reporter: Reporter) -> None:
self.logger: logging.Logger = logging.getLogger("MidasTyper")
self.reporter: FileReporter = reporter.for_file(None)
self.types: TypesRegistry = types
self.dispatcher: CallDispatcher[m.Expr] = CallDispatcher[m.Expr](
self.types, self.reporter
)
self._local_variables: dict[str, TypeVar] = {}
self._predicate_params: dict[str, Type] = {}
@@ -78,8 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
self._preamble: Environment = Preamble(self.types)
def set_reporter(self, reporter: FileReporter):
self.reporter = reporter
self.dispatcher.set_reporter(reporter)
def process(self, source: str, path: Optional[str]):
self.reporter = self.reporter.for_file(path)
reporter: FileReporter = self.reporter.for_file(path)
self.set_reporter(reporter)
lexer: MidasLexer = MidasLexer(source)
tokens: list[Token] = lexer.process()
parser: MidasParser = MidasParser(tokens)
@@ -254,8 +263,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
)
return UnknownType()
dispatcher = CallDispatcher(self.types, self.reporter)
result: CallResult = dispatcher.get_result(
result: CallResult = self.dispatcher.get_result(
location=location,
callee=operation,
positional=[(right_expr, right)],
@@ -281,8 +289,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
)
return UnknownType()
dispatcher = CallDispatcher(self.types, self.reporter)
result: CallResult = dispatcher.get_result(
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=operation,
positional=[],
@@ -298,8 +305,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type
keywords: dict[str, tuple[m.Expr, Type]] = {
name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items()
}
dispatcher = CallDispatcher(self.types, self.reporter)
result: CallResult = dispatcher.get_result(
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=callee,
positional=positional,

View File

@@ -84,9 +84,17 @@ class PythonTyper(
self.locals: dict[p.Expr, int] = {}
self.judgements: list[tuple[p.Expr, Type]] = []
self.evaluated_casts: list[p.CastExpr] = []
self.dispatcher: CallDispatcher[p.Expr] = CallDispatcher[p.Expr](
self.types, self.reporter
)
def set_reporter(self, reporter: FileReporter):
self.reporter = reporter
self.dispatcher.set_reporter(self.reporter)
def process(self, source: str, path: Optional[str]) -> TypedAST:
self.reporter = self.reporter.for_file(path)
reporter: FileReporter = self.reporter.for_file(path)
self.set_reporter(reporter)
tree: ast.Module = ast.parse(source, filename=path or "<unknown>")
parser = PythonParser()
@@ -221,8 +229,7 @@ class PythonTyper(
if method is None:
raise UndefinedMethodException
dispatcher = CallDispatcher(self.types, self.reporter)
result: CallResult = dispatcher.get_result(
result: CallResult = self.dispatcher.get_result(
location=location,
callee=method,
positional=positional,
@@ -572,8 +579,7 @@ class PythonTyper(
)
callee: Type = self.type_of(expr.callee)
dispatcher = CallDispatcher(self.types, self.reporter)
result: CallResult = dispatcher.get_result(
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=callee,
positional=positional,
@@ -742,8 +748,7 @@ class PythonTyper(
return UnknownType()
index: Type = self.type_of(expr.index)
dispatcher = CallDispatcher(self.types, self.reporter)
result: CallResult = dispatcher.get_result(
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=operation,
positional=[(expr.index, index)],
@@ -808,8 +813,7 @@ class PythonTyper(
index: p.Expr = p.LiteralExpr(location=expr.location, value=0)
index_type: Type = self.compute_type(index)
dispatcher = CallDispatcher(self.types, self.reporter)
result: CallResult = dispatcher.get_result(
result: CallResult = self.dispatcher.get_result(
location=expr.location,
callee=getitem,
positional=[(index, index_type)],