From 7c69d172d253ccb87d8c6dedc0cf31e40d77c225 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 6 Jan 2025 20:09:55 +0000 Subject: [PATCH] EHN: cluster: JAX support (non-jitted) --- scipy/cluster/tests/test_hierarchy.py | 69 +++++++++++---------------- scipy/cluster/tests/test_vq.py | 10 ---- scipy/cluster/vq.py | 11 +++-- 3 files changed, 33 insertions(+), 57 deletions(-) diff --git a/scipy/cluster/tests/test_hierarchy.py b/scipy/cluster/tests/test_hierarchy.py index b8f568625980..5a0a0b9bfd53 100644 --- a/scipy/cluster/tests/test_hierarchy.py +++ b/scipy/cluster/tests/test_hierarchy.py @@ -50,6 +50,7 @@ from scipy.cluster._hierarchy import Heap from scipy.conftest import array_api_compatible from scipy._lib._array_api import xp_assert_close, xp_assert_equal +import scipy._lib.array_api_extra as xpx from threading import Lock @@ -445,8 +446,6 @@ def test_is_valid_linkage_4_and_up(self, xp): Z = linkage(y) assert_(is_valid_linkage(Z) is True) - @skip_xp_backends('jax.numpy', - reason='jax arrays do not support item assignment') def test_is_valid_linkage_4_and_up_neg_index_left(self, xp): # Tests is_valid_linkage(Z) on linkage on observation sets between # sizes 4 and 15 (step size 3) with negative indices (left). @@ -454,12 +453,10 @@ def test_is_valid_linkage_4_and_up_neg_index_left(self, xp): y = np.random.rand(i*(i-1)//2) y = xp.asarray(y) Z = linkage(y) - Z[i//2,0] = -2 + Z = xpx.at(Z)[i//2, 0].set(-2) assert_(is_valid_linkage(Z) is False) assert_raises(ValueError, is_valid_linkage, Z, throw=True) - @skip_xp_backends('jax.numpy', - reason='jax arrays do not support item assignment') def test_is_valid_linkage_4_and_up_neg_index_right(self, xp): # Tests is_valid_linkage(Z) on linkage on observation sets between # sizes 4 and 15 (step size 3) with negative indices (right). @@ -467,12 +464,10 @@ def test_is_valid_linkage_4_and_up_neg_index_right(self, xp): y = np.random.rand(i*(i-1)//2) y = xp.asarray(y) Z = linkage(y) - Z[i//2,1] = -2 + Z = xpx.at(Z)[i//2, 1].set(-2) assert_(is_valid_linkage(Z) is False) assert_raises(ValueError, is_valid_linkage, Z, throw=True) - @skip_xp_backends('jax.numpy', - reason='jax arrays do not support item assignment') def test_is_valid_linkage_4_and_up_neg_dist(self, xp): # Tests is_valid_linkage(Z) on linkage on observation sets between # sizes 4 and 15 (step size 3) with negative distances. @@ -480,12 +475,10 @@ def test_is_valid_linkage_4_and_up_neg_dist(self, xp): y = np.random.rand(i*(i-1)//2) y = xp.asarray(y) Z = linkage(y) - Z[i//2,2] = -0.5 + Z = xpx.at(Z)[i//2, 2].set(-0.5) assert_(is_valid_linkage(Z) is False) assert_raises(ValueError, is_valid_linkage, Z, throw=True) - @skip_xp_backends('jax.numpy', - reason='jax arrays do not support item assignment') def test_is_valid_linkage_4_and_up_neg_counts(self, xp): # Tests is_valid_linkage(Z) on linkage on observation sets between # sizes 4 and 15 (step size 3) with negative counts. @@ -493,7 +486,7 @@ def test_is_valid_linkage_4_and_up_neg_counts(self, xp): y = np.random.rand(i*(i-1)//2) y = xp.asarray(y) Z = linkage(y) - Z[i//2,3] = -2 + Z = xpx.at(Z)[i//2, 3].set(-2) assert_(is_valid_linkage(Z) is False) assert_raises(ValueError, is_valid_linkage, Z, throw=True) @@ -538,7 +531,6 @@ def test_is_valid_im_4_and_up(self, xp): R = inconsistent(Z) assert_(is_valid_im(R) is True) - @skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment') def test_is_valid_im_4_and_up_neg_index_left(self, xp): # Tests is_valid_im(R) on im on observation sets between sizes 4 and 15 # (step size 3) with negative link height means. @@ -547,11 +539,10 @@ def test_is_valid_im_4_and_up_neg_index_left(self, xp): y = xp.asarray(y) Z = linkage(y) R = inconsistent(Z) - R[i//2,0] = -2.0 + R = xpx.at(R)[i//2 , 0].set(-2.0) assert_(is_valid_im(R) is False) assert_raises(ValueError, is_valid_im, R, throw=True) - @skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment') def test_is_valid_im_4_and_up_neg_index_right(self, xp): # Tests is_valid_im(R) on im on observation sets between sizes 4 and 15 # (step size 3) with negative link height standard deviations. @@ -560,11 +551,10 @@ def test_is_valid_im_4_and_up_neg_index_right(self, xp): y = xp.asarray(y) Z = linkage(y) R = inconsistent(Z) - R[i//2,1] = -2.0 + R = xpx.at(R)[i//2 , 1].set(-2.0) assert_(is_valid_im(R) is False) assert_raises(ValueError, is_valid_im, R, throw=True) - @skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment') def test_is_valid_im_4_and_up_neg_dist(self, xp): # Tests is_valid_im(R) on im on observation sets between sizes 4 and 15 # (step size 3) with negative link counts. @@ -573,7 +563,7 @@ def test_is_valid_im_4_and_up_neg_dist(self, xp): y = xp.asarray(y) Z = linkage(y) R = inconsistent(Z) - R[i//2,2] = -0.5 + R = xpx.at(R)[i//2, 2].set(-0.5) assert_(is_valid_im(R) is False) assert_raises(ValueError, is_valid_im, R, throw=True) @@ -766,12 +756,11 @@ def test_is_monotonic_tdist_linkage1(self, xp): Z = linkage(xp.asarray(hierarchy_test_data.ytdist), 'single') assert is_monotonic(Z) - @skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment') def test_is_monotonic_tdist_linkage2(self, xp): # Tests is_monotonic(Z) on clustering generated by single linkage on # tdist data set. Perturbing. Expecting False. Z = linkage(xp.asarray(hierarchy_test_data.ytdist), 'single') - Z[2,2] = 0.0 + Z = xpx.at(Z)[2, 2].set(0.0) assert not is_monotonic(Z) def test_is_monotonic_Q_linkage(self, xp): @@ -790,7 +779,6 @@ def test_maxdists_empty_linkage(self, xp): Z = xp.zeros((0, 4), dtype=xp.float64) assert_raises(ValueError, maxdists, Z) - @skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment') def test_maxdists_one_cluster_linkage(self, xp): # Tests maxdists(Z) on linkage with one cluster. Z = xp.asarray([[0, 1, 0.3, 4]], dtype=xp.float64) @@ -798,7 +786,6 @@ def test_maxdists_one_cluster_linkage(self, xp): expectedMD = calculate_maximum_distances(Z, xp) xp_assert_close(MD, expectedMD, atol=1e-15) - @skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment') def test_maxdists_Q_linkage(self, xp): for method in ['single', 'complete', 'ward', 'centroid', 'median']: self.check_maxdists_Q_linkage(method, xp) @@ -829,8 +816,7 @@ def test_maxinconsts_difrow_linkage(self, xp): R = xp.asarray(R) assert_raises(ValueError, maxinconsts, Z, R) - @skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment', - cpu_only=True) + @skip_xp_backends(cpu_only=True, reason="implicit device->host transfer") def test_maxinconsts_one_cluster_linkage(self, xp): # Tests maxinconsts(Z, R) on linkage with one cluster. Z = xp.asarray([[0, 1, 0.3, 4]], dtype=xp.float64) @@ -839,8 +825,7 @@ def test_maxinconsts_one_cluster_linkage(self, xp): expectedMD = calculate_maximum_inconsistencies(Z, R, xp=xp) xp_assert_close(MD, expectedMD, atol=1e-15) - @skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment', - cpu_only=True) + @skip_xp_backends(cpu_only=True, reason="implicit device->host transfer") def test_maxinconsts_Q_linkage(self, xp): for method in ['single', 'complete', 'ward', 'centroid', 'median']: self.check_maxinconsts_Q_linkage(method, xp) @@ -893,8 +878,7 @@ def check_maxRstat_difrow_linkage(self, i, xp): R = xp.asarray(R) assert_raises(ValueError, maxRstat, Z, R, i) - @skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment', - cpu_only=True) + @skip_xp_backends(cpu_only=True, reason="implicit device->host transfer") def test_maxRstat_one_cluster_linkage(self, xp): for i in range(4): self.check_maxRstat_one_cluster_linkage(i, xp) @@ -907,8 +891,7 @@ def check_maxRstat_one_cluster_linkage(self, i, xp): expectedMD = calculate_maximum_inconsistencies(Z, R, 1, xp) xp_assert_close(MD, expectedMD, atol=1e-15) - @skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment', - cpu_only=True) + @skip_xp_backends(cpu_only=True, reason="implicit device->host transfer") def test_maxRstat_Q_linkage(self, xp): for method in ['single', 'complete', 'ward', 'centroid', 'median']: for i in range(4): @@ -1129,17 +1112,18 @@ def calculate_maximum_distances(Z, xp): # Used for testing correctness of maxdists. n = Z.shape[0] + 1 B = xp.zeros((n-1,), dtype=Z.dtype) - q = xp.zeros((3,)) for i in range(0, n - 1): - q[:] = 0.0 + q = xp.zeros((3,)) left = Z[i, 0] right = Z[i, 1] if left >= n: - q[0] = B[xp.asarray(left, dtype=xp.int64) - n] + b_left = B[xp.asarray(left, dtype=xp.int64) - n] + q = xpx.at(q, 0).set(b_left) if right >= n: - q[1] = B[xp.asarray(right, dtype=xp.int64) - n] - q[2] = Z[i, 2] - B[i] = xp.max(q) + b_right = B[xp.asarray(right, dtype=xp.int64) - n] + q = xpx.at(q, 1).set(b_right) + q = xpx.at(q, 2).set(Z[i, 2]) + B = xpx.at(B, i).set(xp.max(q)) return B @@ -1148,17 +1132,18 @@ def calculate_maximum_inconsistencies(Z, R, k=3, xp=np): n = Z.shape[0] + 1 dtype = xp.result_type(Z, R) B = xp.zeros((n-1,), dtype=dtype) - q = xp.zeros((3,)) for i in range(0, n - 1): - q[:] = 0.0 + q = xp.zeros((3,)) left = Z[i, 0] right = Z[i, 1] if left >= n: - q[0] = B[xp.asarray(left, dtype=xp.int64) - n] + b_left = B[xp.asarray(left, dtype=xp.int64) - n] + q = xpx.at(q, 0).set(b_left) if right >= n: - q[1] = B[xp.asarray(right, dtype=xp.int64) - n] - q[2] = R[i, k] - B[i] = xp.max(q) + b_right = B[xp.asarray(right, dtype=xp.int64) - n] + q = xpx.at(q, 1).set(b_right) + q = xpx.at(q, 2).set(R[i, k]) + B = xpx.at(B, i).set(xp.max(q)) return B diff --git a/scipy/cluster/tests/test_vq.py b/scipy/cluster/tests/test_vq.py index d0321e7d81d7..38c7e7a5d922 100644 --- a/scipy/cluster/tests/test_vq.py +++ b/scipy/cluster/tests/test_vq.py @@ -100,8 +100,6 @@ def test_whiten(self, xp): def whiten_lock(self): return Lock() - @skip_xp_backends('jax.numpy', - reason='jax arrays do not support item assignment') def test_whiten_zero_std(self, xp, whiten_lock): desired = xp.asarray([[0., 1.0, 2.86666544], [0., 1.0, 1.32460034], @@ -334,8 +332,6 @@ def test_kmeans2_high_dim(self, xp): data = xp.reshape(data, (20, 20))[:10, :] kmeans2(data, 2) - @skip_xp_backends('jax.numpy', - reason='jax arrays do not support item assignment') def test_kmeans2_init(self, xp): rng = np.random.default_rng(12345678) data = xp.asarray(TESTDATA_2D) @@ -390,8 +386,6 @@ def test_kmeans_large_thres(self, xp): xp_assert_close(res[0], xp.asarray([4.], dtype=xp.float64)) xp_assert_close(res[1], xp.asarray(2.3999999999999999, dtype=xp.float64)[()]) - @skip_xp_backends('jax.numpy', - reason='jax arrays do not support item assignment') def test_kmeans2_kpp_low_dim(self, xp): # Regression test for gh-11462 rng = np.random.default_rng(2358792345678234568) @@ -401,8 +395,6 @@ def test_kmeans2_kpp_low_dim(self, xp): xp_assert_close(res, prev_res) @pytest.mark.thread_unsafe - @skip_xp_backends('jax.numpy', - reason='jax arrays do not support item assignment') def test_kmeans2_kpp_high_dim(self, xp): # Regression test for gh-11462 rng = np.random.default_rng(23587923456834568) @@ -427,8 +419,6 @@ def test_kmeans_diff_convergence(self, xp): xp_assert_close(res[0], xp.asarray([-0.4, 8.], dtype=xp.float64)) xp_assert_close(res[1], xp.asarray(1.0666666666666667, dtype=xp.float64)[()]) - @skip_xp_backends('jax.numpy', - reason='jax arrays do not support item assignment') def test_kmeans_and_kmeans2_random_seed(self, xp): seed_list = [ diff --git a/scipy/cluster/vq.py b/scipy/cluster/vq.py index 34045c1357fe..ddf4f0a66611 100644 --- a/scipy/cluster/vq.py +++ b/scipy/cluster/vq.py @@ -137,8 +137,8 @@ def whiten(obs, check_finite=True): obs = _asarray(obs, check_finite=check_finite, xp=xp) std_dev = xp.std(obs, axis=0) zero_std_mask = std_dev == 0 - if xp.any(zero_std_mask): - std_dev[zero_std_mask] = 1.0 + std_dev = xpx.at(std_dev, zero_std_mask).set(1.0) + if check_finite and xp.any(zero_std_mask): warnings.warn("Some columns have standard deviation zero. " "The values of these columns will not change.", RuntimeWarning, stacklevel=2) @@ -607,15 +607,16 @@ def _kpp(data, k, rng, xp): for i in range(k): if i == 0: - init[i, :] = data[rng_integers(rng, data.shape[0]), :] - + data_idx = rng_integers(rng, data.shape[0]) else: D2 = cdist(init[:i,:], data, metric='sqeuclidean').min(axis=0) probs = D2/D2.sum() cumprobs = probs.cumsum() r = rng.uniform() cumprobs = np.asarray(cumprobs) - init[i, :] = data[np.searchsorted(cumprobs, r), :] + data_idx = np.searchsorted(cumprobs, r) + + init = xpx.at(init)[i, :].set(data[data_idx, :]) if ndim == 1: init = init[:, 0]