diff --git a/midas/checker/checker.py b/midas/checker/checker.py index dcc7188..b14c406 100644 --- a/midas/checker/checker.py +++ b/midas/checker/checker.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import logging from pathlib import Path from typing import Optional @@ -19,6 +20,13 @@ class ReturnException(Exception): pass +@dataclass(frozen=True, kw_only=True) +class MappedArgument: + expr: p.Expr + type: Type + argument: Function.Argument + + class Checker( p.Stmt.Visitor[None], p.Expr.Visitor[Type], @@ -126,15 +134,18 @@ class Checker( kw_args: list[Function.Argument] = [] def eval_arg_type(arg: p.Function.Argument) -> Type: - if arg.type is None: - return UnknownType() - return arg.type.accept(self) + if arg.type is not None: + return arg.type.accept(self) + if arg.default is not None: + return arg.default.accept(self) + return UnknownType() for arg in stmt.posonlyargs: pos_args.append( Function.Argument( name=arg.name, type=eval_arg_type(arg), + required=arg.default is None, ) ) for arg in stmt.args: @@ -142,6 +153,7 @@ class Checker( Function.Argument( name=arg.name, type=eval_arg_type(arg), + required=arg.default is None, ) ) for arg in stmt.kwonlyargs: @@ -149,6 +161,7 @@ class Checker( Function.Argument( name=arg.name, type=eval_arg_type(arg), + required=arg.default is None, ) ) @@ -175,7 +188,9 @@ class Checker( else: returns = inferred_return + # TODO: handle *args and **kwargs sinks function: Function = Function( + name=stmt.name, pos_args=pos_args, args=args, kw_args=kw_args, @@ -240,11 +255,18 @@ class Checker( self.import_midas(path) return UnknownType() callee: Type = self.evaluate(expr.callee) - arguments: list[Type] = [self.evaluate(arg) for arg in expr.arguments] - keywords: dict[str, Type] = { - name: self.evaluate(arg) for name, arg in expr.keywords.items() - } - return UnknownType() + if not isinstance(callee, Function): + self.error(expr.callee.location, "Callee is not a function") + return UnknownType() + function: Function = callee + mapped: list[MappedArgument] = self.map_call_arguments(function, expr) + for arg in mapped: + if arg.type != arg.argument.type: + self.error( + arg.expr.location, + f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}", + ) + return function.returns def visit_get_expr(self, expr: p.GetExpr) -> Type: ... @@ -277,3 +299,87 @@ class Checker( def visit_frame_column(self, node: p.FrameColumn) -> Type: ... def visit_frame_type(self, node: p.FrameType) -> Type: ... + + def map_call_arguments( + self, function: Function, call: p.CallExpr + ) -> list[MappedArgument]: + positional: list[tuple[p.Expr, Type]] = [ + (arg, self.evaluate(arg)) for arg in call.arguments + ] + keywords: dict[str, tuple[p.Expr, Type]] = { + name: (arg, self.evaluate(arg)) for name, arg in call.keywords.items() + } + set_args: set[str] = set() + + required_positional: set[str] = { + arg.name for arg in function.pos_args + function.args if arg.required + } + required_keyword: set[str] = { + arg.name for arg in function.kw_args if arg.required + } + + mapped: list[MappedArgument] = [] + + pos_params: list[Function.Argument] = list(function.pos_args) + mixed_params: list[Function.Argument] = list(function.args) + kw_params: dict[str, Function.Argument] = { + arg.name: arg for arg in function.kw_args + } + + # TODO: handle *args and **kwargs sinks + for arg in positional: + param: Function.Argument + if len(pos_params) != 0: + param = pos_params.pop(0) + elif len(mixed_params) != 0: + param = mixed_params.pop(0) + else: + self.error(arg[0].location, "Too many positional arguments") + break + required_positional.discard(param.name) + required_keyword.discard(param.name) + set_args.add(param.name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + kw_params.update({arg.name: arg for arg in mixed_params}) + for name, arg in keywords.items(): + param: Function.Argument + if name not in kw_params: + if name in set_args: + self.error( + arg[0].location, f"Multiple values for argument '{name}'" + ) + else: + self.error(arg[0].location, f"Unknown keyword argument '{name}'") + continue + param = kw_params.pop(name) + required_positional.discard(name) + required_keyword.discard(name) + set_args.add(name) + mapped.append( + MappedArgument( + expr=arg[0], + type=arg[1], + argument=param, + ) + ) + + if len(required_positional) != 0: + self.error( + call.location, + f"Missing required positional arguments: {required_positional}", + ) + + if len(required_keyword) != 0: + self.error( + call.location, + f"Missing required keyword arguments: {required_keyword}", + ) + + return mapped diff --git a/midas/checker/types.py b/midas/checker/types.py index 2afd0ae..1e2f149 100644 --- a/midas/checker/types.py +++ b/midas/checker/types.py @@ -26,6 +26,7 @@ class UnitType: @dataclass(frozen=True, kw_only=True) class Function: + name: str pos_args: list[Argument] args: list[Argument] kw_args: list[Argument] @@ -35,6 +36,7 @@ class Function: class Argument: name: str type: Type + required: bool Type = BaseType | SimpleType | UnknownType | UnitType | Function