Skip to content

Commit 59e9ee3

Browse files
author
jax authors
committed
Merge pull request #20142 from jakevdp:key-reuse-sig
PiperOrigin-RevId: 614024760
2 parents 777209c + 0644f19 commit 59e9ee3

File tree

2 files changed

+181
-61
lines changed

2 files changed

+181
-61
lines changed

jax/experimental/key_reuse/_core.py

Lines changed: 110 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from collections import defaultdict
1818
from functools import partial, reduce, wraps
19-
from typing import Any, Callable, NamedTuple
19+
from typing import Any, Callable, Iterator, NamedTuple
2020

2121
import jax
2222
from jax import lax
@@ -41,36 +41,91 @@
4141
import numpy as np
4242

4343

44-
class Sink(NamedTuple):
44+
# Create Source() and Sink() objects which validate inputs, have
45+
# correct equality semantics, and are hashable & immutable.
46+
class _SourceSinkBase:
4547
idx: int
46-
mask: bool | np.ndarray = True
48+
mask: bool | np.ndarray
49+
50+
def __init__(self, idx: int, mask: bool | np.bool_ | np.ndarray = True):
51+
assert isinstance(idx, int)
52+
if isinstance(mask, np.ndarray):
53+
assert mask.dtype == np.dtype('bool')
54+
if np.all(mask):
55+
mask = True
56+
elif not np.any(mask):
57+
mask = False
58+
elif mask.flags.writeable:
59+
mask = np.array(mask, copy=True)
60+
mask.flags.writeable = False
61+
elif isinstance(mask, np.bool_):
62+
mask = bool(mask)
63+
else:
64+
assert isinstance(mask, bool)
65+
super().__setattr__("idx", idx)
66+
super().__setattr__("mask", mask)
67+
68+
def __setattr__(self, *args, **kwargs):
69+
raise ValueError(f"{self.__class__.__name__} is immutable")
70+
71+
def __eq__(self, other):
72+
return (self.__class__ == other.__class__
73+
and self.idx == other.idx
74+
and np.shape(self.mask) == np.shape(other.mask)
75+
and np.all(self.mask == other.mask))
76+
77+
def __hash__(self):
78+
if isinstance(self.mask, bool):
79+
return hash((self.__class__, self.idx, self.mask))
80+
else:
81+
mask = np.asarray(self.mask)
82+
return hash((self.__class__, self.idx, mask.shape,
83+
tuple(mask.flatten().tolist())))
4784

4885
def __repr__(self):
49-
if isinstance(self.mask, bool) and self.mask:
50-
return f"Sink({self.idx})"
51-
else:
52-
return f"Sink({self.idx}, mask={self.mask})"
86+
if self.mask is True:
87+
return f"{self.__class__.__name__}({self.idx})"
88+
return f"{self.__class__.__name__}({self.idx}, {self.mask})"
5389

5490

55-
class Source(NamedTuple):
56-
idx: int
57-
mask: bool | np.ndarray = True
91+
class Sink(_SourceSinkBase):
92+
pass
93+
94+
95+
class Source(_SourceSinkBase):
96+
pass
5897

59-
def __repr__(self):
60-
if isinstance(self.mask, bool) and self.mask:
61-
return f"Source({self.idx})"
62-
else:
63-
return f"Source({self.idx}, mask={self.mask})"
6498

6599
class Forward(NamedTuple):
66100
in_idx: int
67101
out_idx: int
68102

69103

70-
class KeyReuseSignature(NamedTuple):
71-
sinks: list[Sink]
72-
sources: list[Source]
73-
forwards: list[Forward] = []
104+
# KeyReuseSignature is essentially a frozen set of Source/Sink/Forward
105+
# objects, with a few convenience methods related to key reuse checking.
106+
class KeyReuseSignature:
107+
_args: frozenset[Source | Sink | Forward]
108+
109+
def __init__(self, *args):
110+
self._args = frozenset(args)
111+
112+
def __eq__(self, other):
113+
return isinstance(other, KeyReuseSignature) and self._args == other._args
114+
115+
def __hash__(self):
116+
return hash(self._args)
117+
118+
@property
119+
def sinks(self) -> Iterator[Sink]:
120+
yield from (s for s in self._args if isinstance(s, Sink))
121+
122+
@property
123+
def sources(self) -> Iterator[Source]:
124+
yield from (s for s in self._args if isinstance(s, Source))
125+
126+
@property
127+
def forwards(self) -> Iterator[Forward]:
128+
yield from (s for s in self._args if isinstance(s, Forward))
74129

