Skip to content

Commit 3eff032

Browse files
committed
[key reuse] define rule for lax.concatenate
1 parent b6e985f commit 3eff032

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

jax/experimental/key_reuse/_core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,16 @@ def _slice_signature(eqn):
357357

358358
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature
359359

360+
def _concatenate_signature(eqn):
361+
num_vals = len(eqn.invars)
362+
# TODO(jakevdp): should this signature be more granular?
363+
if num_vals == 1:
364+
return KeyReuseSignature(Forward(0, 0))
365+
else:
366+
return KeyReuseSignature(*(Sink(i) for i in range(num_vals)), Source(0))
367+
368+
key_reuse_signatures_dynamic[lax.concatenate_p] = _concatenate_signature
369+
360370
def _pjit_key_type_signature(eqn):
361371
return get_jaxpr_type_signature(eqn.params['jaxpr'].jaxpr)
362372

tests/key_reuse_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,18 @@ def f(key):
209209
assert_consumed(key2)
210210
self.check_key_reuse(f, jax.random.key(0))
211211

212+
def test_concatenate(self):
213+
def f(key1, key2):
214+
assert_unconsumed(key1)
215+
assert_unconsumed(key2)
216+
keys = jax.lax.concatenate([key1, key2], dimension=0)
217+
assert_consumed(key1)
218+
assert_consumed(key2)
219+
assert_unconsumed(keys)
220+
key1 = jax.random.split(jax.random.key(0))
221+
key2 = jax.random.split(jax.random.key(1))
222+
self.check_key_reuse(f, key1, key2)
223+
212224
def test_slice(self):
213225
def f(keys):
214226
assert_unconsumed(keys)

0 commit comments

Comments
 (0)