Skip to content

Commit 29edfd8

Browse files
committed
define a loop-free untrue batching rule for rng_bit_generator
1 parent f0afc1b commit 29edfd8

File tree

5 files changed

+82
-16
lines changed

5 files changed

+82
-16
lines changed

jax/_src/lax/control_flow/loops.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,16 +2046,18 @@ def map(f, xs):
20462046
return ys
20472047

20482048
def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm):
2049-
"""Calls RBG in a loop and stacks the results."""
2050-
key, = batched_args
2049+
keys, = batched_args
20512050
bd, = batch_dims
20522051
if bd is batching.not_mapped:
2053-
return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype,
2052+
return lax.rng_bit_generator_p.bind(keys, shape=shape, dtype=dtype,
20542053
algorithm=algorithm), (None, None)
2055-
key = batching.moveaxis(key, bd, 0)
2056-
map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm)
2057-
stacked_keys, stacked_bits = map(map_body, key)
2058-
return (stacked_keys, stacked_bits), (0, 0)
2054+
keys = batching.moveaxis(keys, bd, 0)
2055+
batch_size = keys.shape[0]
2056+
key = keys[0]
2057+
new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape),
2058+
dtype=dtype, algorithm=algorithm)
2059+
new_keys = jax.lax.dynamic_update_index_in_dim(keys, new_key, 0, axis=0)
2060+
return (new_keys, bits), (0, 0)
20592061

20602062
batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore
20612063

jax/_src/random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1233,7 +1233,7 @@ def _gamma_impl(key, a, *, log_space, use_vmap=False):
12331233
keys = keys.flatten()
12341234
alphas = a.flatten()
12351235

1236-
if use_vmap:
1236+
if use_vmap and _key_impl(key) is prng.threefry_prng_impl:
12371237
samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas)
12381238
else:
12391239
samples = lax.map(

tests/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,9 @@ jax_test(
784784
"notsan", # Times out
785785
],
786786
},
787+
backend_variant_args = {
788+
"gpu": ["--jax_num_generated_cases=40"],
789+
},
787790
shard_count = {
788791
"cpu": 40,
789792
"gpu": 30,

tests/lax_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2652,6 +2652,24 @@ def testRngBitGeneratorReturnedKey(self):
26522652
new_key, _ = lax.rng_bit_generator(key, (0,))
26532653
self.assertAllClose(key, new_key)
26542654

2655+
def test_rng_bit_generator_vmap(self):
2656+
def f(key):
2657+
return lax.rng_bit_generator(key, shape=(5, 7))
2658+
2659+
keys = np.arange(3 * 4).reshape((3, 4)).astype(np.uint32)
2660+
out_keys, bits = jax.vmap(f)(keys)
2661+
self.assertEqual(out_keys.shape, (3, 4))
2662+
self.assertEqual(bits.shape, (3, 5, 7))
2663+
2664+
def test_rng_bit_generator_vmap_vmap(self):
2665+
def f(key):
2666+
return lax.rng_bit_generator(key, shape=(5, 7))
2667+
2668+
keys = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.uint32)
2669+
out_keys, bits = jax.vmap(jax.vmap(f))(keys)
2670+
self.assertEqual(out_keys.shape, (2, 3, 4))
2671+
self.assertEqual(bits.shape, (2, 3, 5, 7))
2672+
26552673
@jtu.sample_product(
26562674
dtype=lax_test_util.all_dtypes + lax_test_util.python_scalar_types,
26572675
weak_type=[True, False],

tests/random_lax_test.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,6 +1348,7 @@ def test_vmap_fold_in_shape(self):
13481348
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys(), msgs.T)
13491349
self.assertEqual(out.shape, (3, 2))
13501350

1351+
@jax.enable_key_reuse_checks(False)
13511352
def test_vmap_split_mapped_key(self):
13521353
key = self.make_key(73)
13531354
mapped_keys = random.split(key, num=3)
@@ -1408,24 +1409,57 @@ def test_vmap_split_not_mapped_key(self):
14081409
self.assertArraysEqual(random.key_data(vk),
14091410
random.key_data(single_split_key))
14101411