75130
def check_signature(self, *args, funcname="function", context=None):
76131
for sink in self.sinks:
@@ -145,34 +200,33 @@ def _check_consumed_value(eqn, consumed):
145200
# The behavior of most primitives can be described via simple signatures.
146201
key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {}
147202

148-
key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [], [Forward(0, 0)])
149-
key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature([], [], [Forward(0, 0)])
150-
key_reuse_signatures[random.random_clone_p] = KeyReuseSignature([], [Source(0)])
151-
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], [])
203+
key_reuse_signatures[consume_p] = KeyReuseSignature(Sink(0), Forward(0, 0))
204+
key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature(Forward(0, 0))
205+
key_reuse_signatures[random.random_clone_p] = KeyReuseSignature(Source(0))
206+
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature(Sink(0))
152207
# TODO(jakevdp): should fold_in sink its input key?
153-
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)])
154-
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([], [Source(0)])
155-
key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature([], [Source(0)])
156-
key_reuse_signatures[prng.random_split_p] = KeyReuseSignature([Sink(0)], [Source(0)])
157-
key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature([Sink(0)], [])
208+
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature(Source(0))
209+
key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature(Source(0))
210+
key_reuse_signatures[prng.random_split_p] = KeyReuseSignature(Sink(0), Source(0))
211+
key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature(Sink(0))
158212
# TODO(jakevdp): broadcast should probably consume the input to avoid implicit duplication
159-
key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature([], [], [Forward(0, 0)])
160-
key_reuse_signatures[lax.copy_p] = KeyReuseSignature([], [], [Forward(0, 0)])
161-
key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature([], [], [Forward(0, 0)])
162-
key_reuse_signatures[lax.device_put_p] = KeyReuseSignature([], [], [Forward(0, 0)])
163-
key_reuse_signatures[lax.reshape_p] = KeyReuseSignature([], [], [Forward(0, 0)])
164-
key_reuse_signatures[lax.squeeze_p] = KeyReuseSignature([], [], [Forward(0, 0)])
165-
key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature([], [Source(0)], [])
213+
key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature(Forward(0, 0))
214+
key_reuse_signatures[lax.copy_p] = KeyReuseSignature(Forward(0, 0))
215+
key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature(Forward(0, 0))
216+
key_reuse_signatures[lax.device_put_p] = KeyReuseSignature(Forward(0, 0))
217+
key_reuse_signatures[lax.reshape_p] = KeyReuseSignature(Forward(0, 0))
218+
key_reuse_signatures[lax.squeeze_p] = KeyReuseSignature(Forward(0, 0))
219+
key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature(Source(0))
166220
# TODO(jakevdp): should unwrap sink its input key?
167-
key_reuse_signatures[prng.random_unwrap_p] = KeyReuseSignature([], [], [])
168-
key_reuse_signatures[debug_callback_p] = KeyReuseSignature([], [])
169-
key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature([], [], [Forward(0, 0)])
170-
key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature([Sink(1)], [], [Forward(0, 0)])
171-
key_reuse_signatures[lax.gather_p] = KeyReuseSignature([], [], [Forward(0, 0)])
172-
key_reuse_signatures[lax.scatter_p] = KeyReuseSignature([Sink(2)], [], [Forward(0, 0)])
221+
key_reuse_signatures[prng.random_unwrap_p] = KeyReuseSignature()
222+
key_reuse_signatures[debug_callback_p] = KeyReuseSignature()
223+
key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature(Forward(0, 0))
224+
key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature(Sink(1), Forward(0, 0))
225+
key_reuse_signatures[lax.gather_p] = KeyReuseSignature(Forward(0, 0))
226+
key_reuse_signatures[lax.scatter_p] = KeyReuseSignature(Sink(2), Forward(0, 0))
173227
# Equality checks don't consume
174-
key_reuse_signatures[lax.eq_p] = KeyReuseSignature([], [], [])
175-
key_reuse_signatures[lax.ne_p] = KeyReuseSignature([], [], [])
228+
key_reuse_signatures[lax.eq_p] = KeyReuseSignature()
229+
key_reuse_signatures[lax.ne_p] = KeyReuseSignature()
176230

