Skip to content

Commit c7b7b01

Browse files
Jake VanderPlasjax authors
authored andcommitted
remove test of deprecated jax.random.shuffle API
PiperOrigin-RevId: 623499655
1 parent abfbb0a commit c7b7b01

File tree

1 file changed

+0
-16
lines changed

1 file changed

+0
-16
lines changed

tests/random_lax_test.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -210,22 +210,6 @@ def testTruncatedNormal(self, dtype):
210210
for samples in [uncompiled_samples, compiled_samples]:
211211
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.truncnorm(-0.3, 0.3).cdf)
212212

213-
@jtu.sample_product(dtype=jtu.dtypes.floating + jtu.dtypes.integer)
214-
def testShuffle(self, dtype):
215-
key = lambda: self.make_key(0)
216-
x = np.arange(100).astype(dtype)
217-
rand = lambda key: random.shuffle(key, x)
218-
crand = jax.jit(rand)
219-
220-
with self.assertWarns((DeprecationWarning, FutureWarning)):
221-
perm1 = rand(key())
222-
with self.assertWarns((DeprecationWarning, FutureWarning)):
223-
perm2 = crand(key())
224-
225-
self.assertAllClose(perm1, perm2)
226-
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
227-
self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
228-
229213
@jtu.sample_product(
230214
[dict(shape=shape, replace=replace, axis=axis,
231215
input_range_or_shape=input_range_or_shape)

0 commit comments

Comments
 (0)