1411-
def test_vmap_split_mapped_key(self):
1412+
@jax.enable_key_reuse_checks(False)
1413+
def test_vmap_split_mapped_key_shape(self):
14121414
key = self.make_key(73)
14131415
mapped_keys = random.split(key, num=3)
1414-
forloop_keys = [random.split(k) for k in mapped_keys]
14151416
vmapped_keys = vmap(random.split)(mapped_keys)
14161417
self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape))
1417-
for fk, vk in zip(forloop_keys, vmapped_keys):
1418-
self.assertArraysEqual(random.key_data(fk),
1418+
1419+
@jax.enable_key_reuse_checks(False)
1420+
def test_vmap_split_mapped_key_values(self):
1421+
key = self.make_key(73)
1422+
mapped_keys = random.split(key, num=3)
1423+
vmapped_keys = vmap(random.split)(mapped_keys)
1424+
ref_keys = [random.split(k) for k in mapped_keys]
1425+
for rk, vk in zip(ref_keys, vmapped_keys):
1426+
self.assertArraysEqual(random.key_data(rk),
14191427
random.key_data(vk))
14201428

1421-
def test_vmap_random_bits(self):
1422-
rand_fun = lambda key: random.randint(key, (), 0, 100)
1429+
@jax.enable_key_reuse_checks(False)
1430+
def test_vmap_random_bits_shape(self):
1431+
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
14231432
key = self.make_key(73)
14241433
mapped_keys = random.split(key, num=3)
1425-
forloop_rand_nums = [rand_fun(k) for k in mapped_keys]
14261434
rand_nums = vmap(rand_fun)(mapped_keys)
14271435
self.assertEqual(rand_nums.shape, (3,))
1428-
self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums))
1436+
1437+
@jtu.skip_on_devices("tpu")
1438+
@jax.enable_key_reuse_checks(False)
1439+
def test_vmap_random_bits_value(self):
1440+
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
1441+
key = self.make_key(73)
1442+
mapped_keys = random.split(key, num=3)
1443+
rand_nums = vmap(rand_fun)(mapped_keys)
1444+
ref_nums = rand_fun(mapped_keys[0], shape=(3,))
1445+
self.assertArraysEqual(rand_nums, ref_nums)
1446+
1447+
def test_vmap_random_bits_distribution(self):
1448+
dtype = jnp.float32
1449+
keys = lambda: jax.random.split(self.make_key(0), 10)
1450+
1451+
def rand(key):
1452+
nums = jax.vmap(lambda key: random.uniform(key, (1000,), dtype))(key)
1453+
return nums.flatten()
1454+
1455+
crand = jax.jit(rand)
1456+
1457+
uncompiled_samples = rand(keys())
1458+
compiled_samples = crand(keys())
1459+
1460+
for samples in [uncompiled_samples, compiled_samples]:
1461+
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
1462+
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
14291463

14301464
def test_cannot_add(self):
14311465
key = self.make_key(73)
@@ -1455,6 +1489,15 @@ class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
14551489
def make_key(self, seed):
14561490
return random.PRNGKey(seed, impl="unsafe_rbg")
14571491

1492+
@jtu.skip_on_devices("tpu")
1493+
@jax.enable_key_reuse_checks(False)
1494+
def test_vmap_split_mapped_key_values(self):
1495+
key = self.make_key(73)
1496+
mapped_keys = random.split(key, num=3)
1497+
vmapped_keys = vmap(random.split)(mapped_keys)
1498+
ref_keys = random.split(mapped_keys[0], (3, 2))
1499+
self.assertArraysEqual(random.key_data(vmapped_keys),
1500+
random.key_data(ref_keys))
14581501

14591502
def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
14601503
raise SkipTest('sampler only implemented for default RNG')

0 commit comments

Comments
 (0)