177231
# Rules which require more dynamic logic.
178232
key_reuse_signatures_dynamic: dict[core.Primitive, Callable[..., KeyReuseSignature]] = {}
@@ -182,8 +236,7 @@ def unknown_signature(eqn):
182236
def is_key(var: core.Atom):
183237
return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key)
184238
return KeyReuseSignature(
185-
sinks=[Sink(idx, True) for idx, var in enumerate(eqn.invars) if is_key(var)],
186-
sources=[],
239+
*(Sink(idx) for idx, var in enumerate(eqn.invars) if is_key(var))
187240
)
188241

189242
@weakref_lru_cache
@@ -216,7 +269,6 @@ def sink(var: core.Atom, mask=True):
216269
return True
217270
consumed[var] = np.logical_or(consumed.get(var, False), mask)
218271

219-
220272
def source(var: core.Atom, mask=False):
221273
if not is_key(var):
222274
return
@@ -262,13 +314,13 @@ def is_consumed(var: core.Atom):
262314
source(eqn.outvars[src.idx])
263315

264316
return KeyReuseSignature(
265-
sinks=[Sink(i, consumed[v]) for i, v in enumerate(jaxpr.invars)
266-
if is_key(v) and np.any(consumed.get(v, False))],
267-
sources=[Source(i) for i, v in enumerate(jaxpr.outvars)
268-
if is_key(v) and resolve_forwards(v) not in jaxpr.invars and not consumed.get(v, False)],
269-
forwards=[Forward(jaxpr.invars.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type]
270-
for idx_out, outvar in enumerate(jaxpr.outvars)
271-
if is_key(outvar) and resolve_forwards(outvar) in jaxpr.invars]
317+
*(Sink(i, consumed[v]) for i, v in enumerate(jaxpr.invars)
318+
if is_key(v) and np.any(consumed.get(v, False))),
319+
*(Source(i) for i, v in enumerate(jaxpr.outvars)
320+
if is_key(v) and resolve_forwards(v) not in jaxpr.invars and not consumed.get(v, False)),
321+
*(Forward(jaxpr.invars.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type]
322+
for idx_out, outvar in enumerate(jaxpr.outvars)
323+
if is_key(outvar) and resolve_forwards(outvar) in jaxpr.invars)
272324
)
273325

274326

@@ -292,16 +344,16 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None:
292344
def _slice_signature(eqn):
293345
in_aval = eqn.invars[0].aval
294346
if not jax.dtypes.issubdtype(in_aval.dtype, jax.dtypes.prng_key):
295-
return KeyReuseSignature([], [], [Forward(0, 0)])
347+
return KeyReuseSignature(Forward(0, 0))
296348
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
297-
return KeyReuseSignature([], [], [Forward(0, 0)])
349+
return KeyReuseSignature(Forward(0, 0))
298350
start_indices = eqn.params['start_indices']
299351
limit_indices = eqn.params['limit_indices']
300352
strides = eqn.params['strides'] or (1,) * len(start_indices)
301353
idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides))
302354
sink = np.zeros(in_aval.shape, dtype=bool)
303355
sink[idx] = True
304-
return KeyReuseSignature([Sink(0, sink)], [Source(0)])
356+
return KeyReuseSignature(Sink(0, sink), Source(0))
305357

306358
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature
307359

@@ -329,7 +381,7 @@ def _cond_key_type_signature(eqn):
329381
combined_sources = [Source(i, reduce(np.logical_and, m)) for i, m in sources.items()]
330382
combined_forwards = [Forward(f.in_idx + 1, f.out_idx) for f in
331383
set.intersection(*(set(sig.forwards) for sig in signatures))]
332-
return KeyReuseSignature(combined_sinks, combined_sources, combined_forwards)
384+
return KeyReuseSignature(*combined_sinks, *combined_sources, *combined_forwards)
333385

334386
key_reuse_signatures_dynamic[lax.cond_p] = _cond_key_type_signature
335387

@@ -410,7 +462,7 @@ def _remat_key_type_signature(eqn):
410462
# 2) will never create keys
411463
# Therefore, the differentiated pass is a no-op.
412464
if eqn.params['differentiated']:
413-
return KeyReuseSignature([], [])
465+
return KeyReuseSignature()
414466
return get_jaxpr_type_signature(eqn.params['jaxpr'])
415467

416468
key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature

