|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
17 | 17 | from collections import defaultdict
|
| 18 | +import contextlib |
18 | 19 | from functools import partial, reduce, total_ordering, wraps
|
19 | 20 | from typing import Any, Callable, Iterator, NamedTuple
|
20 | 21 |
|
|
31 | 32 | from jax._src import prng
|
32 | 33 | from jax._src import random
|
33 | 34 | from jax._src import source_info_util
|
| 35 | +from jax._src import traceback_util |
34 | 36 | from jax._src import util
|
35 | 37 | from jax._src.ad_checkpoint import remat_p
|
36 | 38 | from jax._src.debugging import debug_callback_p
|
|
41 | 43 | import numpy as np
|
42 | 44 |
|
43 | 45 |
|
| 46 | +traceback_util.register_exclusion(__file__) |
| 47 | + |
| 48 | +_source_context_message = ( |
| 49 | + 'PRNG key first used at the above location was subsequently reused' |
| 50 | + ' at the following location:') |
| 51 | + |
| 52 | +def key_reuse_error_with_source_traceback( |
| 53 | + message: str, traceback: source_info_util.Traceback | None) -> KeyReuseError: |
| 54 | + err = KeyReuseError(message) |
| 55 | + if traceback is not None: |
| 56 | + filtered_tb = traceback_util.filter_traceback(traceback.as_python_traceback()) |
| 57 | + if filtered_tb: |
| 58 | + context_err = KeyReuseError(_source_context_message).with_traceback(filtered_tb) |
| 59 | + context_err.__context__ = err.__context__ |
| 60 | + context_err.__cause__ = err.__cause__ |
| 61 | + context_err.__suppress_context__ = err.__suppress_context__ |
| 62 | + err.__context__ = None |
| 63 | + err.__cause__ = context_err |
| 64 | + return err |
| 65 | + |
| 66 | + |
44 | 67 | # Create Source() and Sink() objects which validate inputs, have
|
45 | 68 | # correct equality semantics, and are hashable & immutable.
|
46 | 69 | @total_ordering
|
@@ -145,19 +168,23 @@ def forwards(self) -> Iterator[Forward]:
|
145 | 168 |
|
146 | 169 | def check_signature(self, *args, funcname="function", context=None):
|
147 | 170 | for sink in self.sinks:
|
148 |
| - if not isinstance(args[sink.idx], prng.PRNGKeyArray): |
| 171 | + key = args[sink.idx] |
| 172 | + if not isinstance(key, prng.PRNGKeyArray): |
149 | 173 | continue
|
150 |
| - if np.any(args[sink.idx]._consumed & sink.mask): |
| 174 | + if np.any(key._consumed & sink.mask): |
151 | 175 | msg = f"Previously-consumed key passed to {funcname} at index {sink.idx}"
|
152 | 176 | if context:
|
153 | 177 | msg += " {context}"
|
154 |
| - raise KeyReuseError(msg) |
| 178 | + raise key_reuse_error_with_source_traceback( |
| 179 | + msg, key._source_info and key._source_info.traceback) |
155 | 180 |
|
156 | 181 | def update_consumption(self, args_in, args_out):
|
157 | 182 | for sink in self.sinks:
|
158 | 183 | arg = args_in[sink.idx]
|
159 | 184 | if isinstance(arg, prng.PRNGKeyArray):
|
160 | 185 | arg._consumed = arg._consumed | sink.mask
|
| 186 | + if np.any(sink.mask): |
| 187 | + arg._source_info = source_info_util.current() |
161 | 188 | for arg in args_out:
|
162 | 189 | if isinstance(arg, prng.PRNGKeyArray):
|
163 | 190 | arg._consumed = True
|
|
0 commit comments