fix(checker): update binary operation lookup
This commit is contained in:
@@ -324,45 +324,34 @@ class PythonTyper(
|
|||||||
left: Type = self.type_of(expr.left)
|
left: Type = self.type_of(expr.left)
|
||||||
right: Type = self.type_of(expr.right)
|
right: Type = self.type_of(expr.right)
|
||||||
|
|
||||||
operations: list[Operation] = self.types.get_operations_by_name(method)
|
operation: Optional[Type] = self.types.lookup_member(left, method)
|
||||||
valid_operations: list[Operation] = []
|
if operation is None:
|
||||||
for op in operations:
|
|
||||||
sig: Operation.CallSignature = op.signature
|
|
||||||
if self.is_subtype(left, sig.left) and self.is_subtype(right, sig.right):
|
|
||||||
valid_operations.append(op)
|
|
||||||
|
|
||||||
if len(valid_operations) == 0:
|
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
expr.location,
|
expr.location,
|
||||||
f"Undefined operation {method} between {left} and {right}",
|
f"Undefined operation {method} between {left} and {right}",
|
||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
elif len(valid_operations) == 1:
|
|
||||||
self.logger.debug(f"Unique operation {method} between {left} and {right}")
|
|
||||||
return valid_operations[0].result
|
|
||||||
|
|
||||||
for i, op1 in enumerate(valid_operations):
|
|
||||||
sig1: Operation.CallSignature = op1.signature
|
|
||||||
best_match: bool = True
|
|
||||||
for j, op2 in enumerate(valid_operations):
|
|
||||||
if i == j:
|
|
||||||
continue
|
|
||||||
sig2: Operation.CallSignature = op2.signature
|
|
||||||
|
|
||||||
# If op1 is not a full overload of op2 (i.e. operands of op1 are subtypes of op2's)
|
|
||||||
# ambiguity -> not best match
|
|
||||||
if not self.is_subtype(sig1.left, sig2.left) or not self.is_subtype(
|
|
||||||
sig1.right, sig2.right
|
|
||||||
):
|
|
||||||
best_match = False
|
|
||||||
break
|
|
||||||
self.logger.debug(f"{op1} is a full overload of {op2}")
|
|
||||||
if best_match:
|
|
||||||
return op1.result
|
|
||||||
|
|
||||||
|
match operation:
|
||||||
|
case Function() as function:
|
||||||
|
if not self._is_binary_function(function):
|
||||||
self.reporter.error(
|
self.reporter.error(
|
||||||
expr.location,
|
expr.location,
|
||||||
f"Ambiguous operation {method} between {left} and {right}, multiple matching overloads: {', '.join(map(str, valid_operations))}",
|
f"Wrong definition of binary operation. Expected function with 2 positional-only parameters, got {function}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
|
||||||
|
rhs: Function.Argument = function.pos_args[0]
|
||||||
|
if not self.is_subtype(right, rhs.type):
|
||||||
|
self.reporter.error(
|
||||||
|
expr.location,
|
||||||
|
f"Wrong type for right-hand side, expected {rhs.type}, got {right}",
|
||||||
|
)
|
||||||
|
return UnknownType()
|
||||||
|
return function.returns
|
||||||
|
case _:
|
||||||
|
self.reporter.warning(
|
||||||
|
expr.location, f"Unsupported operation {operation}"
|
||||||
)
|
)
|
||||||
return UnknownType()
|
return UnknownType()
|
||||||
|
|
||||||
@@ -617,3 +606,12 @@ class PythonTyper(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return mapped
|
return mapped
|
||||||
|
|
||||||
|
def _is_binary_function(self, function: Function) -> bool:
|
||||||
|
if len(function.pos_args) != 1:
|
||||||
|
return False
|
||||||
|
if len(function.args) != 0:
|
||||||
|
return False
|
||||||
|
if len(function.kw_args) != 0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|||||||
Reference in New Issue
Block a user