tests/key_reuse_test.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from jax._src import test_util as jtu
2626
from jax.errors import KeyReuseError
2727
from jax.experimental.key_reuse._core import (
28-
assert_consumed, assert_unconsumed, consume, consume_p)
28+
assert_consumed, assert_unconsumed, consume, consume_p,
29+
Source, Sink, Forward, KeyReuseSignature)
2930
from jax.experimental.key_reuse import _core
3031

3132
from jax import config
@@ -589,7 +590,7 @@ def f_good(x, key):
589590

590591

591592
@jtu.with_config(jax_enable_key_reuse_checks=True)
592-
class KeyReuseEager(jtu.JaxTestCase):
593+
class KeyReuseEagerTest(jtu.JaxTestCase):
593594
jit_msg = "Previously-consumed key passed to jit-compiled function at index 0"
594595
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
595596
traced_bits_msg = "In random_bits, argument 0 is already consumed."
@@ -616,9 +617,76 @@ def f():
616617
f()
617618

618619

620+
class KeyReuseImplementationTest(jtu.JaxTestCase):
621+
622+
def assertEquivalent(self, a, b):
623+
self.assertEqual(a, b)
624+
self.assertEqual(hash(a), hash(b))
625+
626+
def assertNotEquivalent(self, a, b):
627+
self.assertNotEqual(a, b)
628+
self.assertNotEqual(hash(a), hash(b))
629+
630+
def test_source_sink_immutability(self):
631+
mask = np.array([True, False])
632+
orig_mask_writeable = mask.flags.writeable
633+
634+
sink = Sink(0, mask)
635+
source = Source(0, mask)
636+
637+
self.assertFalse(sink.mask.flags.writeable)
638+
self.assertFalse(source.mask.flags.writeable)
639+
self.assertEqual(mask.flags.writeable, orig_mask_writeable)
640+
641+
with self.assertRaises(ValueError):
642+
sink.idx = 1
643+
with self.assertRaises(ValueError):
644+
sink.mask = True
645+
with self.assertRaises(ValueError):
646+
source.idx = 1
647+
with self.assertRaises(ValueError):
648+
source.mask = True
649+
650+
def test_source_sink_forward_equivalence_semantics(self):
651+
652+
true_mask = np.array([True, True])
653+
false_mask = np.array([False, False])
654+
mixed_mask = np.array([True, False])
655+
656+
self.assertEquivalent(Source(0), Source(0, True))
657+
self.assertEquivalent(Source(0, True), Source(0, true_mask))
658+
self.assertEquivalent(Source(0, False), Source(0, false_mask))
659+
self.assertEquivalent(Source(0, mixed_mask), Source(0, mixed_mask))
660+
self.assertNotEquivalent(Source(0), Source(1))
661+
self.assertNotEquivalent(Source(0), Source(0, False))
662+
self.assertNotEquivalent(Source(0), Source(0, mixed_mask))
663+
664+
self.assertEquivalent(Sink(0), Sink(0, True))
665+
self.assertEquivalent(Sink(0, True), Sink(0, true_mask))
666+
self.assertEquivalent(Sink(0, False), Sink(0, false_mask))
667+
self.assertEquivalent(Sink(0, mixed_mask), Sink(0, mixed_mask))
668+
self.assertNotEquivalent(Sink(0), Sink(1))
669+
self.assertNotEquivalent(Sink(0), Sink(0, False))
670+
self.assertNotEquivalent(Sink(0), Sink(0, mixed_mask))
671+
672+
self.assertNotEquivalent(Source(0), Sink(0))
673+
674+
self.assertEquivalent(Forward(0, 1), Forward(0, 1))
675+
self.assertNotEquivalent(Forward(0, 1), Forward(1, 0))
676+
677+
def test_signature_equality_semantics(self):
678+
self.assertEquivalent(
679+
KeyReuseSignature(Sink(0), Source(1), Forward(1, 0)),
680+
KeyReuseSignature(Forward(1, 0), Source(1), Sink(0)))
681+
self.assertEquivalent(
682+
KeyReuseSignature(), KeyReuseSignature())
683+
self.assertNotEquivalent(
684+
KeyReuseSignature(Source(0)), KeyReuseSignature(Sink(0)))
685+
686+
619687

620688
@jtu.with_config(jax_enable_checks=False)
621-
class KeyReuseGlobalFlags(jtu.JaxTestCase):
689+
class KeyReuseGlobalFlagsTest(jtu.JaxTestCase):
622690
def test_key_reuse_flag(self):
623691

624692
@jax.jit

0 commit comments

Comments
 (0)