Skip to content

Commit 6353877

Browse files
author
jax authors
committed
Merge pull request #20183 from jakevdp:key-reuse-concatenate
PiperOrigin-RevId: 614851106
2 parents 8ae93d5 + 3eff032 commit 6353877

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

jax/experimental/key_reuse/_core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,16 @@ def _slice_signature(eqn):
373373

374374
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature
375375

376+
def _concatenate_signature(eqn):
377+
num_vals = len(eqn.invars)
378+
# TODO(jakevdp): should this signature be more granular?
379+
if num_vals == 1:
380+
return KeyReuseSignature(Forward(0, 0))
381+
else:
382+
return KeyReuseSignature(*(Sink(i) for i in range(num_vals)), Source(0))
383+
384+
key_reuse_signatures_dynamic[lax.concatenate_p] = _concatenate_signature
385+
376386
def _pjit_key_type_signature(eqn):
377387
return get_jaxpr_type_signature(eqn.params['jaxpr'].jaxpr)
378388

tests/key_reuse_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,18 @@ def f(key):
209209
assert_consumed(key2)
210210
self.check_key_reuse(f, jax.random.key(0))
211211

212+
def test_concatenate(self):
213+
def f(key1, key2):
214+
assert_unconsumed(key1)
215+
assert_unconsumed(key2)
216+
keys = jax.lax.concatenate([key1, key2], dimension=0)
217+
assert_consumed(key1)
218+
assert_consumed(key2)
219+
assert_unconsumed(keys)
220+
key1 = jax.random.split(jax.random.key(0))
221+
key2 = jax.random.split(jax.random.key(1))
222+
self.check_key_reuse(f, key1, key2)
223+
212224
def test_slice(self):
213225
def f(keys):
214226
assert_unconsumed(keys)

0 commit comments

Comments
 (0)