diff --git a/midas/checker/dispatcher.py b/midas/checker/dispatcher.py index b352ebc..32ae31d 100644 --- a/midas/checker/dispatcher.py +++ b/midas/checker/dispatcher.py @@ -74,6 +74,9 @@ class CallDispatcher(Generic[E]): self.reporter: FileReporter = reporter self.logger: logging.Logger = logging.getLogger("CallDispatcher") + def set_reporter(self, reporter: FileReporter): + self.reporter = reporter + def get_result( self, location: Location, diff --git a/midas/checker/midas.py b/midas/checker/midas.py index b6048cb..32716fc 100644 --- a/midas/checker/midas.py +++ b/midas/checker/midas.py @@ -62,8 +62,11 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type def __init__(self, types: TypesRegistry, reporter: Reporter) -> None: self.logger: logging.Logger = logging.getLogger("MidasTyper") self.reporter: FileReporter = reporter.for_file(None) - self.types: TypesRegistry = types + self.dispatcher: CallDispatcher[m.Expr] = CallDispatcher[m.Expr]( + self.types, self.reporter + ) + self._local_variables: dict[str, TypeVar] = {} self._predicate_params: dict[str, Type] = {} @@ -78,8 +81,14 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type self._preamble: Environment = Preamble(self.types) + def set_reporter(self, reporter: FileReporter): + self.reporter = reporter + self.dispatcher.set_reporter(reporter) + def process(self, source: str, path: Optional[str]): - self.reporter = self.reporter.for_file(path) + reporter: FileReporter = self.reporter.for_file(path) + self.set_reporter(reporter) + lexer: MidasLexer = MidasLexer(source) tokens: list[Token] = lexer.process() parser: MidasParser = MidasParser(tokens) @@ -254,8 +263,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type ) return UnknownType() - dispatcher = CallDispatcher(self.types, self.reporter) - result: CallResult = dispatcher.get_result( + result: CallResult = self.dispatcher.get_result( location=location, callee=operation, positional=[(right_expr, right)], @@ -281,8 +289,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type ) return UnknownType() - dispatcher = CallDispatcher(self.types, self.reporter) - result: CallResult = dispatcher.get_result( + result: CallResult = self.dispatcher.get_result( location=expr.location, callee=operation, positional=[], @@ -298,8 +305,7 @@ class MidasTyper(m.Stmt.Visitor[None], m.Expr.Visitor[Type], m.Type.Visitor[Type keywords: dict[str, tuple[m.Expr, Type]] = { name: (arg, self.type_of(arg)) for name, arg in expr.keywords.items() } - dispatcher = CallDispatcher(self.types, self.reporter) - result: CallResult = dispatcher.get_result( + result: CallResult = self.dispatcher.get_result( location=expr.location, callee=callee, positional=positional, diff --git a/midas/checker/python.py b/midas/checker/python.py index 4538988..3b9f3ce 100644 --- a/midas/checker/python.py +++ b/midas/checker/python.py @@ -84,9 +84,17 @@ class PythonTyper( self.locals: dict[p.Expr, int] = {} self.judgements: list[tuple[p.Expr, Type]] = [] self.evaluated_casts: list[p.CastExpr] = [] + self.dispatcher: CallDispatcher[p.Expr] = CallDispatcher[p.Expr]( + self.types, self.reporter + ) + + def set_reporter(self, reporter: FileReporter): + self.reporter = reporter + self.dispatcher.set_reporter(self.reporter) def process(self, source: str, path: Optional[str]) -> TypedAST: - self.reporter = self.reporter.for_file(path) + reporter: FileReporter = self.reporter.for_file(path) + self.set_reporter(reporter) tree: ast.Module = ast.parse(source, filename=path or "") parser = PythonParser() @@ -221,8 +229,7 @@ class PythonTyper( if method is None: raise UndefinedMethodException - dispatcher = CallDispatcher(self.types, self.reporter) - result: CallResult = dispatcher.get_result( + result: CallResult = self.dispatcher.get_result( location=location, callee=method, positional=positional, @@ -572,8 +579,7 @@ class PythonTyper( ) callee: Type = self.type_of(expr.callee) - dispatcher = CallDispatcher(self.types, self.reporter) - result: CallResult = dispatcher.get_result( + result: CallResult = self.dispatcher.get_result( location=expr.location, callee=callee, positional=positional, @@ -742,8 +748,7 @@ class PythonTyper( return UnknownType() index: Type = self.type_of(expr.index) - dispatcher = CallDispatcher(self.types, self.reporter) - result: CallResult = dispatcher.get_result( + result: CallResult = self.dispatcher.get_result( location=expr.location, callee=operation, positional=[(expr.index, index)], @@ -808,8 +813,7 @@ class PythonTyper( index: p.Expr = p.LiteralExpr(location=expr.location, value=0) index_type: Type = self.compute_type(index) - dispatcher = CallDispatcher(self.types, self.reporter) - result: CallResult = dispatcher.get_result( + result: CallResult = self.dispatcher.get_result( location=expr.location, callee=getitem, positional=[(index, index_type)],