Skip to content

Commit 5f2d99e

Browse files
dschulttylerjereddy
authored andcommitted
revert NotImplemented return values in dot & others (scipy#22373)
1 parent 1761a4b commit 5f2d99e

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

scipy/sparse/_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,8 @@ def _matmul_dispatch(self, other):
685685

686686
return result
687687

688-
return NotImplemented
688+
else:
689+
raise ValueError('could not interpret dimensions')
689690

690691
def __mul__(self, other):
691692
return self.multiply(other)

scipy/sparse/_coo.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -866,10 +866,7 @@ def dot(self, other):
866866
o_array = np.asanyarray(other)
867867

868868
if o_array.ndim == 0 and o_array.dtype == np.object_:
869-
# Not interpretable as an array; return NotImplemented so that
870-
# other's __rmatmul__ can kick in if that's implemented.
871-
return NotImplemented
872-
869+
raise TypeError(f"dot argument not supported type: '{type(other)}'")
873870
try:
874871
other.shape
875872
except AttributeError:
@@ -1082,9 +1079,7 @@ def tensordot(self, other, axes=2):
10821079
other_array = np.asanyarray(other)
10831080

10841081
if other_array.ndim == 0 and other_array.dtype == np.object_:
1085-
# Not interpretable as an array; return NotImplemented
1086-
return NotImplemented
1087-
1082+
raise TypeError(f"tensordot arg not supported type: '{type(other)}'")
10881083
try:
10891084
other.shape
10901085
except AttributeError:

scipy/sparse/tests/test_coo.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,18 @@ def test_dot_with_inconsistent_shapes():
719719
arr_a.dot(arr_b)
720720

721721

722+
def test_matmul_dot_not_implemented():
723+
arr_a = coo_array([[1, 2], [3, 4]])
724+
with pytest.raises(TypeError, match="argument not supported type"):
725+
arr_a.dot(None)
726+
with pytest.raises(TypeError, match="arg not supported type"):
727+
arr_a.tensordot(None)
728+
with pytest.raises(TypeError, match="unsupported operand type"):
729+
arr_a @ None
730+
with pytest.raises(TypeError, match="unsupported operand type"):
731+
None @ arr_a
732+
733+
722734
dot_shapes = [
723735
((3,3), (3,3)), ((4,6), (6,7)), ((1,4), (4,1)), # matrix multiplication 2-D
724736
((3,2,4,7), (7,)), ((5,), (6,3,5,2)), # dot of n-D and 1-D arrays

0 commit comments

Comments
 (0)