Skip to content

Commit 07e45c3

Browse files
author
jax authors
committed
Merge pull request #20236 from jakevdp:key-reuse-stack
PiperOrigin-RevId: 618014760
2 parents d57bb8c + 7e60331 commit 07e45c3

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

jax/_src/prng.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from jax._src import dtypes
3535
from jax._src import pretty_printer as pp
3636
from jax._src import sharding_specs
37+
from jax._src import source_info_util
3738
from jax._src import tree_util as tree_util_internal
3839
from jax._src import typing
3940
from jax._src import op_shardings
@@ -154,6 +155,7 @@ class behave like an array whose base elements are keys, hiding the
154155
_impl: PRNGImpl
155156
_base_array: typing.Array
156157
_consumed: bool | np.ndarray # Used in jax.experimental.key_reuse.
158+
_source_info: None | source_info_util.SourceInfo = None
157159

158160
def __init__(self, impl, key_data: Any):
159161
assert not isinstance(key_data, core.Tracer)

jax/experimental/key_reuse/_core.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from collections import defaultdict
18+
import contextlib
1819
from functools import partial, reduce, total_ordering, wraps
1920
from typing import Any, Callable, Iterator, NamedTuple
2021

@@ -31,6 +32,7 @@
3132
from jax._src import prng
3233
from jax._src import random
3334
from jax._src import source_info_util
35+
from jax._src import traceback_util
3436
from jax._src import util
3537
from jax._src.ad_checkpoint import remat_p
3638
from jax._src.debugging import debug_callback_p
@@ -41,6 +43,27 @@
4143
import numpy as np
4244

4345

46+
traceback_util.register_exclusion(__file__)
47+
48+
_source_context_message = (
49+
'PRNG key first used at the above location was subsequently reused'
50+
' at the following location:')
51+
52+
def key_reuse_error_with_source_traceback(
53+
message: str, traceback: source_info_util.Traceback | None) -> KeyReuseError:
54+
err = KeyReuseError(message)
55+
if traceback is not None:
56+
filtered_tb = traceback_util.filter_traceback(traceback.as_python_traceback())
57+
if filtered_tb:
58+
context_err = KeyReuseError(_source_context_message).with_traceback(filtered_tb)
59+
context_err.__context__ = err.__context__
60+
context_err.__cause__ = err.__cause__
61+
context_err.__suppress_context__ = err.__suppress_context__
62+
err.__context__ = None
63+
err.__cause__ = context_err
64+
return err
65+
66+
4467
# Create Source() and Sink() objects which validate inputs, have
4568
# correct equality semantics, and are hashable & immutable.
4669
@total_ordering
@@ -145,19 +168,23 @@ def forwards(self) -> Iterator[Forward]:
145168

146169
def check_signature(self, *args, funcname="function", context=None):
147170
for sink in self.sinks:
148-
if not isinstance(args[sink.idx], prng.PRNGKeyArray):
171+
key = args[sink.idx]
172+
if not isinstance(key, prng.PRNGKeyArray):
149173
continue
150-
if np.any(args[sink.idx]._consumed & sink.mask):
174+
if np.any(key._consumed & sink.mask):
151175
msg = f"Previously-consumed key passed to {funcname} at index {sink.idx}"
152176
if context:
153177
msg += " {context}"
154-
raise KeyReuseError(msg)
178+
raise key_reuse_error_with_source_traceback(
179+
msg, key._source_info and key._source_info.traceback)
155180

156181
def update_consumption(self, args_in, args_out):
157182
for sink in self.sinks:
158183
arg = args_in[sink.idx]
159184
if isinstance(arg, prng.PRNGKeyArray):
160185
arg._consumed = arg._consumed | sink.mask
186+
if np.any(sink.mask):
187+
arg._source_info = source_info_util.current()
161188
for arg in args_out:
162189
if isinstance(arg, prng.PRNGKeyArray):
163190
arg._consumed = True

0 commit comments

Comments
 (0)