File tree Expand file tree Collapse file tree 2 files changed +22
-0
lines changed
jax/experimental/key_reuse Expand file tree Collapse file tree 2 files changed +22
-0
lines changed Original file line number Diff line number Diff line change @@ -373,6 +373,16 @@ def _slice_signature(eqn):
373
373
374
374
key_reuse_signatures_dynamic [lax .slice_p ] = _slice_signature
375
375
376
+ def _concatenate_signature (eqn ):
377
+ num_vals = len (eqn .invars )
378
+ # TODO(jakevdp): should this signature be more granular?
379
+ if num_vals == 1 :
380
+ return KeyReuseSignature (Forward (0 , 0 ))
381
+ else :
382
+ return KeyReuseSignature (* (Sink (i ) for i in range (num_vals )), Source (0 ))
383
+
384
+ key_reuse_signatures_dynamic [lax .concatenate_p ] = _concatenate_signature
385
+
376
386
def _pjit_key_type_signature (eqn ):
377
387
return get_jaxpr_type_signature (eqn .params ['jaxpr' ].jaxpr )
378
388
Original file line number Diff line number Diff line change @@ -209,6 +209,18 @@ def f(key):
209
209
assert_consumed (key2 )
210
210
self .check_key_reuse (f , jax .random .key (0 ))
211
211
212
+ def test_concatenate (self ):
213
+ def f (key1 , key2 ):
214
+ assert_unconsumed (key1 )
215
+ assert_unconsumed (key2 )
216
+ keys = jax .lax .concatenate ([key1 , key2 ], dimension = 0 )
217
+ assert_consumed (key1 )
218
+ assert_consumed (key2 )
219
+ assert_unconsumed (keys )
220
+ key1 = jax .random .split (jax .random .key (0 ))
221
+ key2 = jax .random .split (jax .random .key (1 ))
222
+ self .check_key_reuse (f , key1 , key2 )
223
+
212
224
def test_slice (self ):
213
225
def f (keys ):
214
226
assert_unconsumed (keys )
You can’t perform that action at this time.
0 commit comments