16
16
17
17
from collections import defaultdict
18
18
from functools import partial , reduce , wraps
19
- from typing import Any , Callable , NamedTuple
19
+ from typing import Any , Callable , Iterator , NamedTuple
20
20
21
21
import jax
22
22
from jax import lax
41
41
import numpy as np
42
42
43
43
44
- class Sink (NamedTuple ):
44
+ # Create Source() and Sink() objects which validate inputs, have
45
+ # correct equality semantics, and are hashable & immutable.
46
+ class _SourceSinkBase :
45
47
idx : int
46
- mask : bool | np .ndarray = True
48
+ mask : bool | np .ndarray
49
+
50
+ def __init__ (self , idx : int , mask : bool | np .bool_ | np .ndarray = True ):
51
+ assert isinstance (idx , int )
52
+ if isinstance (mask , np .ndarray ):
53
+ assert mask .dtype == np .dtype ('bool' )
54
+ if np .all (mask ):
55
+ mask = True
56
+ elif not np .any (mask ):
57
+ mask = False
58
+ elif mask .flags .writeable :
59
+ mask = np .array (mask , copy = True )
60
+ mask .flags .writeable = False
61
+ elif isinstance (mask , np .bool_ ):
62
+ mask = bool (mask )
63
+ else :
64
+ assert isinstance (mask , bool )
65
+ super ().__setattr__ ("idx" , idx )
66
+ super ().__setattr__ ("mask" , mask )
67
+
68
+ def __setattr__ (self , * args , ** kwargs ):
69
+ raise ValueError (f"{ self .__class__ .__name__ } is immutable" )
70
+
71
+ def __eq__ (self , other ):
72
+ return (self .__class__ == other .__class__
73
+ and self .idx == other .idx
74
+ and np .shape (self .mask ) == np .shape (other .mask )
75
+ and np .all (self .mask == other .mask ))
76
+
77
+ def __hash__ (self ):
78
+ if isinstance (self .mask , bool ):
79
+ return hash ((self .__class__ , self .idx , self .mask ))
80
+ else :
81
+ mask = np .asarray (self .mask )
82
+ return hash ((self .__class__ , self .idx , mask .shape ,
83
+ tuple (mask .flatten ().tolist ())))
47
84
48
85
def __repr__ (self ):
49
- if isinstance (self .mask , bool ) and self .mask :
50
- return f"Sink({ self .idx } )"
51
- else :
52
- return f"Sink({ self .idx } , mask={ self .mask } )"
86
+ if self .mask is True :
87
+ return f"{ self .__class__ .__name__ } ({ self .idx } )"
88
+ return f"{ self .__class__ .__name__ } ({ self .idx } , { self .mask } )"
53
89
54
90
55
- class Source (NamedTuple ):
56
- idx : int
57
- mask : bool | np .ndarray = True
91
+ class Sink (_SourceSinkBase ):
92
+ pass
93
+
94
+
95
+ class Source (_SourceSinkBase ):
96
+ pass
58
97
59
- def __repr__ (self ):
60
- if isinstance (self .mask , bool ) and self .mask :
61
- return f"Source({ self .idx } )"
62
- else :
63
- return f"Source({ self .idx } , mask={ self .mask } )"
64
98
65
99
class Forward (NamedTuple ):
66
100
in_idx : int
67
101
out_idx : int
68
102
69
103
70
- class KeyReuseSignature (NamedTuple ):
71
- sinks : list [Sink ]
72
- sources : list [Source ]
73
- forwards : list [Forward ] = []
104
+ # KeyReuseSignature is essentially a frozen set of Source/Sink/Forward
105
+ # objects, with a few convenience methods related to key reuse checking.
106
+ class KeyReuseSignature :
107
+ _args : frozenset [Source | Sink | Forward ]
108
+
109
+ def __init__ (self , * args ):
110
+ self ._args = frozenset (args )
111
+
112
+ def __eq__ (self , other ):
113
+ return isinstance (other , KeyReuseSignature ) and self ._args == other ._args
114
+
115
+ def __hash__ (self ):
116
+ return hash (self ._args )
117
+
118
+ @property
119
+ def sinks (self ) -> Iterator [Sink ]:
120
+ yield from (s for s in self ._args if isinstance (s , Sink ))
121
+
122
+ @property
123
+ def sources (self ) -> Iterator [Source ]:
124
+ yield from (s for s in self ._args if isinstance (s , Source ))
125
+
126
+ @property
127
+ def forwards (self ) -> Iterator [Forward ]:
128
+ yield from (s for s in self ._args if isinstance (s , Forward ))
74
129
75
130
def check_signature (self , * args , funcname = "function" , context = None ):
76
131
for sink in self .sinks :
@@ -145,34 +200,33 @@ def _check_consumed_value(eqn, consumed):
145
200
# The behavior of most primitives can be described via simple signatures.
146
201
key_reuse_signatures : dict [core .Primitive , KeyReuseSignature ] = {}
147
202
148
- key_reuse_signatures [consume_p ] = KeyReuseSignature ([ Sink (0 )], [], [ Forward (0 , 0 )] )
149
- key_reuse_signatures [assert_consumed_value_p ] = KeyReuseSignature ([], [], [ Forward (0 , 0 )] )
150
- key_reuse_signatures [random .random_clone_p ] = KeyReuseSignature ([], [ Source (0 )] )
151
- key_reuse_signatures [prng .random_bits_p ] = KeyReuseSignature ([ Sink (0 )], [] )
203
+ key_reuse_signatures [consume_p ] = KeyReuseSignature (Sink (0 ), Forward (0 , 0 ))
204
+ key_reuse_signatures [assert_consumed_value_p ] = KeyReuseSignature (Forward (0 , 0 ))
205
+ key_reuse_signatures [random .random_clone_p ] = KeyReuseSignature (Source (0 ))
206
+ key_reuse_signatures [prng .random_bits_p ] = KeyReuseSignature (Sink (0 ))
152
207
# TODO(jakevdp): should fold_in sink its input key?
153
- # key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)])
154
- key_reuse_signatures [prng .random_fold_in_p ] = KeyReuseSignature ([], [Source (0 )])
155
- key_reuse_signatures [prng .random_seed_p ] = KeyReuseSignature ([], [Source (0 )])
156
- key_reuse_signatures [prng .random_split_p ] = KeyReuseSignature ([Sink (0 )], [Source (0 )])
157
- key_reuse_signatures [random .random_gamma_p ] = KeyReuseSignature ([Sink (0 )], [])
208
+ key_reuse_signatures [prng .random_fold_in_p ] = KeyReuseSignature (Source (0 ))
209
+ key_reuse_signatures [prng .random_seed_p ] = KeyReuseSignature (Source (0 ))
210
+ key_reuse_signatures [prng .random_split_p ] = KeyReuseSignature (Sink (0 ), Source (0 ))
211
+ key_reuse_signatures [random .random_gamma_p ] = KeyReuseSignature (Sink (0 ))
158
212
# TODO(jakevdp): broadcast should probably consume the input to avoid implicit duplication
159
- key_reuse_signatures [lax .broadcast_in_dim_p ] = KeyReuseSignature ([], [], [ Forward (0 , 0 )] )
160
- key_reuse_signatures [lax .copy_p ] = KeyReuseSignature ([], [], [ Forward (0 , 0 )] )
161
- key_reuse_signatures [lax .convert_element_type_p ] = KeyReuseSignature ([], [], [ Forward (0 , 0 )] )
162
- key_reuse_signatures [lax .device_put_p ] = KeyReuseSignature ([], [], [ Forward (0 , 0 )] )
163
- key_reuse_signatures [lax .reshape_p ] = KeyReuseSignature ([], [], [ Forward (0 , 0 )] )
164
- key_reuse_signatures [lax .squeeze_p ] = KeyReuseSignature ([], [], [ Forward (0 , 0 )] )
165
- key_reuse_signatures [prng .random_wrap_p ] = KeyReuseSignature ([], [ Source (0 )], [] )
213
+ key_reuse_signatures [lax .broadcast_in_dim_p ] = KeyReuseSignature (Forward (0 , 0 ))
214
+ key_reuse_signatures [lax .copy_p ] = KeyReuseSignature (Forward (0 , 0 ))
215
+ key_reuse_signatures [lax .convert_element_type_p ] = KeyReuseSignature (Forward (0 , 0 ))
216
+ key_reuse_signatures [lax .device_put_p ] = KeyReuseSignature (Forward (0 , 0 ))
217
+ key_reuse_signatures [lax .reshape_p ] = KeyReuseSignature (Forward (0 , 0 ))
218
+ key_reuse_signatures [lax .squeeze_p ] = KeyReuseSignature (Forward (0 , 0 ))
219
+ key_reuse_signatures [prng .random_wrap_p ] = KeyReuseSignature (Source (0 ))
166
220
# TODO(jakevdp): should unwrap sink its input key?
167
- key_reuse_signatures [prng .random_unwrap_p ] = KeyReuseSignature ([], [], [] )
168
- key_reuse_signatures [debug_callback_p ] = KeyReuseSignature ([], [] )
169
- key_reuse_signatures [lax .dynamic_slice_p ] = KeyReuseSignature ([], [], [ Forward (0 , 0 )] )
170
- key_reuse_signatures [lax .dynamic_update_slice_p ] = KeyReuseSignature ([ Sink (1 )], [], [ Forward (0 , 0 )] )
171
- key_reuse_signatures [lax .gather_p ] = KeyReuseSignature ([], [], [ Forward (0 , 0 )] )
172
- key_reuse_signatures [lax .scatter_p ] = KeyReuseSignature ([ Sink (2 )], [], [ Forward (0 , 0 )] )
221
+ key_reuse_signatures [prng .random_unwrap_p ] = KeyReuseSignature ()
222
+ key_reuse_signatures [debug_callback_p ] = KeyReuseSignature ()
223
+ key_reuse_signatures [lax .dynamic_slice_p ] = KeyReuseSignature (Forward (0 , 0 ))
224
+ key_reuse_signatures [lax .dynamic_update_slice_p ] = KeyReuseSignature (Sink (1 ), Forward (0 , 0 ))
225
+ key_reuse_signatures [lax .gather_p ] = KeyReuseSignature (Forward (0 , 0 ))
226
+ key_reuse_signatures [lax .scatter_p ] = KeyReuseSignature (Sink (2 ), Forward (0 , 0 ))
173
227
# Equality checks don't consume
174
- key_reuse_signatures [lax .eq_p ] = KeyReuseSignature ([], [], [] )
175
- key_reuse_signatures [lax .ne_p ] = KeyReuseSignature ([], [], [] )
228
+ key_reuse_signatures [lax .eq_p ] = KeyReuseSignature ()
229
+ key_reuse_signatures [lax .ne_p ] = KeyReuseSignature ()
176
230
177
231
# Rules which require more dynamic logic.
178
232
key_reuse_signatures_dynamic : dict [core .Primitive , Callable [..., KeyReuseSignature ]] = {}
@@ -182,8 +236,7 @@ def unknown_signature(eqn):
182
236
def is_key (var : core .Atom ):
183
237
return hasattr (var .aval , "dtype" ) and jax .dtypes .issubdtype (var .aval .dtype , jax .dtypes .prng_key )
184
238
return KeyReuseSignature (
185
- sinks = [Sink (idx , True ) for idx , var in enumerate (eqn .invars ) if is_key (var )],
186
- sources = [],
239
+ * (Sink (idx ) for idx , var in enumerate (eqn .invars ) if is_key (var ))
187
240
)
188
241
189
242
@weakref_lru_cache
@@ -216,7 +269,6 @@ def sink(var: core.Atom, mask=True):
216
269
return True
217
270
consumed [var ] = np .logical_or (consumed .get (var , False ), mask )
218
271
219
-
220
272
def source (var : core .Atom , mask = False ):
221
273
if not is_key (var ):
222
274
return
@@ -262,13 +314,13 @@ def is_consumed(var: core.Atom):
262
314
source (eqn .outvars [src .idx ])
263
315
264
316
return KeyReuseSignature (
265
- sinks = [ Sink (i , consumed [v ]) for i , v in enumerate (jaxpr .invars )
266
- if is_key (v ) and np .any (consumed .get (v , False ))] ,
267
- sources = [ Source (i ) for i , v in enumerate (jaxpr .outvars )
268
- if is_key (v ) and resolve_forwards (v ) not in jaxpr .invars and not consumed .get (v , False )] ,
269
- forwards = [ Forward (jaxpr .invars .index (resolve_forwards (outvar )), idx_out ) # type: ignore[arg-type]
270
- for idx_out , outvar in enumerate (jaxpr .outvars )
271
- if is_key (outvar ) and resolve_forwards (outvar ) in jaxpr .invars ]
317
+ * ( Sink (i , consumed [v ]) for i , v in enumerate (jaxpr .invars )
318
+ if is_key (v ) and np .any (consumed .get (v , False ))) ,
319
+ * ( Source (i ) for i , v in enumerate (jaxpr .outvars )
320
+ if is_key (v ) and resolve_forwards (v ) not in jaxpr .invars and not consumed .get (v , False )) ,
321
+ * ( Forward (jaxpr .invars .index (resolve_forwards (outvar )), idx_out ) # type: ignore[arg-type]
322
+ for idx_out , outvar in enumerate (jaxpr .outvars )
323
+ if is_key (outvar ) and resolve_forwards (outvar ) in jaxpr .invars )
272
324
)
273
325
274
326
@@ -292,16 +344,16 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None:
292
344
def _slice_signature (eqn ):
293
345
in_aval = eqn .invars [0 ].aval
294
346
if not jax .dtypes .issubdtype (in_aval .dtype , jax .dtypes .prng_key ):
295
- return KeyReuseSignature ([], [], [ Forward (0 , 0 )] )
347
+ return KeyReuseSignature (Forward (0 , 0 ))
296
348
if any (core .is_symbolic_dim (s ) for s in in_aval .shape ):
297
- return KeyReuseSignature ([], [], [ Forward (0 , 0 )] )
349
+ return KeyReuseSignature (Forward (0 , 0 ))
298
350
start_indices = eqn .params ['start_indices' ]
299
351
limit_indices = eqn .params ['limit_indices' ]
300
352
strides = eqn .params ['strides' ] or (1 ,) * len (start_indices )
301
353
idx = tuple (slice (* tup ) for tup in util .safe_zip (start_indices , limit_indices , strides ))
302
354
sink = np .zeros (in_aval .shape , dtype = bool )
303
355
sink [idx ] = True
304
- return KeyReuseSignature ([ Sink (0 , sink )], [ Source (0 )] )
356
+ return KeyReuseSignature (Sink (0 , sink ), Source (0 ))
305
357
306
358
key_reuse_signatures_dynamic [lax .slice_p ] = _slice_signature
307
359
@@ -329,7 +381,7 @@ def _cond_key_type_signature(eqn):
329
381
combined_sources = [Source (i , reduce (np .logical_and , m )) for i , m in sources .items ()]
330
382
combined_forwards = [Forward (f .in_idx + 1 , f .out_idx ) for f in
331
383
set .intersection (* (set (sig .forwards ) for sig in signatures ))]
332
- return KeyReuseSignature (combined_sinks , combined_sources , combined_forwards )
384
+ return KeyReuseSignature (* combined_sinks , * combined_sources , * combined_forwards )
333
385
334
386
key_reuse_signatures_dynamic [lax .cond_p ] = _cond_key_type_signature
335
387
@@ -410,7 +462,7 @@ def _remat_key_type_signature(eqn):
410
462
# 2) will never create keys
411
463
# Therefore, the differentiated pass is a no-op.
412
464
if eqn .params ['differentiated' ]:
413
- return KeyReuseSignature ([], [] )
465
+ return KeyReuseSignature ()
414
466
return get_jaxpr_type_signature (eqn .params ['jaxpr' ])
415
467
416
468
key_reuse_signatures_dynamic [remat_p ] = _remat_key_type_signature
0 commit comments