Skip to content

Commit 246c740

Browse files
PTNobeldschult
authored andcommitted
BUG: sparse: fix selecting wrong dtype for coo coords (scipy#22353)
* Fixes bug with selecting wrong dtype for coo coords * add test for reshape having wrong dtype * test that smaller dtype is maintained even if intermediate values are big --------- Co-authored-by: Dan Schult <dschult@colgate.edu>
1 parent c9a6140 commit 246c740

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

scipy/sparse/_coo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def reshape(self, *args, **kwargs):
148148
else:
149149
new_coords = np.unravel_index(flat_coords, shape, order=order)
150150

151-
idx_dtype = self._get_index_dtype(self.coords, maxval=max(self.shape))
151+
idx_dtype = self._get_index_dtype(self.coords, maxval=max(shape))
152152
new_coords = tuple(np.asarray(co, dtype=idx_dtype) for co in new_coords)
153153

154154
# Handle copy here rather than passing on to the constructor so that no

scipy/sparse/tests/test_coo.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,22 @@ def test_non_subscriptability():
129129
match="'coo_array' object is not subscriptable"):
130130
coo_2d[0, :]
131131

132+
def test_reshape_overflow():
133+
# see gh-22353 : new idx_dtype can need to be int64 instead of int32
134+
M, N = (1045507, 523266)
135+
coords = (np.array([M - 1], dtype='int32'), np.array([N - 1], dtype='int32'))
136+
A = coo_array(([3.3], coords), shape=(M, N))
137+
138+
# need new idx_dtype to not overflow
139+
B = A.reshape((M * N, 1))
140+
assert B.coords[0].dtype == np.dtype('int64')
141+
assert B.coords[0][0] == (M * N) - 1
142+
143+
# need idx_dtype to stay int32 if before and after can be int32
144+
C = A.reshape(N, M)
145+
assert C.coords[0].dtype == np.dtype('int32')
146+
assert C.coords[0][0] == N - 1
147+
132148
def test_reshape():
133149
arr1d = coo_array([1, 0, 3])
134150
assert arr1d.shape == (3,)

0 commit comments

Comments
 (0)