feat(checker): map and check function call arguments
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@@ -19,6 +20,13 @@ class ReturnException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MappedArgument:
|
||||
expr: p.Expr
|
||||
type: Type
|
||||
argument: Function.Argument
|
||||
|
||||
|
||||
class Checker(
|
||||
p.Stmt.Visitor[None],
|
||||
p.Expr.Visitor[Type],
|
||||
@@ -126,15 +134,18 @@ class Checker(
|
||||
kw_args: list[Function.Argument] = []
|
||||
|
||||
def eval_arg_type(arg: p.Function.Argument) -> Type:
|
||||
if arg.type is None:
|
||||
return UnknownType()
|
||||
return arg.type.accept(self)
|
||||
if arg.type is not None:
|
||||
return arg.type.accept(self)
|
||||
if arg.default is not None:
|
||||
return arg.default.accept(self)
|
||||
return UnknownType()
|
||||
|
||||
for arg in stmt.posonlyargs:
|
||||
pos_args.append(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
for arg in stmt.args:
|
||||
@@ -142,6 +153,7 @@ class Checker(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
for arg in stmt.kwonlyargs:
|
||||
@@ -149,6 +161,7 @@ class Checker(
|
||||
Function.Argument(
|
||||
name=arg.name,
|
||||
type=eval_arg_type(arg),
|
||||
required=arg.default is None,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -175,7 +188,9 @@ class Checker(
|
||||
else:
|
||||
returns = inferred_return
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
function: Function = Function(
|
||||
name=stmt.name,
|
||||
pos_args=pos_args,
|
||||
args=args,
|
||||
kw_args=kw_args,
|
||||
@@ -240,11 +255,18 @@ class Checker(
|
||||
self.import_midas(path)
|
||||
return UnknownType()
|
||||
callee: Type = self.evaluate(expr.callee)
|
||||
arguments: list[Type] = [self.evaluate(arg) for arg in expr.arguments]
|
||||
keywords: dict[str, Type] = {
|
||||
name: self.evaluate(arg) for name, arg in expr.keywords.items()
|
||||
}
|
||||
return UnknownType()
|
||||
if not isinstance(callee, Function):
|
||||
self.error(expr.callee.location, "Callee is not a function")
|
||||
return UnknownType()
|
||||
function: Function = callee
|
||||
mapped: list[MappedArgument] = self.map_call_arguments(function, expr)
|
||||
for arg in mapped:
|
||||
if arg.type != arg.argument.type:
|
||||
self.error(
|
||||
arg.expr.location,
|
||||
f"Wrong type for argument '{arg.argument.name}', expected {arg.argument.type}, got {arg.type}",
|
||||
)
|
||||
return function.returns
|
||||
|
||||
def visit_get_expr(self, expr: p.GetExpr) -> Type: ...
|
||||
|
||||
@@ -277,3 +299,87 @@ class Checker(
|
||||
def visit_frame_column(self, node: p.FrameColumn) -> Type: ...
|
||||
|
||||
def visit_frame_type(self, node: p.FrameType) -> Type: ...
|
||||
|
||||
def map_call_arguments(
|
||||
self, function: Function, call: p.CallExpr
|
||||
) -> list[MappedArgument]:
|
||||
positional: list[tuple[p.Expr, Type]] = [
|
||||
(arg, self.evaluate(arg)) for arg in call.arguments
|
||||
]
|
||||
keywords: dict[str, tuple[p.Expr, Type]] = {
|
||||
name: (arg, self.evaluate(arg)) for name, arg in call.keywords.items()
|
||||
}
|
||||
set_args: set[str] = set()
|
||||
|
||||
required_positional: set[str] = {
|
||||
arg.name for arg in function.pos_args + function.args if arg.required
|
||||
}
|
||||
required_keyword: set[str] = {
|
||||
arg.name for arg in function.kw_args if arg.required
|
||||
}
|
||||
|
||||
mapped: list[MappedArgument] = []
|
||||
|
||||
pos_params: list[Function.Argument] = list(function.pos_args)
|
||||
mixed_params: list[Function.Argument] = list(function.args)
|
||||
kw_params: dict[str, Function.Argument] = {
|
||||
arg.name: arg for arg in function.kw_args
|
||||
}
|
||||
|
||||
# TODO: handle *args and **kwargs sinks
|
||||
for arg in positional:
|
||||
param: Function.Argument
|
||||
if len(pos_params) != 0:
|
||||
param = pos_params.pop(0)
|
||||
elif len(mixed_params) != 0:
|
||||
param = mixed_params.pop(0)
|
||||
else:
|
||||
self.error(arg[0].location, "Too many positional arguments")
|
||||
break
|
||||
required_positional.discard(param.name)
|
||||
required_keyword.discard(param.name)
|
||||
set_args.add(param.name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
kw_params.update({arg.name: arg for arg in mixed_params})
|
||||
for name, arg in keywords.items():
|
||||
param: Function.Argument
|
||||
if name not in kw_params:
|
||||
if name in set_args:
|
||||
self.error(
|
||||
arg[0].location, f"Multiple values for argument '{name}'"
|
||||
)
|
||||
else:
|
||||
self.error(arg[0].location, f"Unknown keyword argument '{name}'")
|
||||
continue
|
||||
param = kw_params.pop(name)
|
||||
required_positional.discard(name)
|
||||
required_keyword.discard(name)
|
||||
set_args.add(name)
|
||||
mapped.append(
|
||||
MappedArgument(
|
||||
expr=arg[0],
|
||||
type=arg[1],
|
||||
argument=param,
|
||||
)
|
||||
)
|
||||
|
||||
if len(required_positional) != 0:
|
||||
self.error(
|
||||
call.location,
|
||||
f"Missing required positional arguments: {required_positional}",
|
||||
)
|
||||
|
||||
if len(required_keyword) != 0:
|
||||
self.error(
|
||||
call.location,
|
||||
f"Missing required keyword arguments: {required_keyword}",
|
||||
)
|
||||
|
||||
return mapped
|
||||
|
||||
@@ -26,6 +26,7 @@ class UnitType:
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Function:
|
||||
name: str
|
||||
pos_args: list[Argument]
|
||||
args: list[Argument]
|
||||
kw_args: list[Argument]
|
||||
@@ -35,6 +36,7 @@ class Function:
|
||||
class Argument:
|
||||
name: str
|
||||
type: Type
|
||||
required: bool
|
||||
|
||||
|
||||
Type = BaseType | SimpleType | UnknownType | UnitType | Function
|
||||
|
||||
Reference in New Issue
Block a user