@@ -1348,6 +1348,7 @@ def test_vmap_fold_in_shape(self):
1348
1348
out = vmap (vmap (random .fold_in ), in_axes = (1 , 0 ))(keys (), msgs .T )
1349
1349
self .assertEqual (out .shape , (3 , 2 ))
1350
1350
1351
+ @jax .enable_key_reuse_checks (False )
1351
1352
def test_vmap_split_mapped_key (self ):
1352
1353
key = self .make_key (73 )
1353
1354
mapped_keys = random .split (key , num = 3 )
@@ -1408,24 +1409,57 @@ def test_vmap_split_not_mapped_key(self):
1408
1409
self .assertArraysEqual (random .key_data (vk ),
1409
1410
random .key_data (single_split_key ))
1410
1411
1411
- def test_vmap_split_mapped_key (self ):
1412
+ @jax .enable_key_reuse_checks (False )
1413
+ def test_vmap_split_mapped_key_shape (self ):
1412
1414
key = self .make_key (73 )
1413
1415
mapped_keys = random .split (key , num = 3 )
1414
- forloop_keys = [random .split (k ) for k in mapped_keys ]
1415
1416
vmapped_keys = vmap (random .split )(mapped_keys )
1416
1417
self .assertEqual (vmapped_keys .shape , (3 , 2 , * key .shape ))
1417
- for fk , vk in zip (forloop_keys , vmapped_keys ):
1418
- self .assertArraysEqual (random .key_data (fk ),
1418
+
1419
+ @jax .enable_key_reuse_checks (False )
1420
+ def test_vmap_split_mapped_key_values (self ):
1421
+ key = self .make_key (73 )
1422
+ mapped_keys = random .split (key , num = 3 )
1423
+ vmapped_keys = vmap (random .split )(mapped_keys )
1424
+ ref_keys = [random .split (k ) for k in mapped_keys ]
1425
+ for rk , vk in zip (ref_keys , vmapped_keys ):
1426
+ self .assertArraysEqual (random .key_data (rk ),
1419
1427
random .key_data (vk ))
1420
1428
1421
- def test_vmap_random_bits (self ):
1422
- rand_fun = lambda key : random .randint (key , (), 0 , 100 )
1429
+ @jax .enable_key_reuse_checks (False )
1430
+ def test_vmap_random_bits_shape (self ):
1431
+ rand_fun = lambda key , shape = (): random .randint (key , shape , 0 , 100 )
1423
1432
key = self .make_key (73 )
1424
1433
mapped_keys = random .split (key , num = 3 )
1425
- forloop_rand_nums = [rand_fun (k ) for k in mapped_keys ]
1426
1434
rand_nums = vmap (rand_fun )(mapped_keys )
1427
1435
self .assertEqual (rand_nums .shape , (3 ,))
1428
- self .assertArraysEqual (rand_nums , jnp .array (forloop_rand_nums ))
1436
+
1437
+ @jtu .skip_on_devices ("tpu" )
1438
+ @jax .enable_key_reuse_checks (False )
1439
+ def test_vmap_random_bits_value (self ):
1440
+ rand_fun = lambda key , shape = (): random .randint (key , shape , 0 , 100 )
1441
+ key = self .make_key (73 )
1442
+ mapped_keys = random .split (key , num = 3 )
1443
+ rand_nums = vmap (rand_fun )(mapped_keys )
1444
+ ref_nums = rand_fun (mapped_keys [0 ], shape = (3 ,))
1445
+ self .assertArraysEqual (rand_nums , ref_nums )
1446
+
1447
+ def test_vmap_random_bits_distribution (self ):
1448
+ dtype = jnp .float32
1449
+ keys = lambda : jax .random .split (self .make_key (0 ), 10 )
1450
+
1451
+ def rand (key ):
1452
+ nums = jax .vmap (lambda key : random .uniform (key , (1000 ,), dtype ))(key )
1453
+ return nums .flatten ()
1454
+
1455
+ crand = jax .jit (rand )
1456
+
1457
+ uncompiled_samples = rand (keys ())
1458
+ compiled_samples = crand (keys ())
1459
+
1460
+ for samples in [uncompiled_samples , compiled_samples ]:
1461
+ self ._CheckCollisions (samples , jnp .finfo (dtype ).nmant )
1462
+ self ._CheckKolmogorovSmirnovCDF (samples , scipy .stats .uniform ().cdf )
1429
1463
1430
1464
def test_cannot_add (self ):
1431
1465
key = self .make_key (73 )
@@ -1455,6 +1489,15 @@ class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
1455
1489
def make_key (self , seed ):
1456
1490
return random .PRNGKey (seed , impl = "unsafe_rbg" )
1457
1491
1492
+ @jtu .skip_on_devices ("tpu" )
1493
+ @jax .enable_key_reuse_checks (False )
1494
+ def test_vmap_split_mapped_key_values (self ):
1495
+ key = self .make_key (73 )
1496
+ mapped_keys = random .split (key , num = 3 )
1497
+ vmapped_keys = vmap (random .split )(mapped_keys )
1498
+ ref_keys = random .split (mapped_keys [0 ], (3 , 2 ))
1499
+ self .assertArraysEqual (random .key_data (vmapped_keys ),
1500
+ random .key_data (ref_keys ))
1458
1501
1459
1502
def _sampler_unimplemented_with_custom_prng (* args , ** kwargs ):
1460
1503
raise SkipTest ('sampler only implemented for default RNG' )
0 commit comments