15
15
from __future__ import annotations
16
16
17
17
from collections import defaultdict
18
- from functools import partial , reduce , wraps
18
+ from functools import partial , reduce , total_ordering , wraps
19
19
from typing import Any , Callable , Iterator , NamedTuple
20
20
21
21
import jax
43
43
44
44
# Create Source() and Sink() objects which validate inputs, have
45
45
# correct equality semantics, and are hashable & immutable.
46
+ @total_ordering
46
47
class _SourceSinkBase :
47
48
idx : int
48
49
mask : bool | np .ndarray
@@ -74,6 +75,15 @@ def __eq__(self, other):
74
75
and np .shape (self .mask ) == np .shape (other .mask )
75
76
and np .all (self .mask == other .mask ))
76
77
78
+ def __lt__ (self , other ):
79
+ if isinstance (other , Forward ):
80
+ return True
81
+ elif isinstance (other , _SourceSinkBase ):
82
+ return ((self .__class__ .__name__ , self .idx )
83
+ < (other .__class__ .__name__ , other .idx ))
84
+ else :
85
+ return NotImplemented
86
+
77
87
def __hash__ (self ):
78
88
if isinstance (self .mask , bool ):
79
89
return hash ((self .__class__ , self .idx , self .mask ))
@@ -100,6 +110,9 @@ class Forward(NamedTuple):
100
110
in_idx : int
101
111
out_idx : int
102
112
113
+ def __repr__ (self ):
114
+ return f"Forward({ self .in_idx } , { self .out_idx } )"
115
+
103
116
104
117
# KeyReuseSignature is essentially a frozen set of Source/Sink/Forward
105
118
# objects, with a few convenience methods related to key reuse checking.
@@ -109,6 +122,9 @@ class KeyReuseSignature:
109
122
def __init__ (self , * args ):
110
123
self ._args = frozenset (args )
111
124
125
+ def __repr__ (self ):
126
+ return f"KeyReuseSignature{ tuple (sorted (self ._args ))} "
127
+
112
128
def __eq__ (self , other ):
113
129
return isinstance (other , KeyReuseSignature ) and self ._args == other ._args
114
130
0 commit comments