Skip to content

Commit 8ae93d5

Browse files
author
jax authors
committed
Merge pull request #20181 from jakevdp:reuse-signature-repr
PiperOrigin-RevId: 614821824
2 parents fe44afc + 6cf740c commit 8ae93d5

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

jax/experimental/key_reuse/_core.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
from collections import defaultdict
18-
from functools import partial, reduce, wraps
18+
from functools import partial, reduce, total_ordering, wraps
1919
from typing import Any, Callable, Iterator, NamedTuple
2020

2121
import jax
@@ -43,6 +43,7 @@
4343

4444
# Create Source() and Sink() objects which validate inputs, have
4545
# correct equality semantics, and are hashable & immutable.
46+
@total_ordering
4647
class _SourceSinkBase:
4748
idx: int
4849
mask: bool | np.ndarray
@@ -74,6 +75,15 @@ def __eq__(self, other):
7475
and np.shape(self.mask) == np.shape(other.mask)
7576
and np.all(self.mask == other.mask))
7677

78+
def __lt__(self, other):
79+
if isinstance(other, Forward):
80+
return True
81+
elif isinstance(other, _SourceSinkBase):
82+
return ((self.__class__.__name__, self.idx)
83+
< (other.__class__.__name__, other.idx))
84+
else:
85+
return NotImplemented
86+
7787
def __hash__(self):
7888
if isinstance(self.mask, bool):
7989
return hash((self.__class__, self.idx, self.mask))
@@ -100,6 +110,9 @@ class Forward(NamedTuple):
100110
in_idx: int
101111
out_idx: int
102112

113+
def __repr__(self):
114+
return f"Forward({self.in_idx}, {self.out_idx})"
115+
103116

104117
# KeyReuseSignature is essentially a frozen set of Source/Sink/Forward
105118
# objects, with a few convenience methods related to key reuse checking.
@@ -109,6 +122,9 @@ class KeyReuseSignature:
109122
def __init__(self, *args):
110123
self._args = frozenset(args)
111124

125+
def __repr__(self):
126+
return f"KeyReuseSignature{tuple(sorted(self._args))}"
127+
112128
def __eq__(self, other):
113129
return isinstance(other, KeyReuseSignature) and self._args == other._args
114130

tests/key_reuse_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,15 @@ def test_signature_equality_semantics(self):
692692
self.assertNotEquivalent(
693693
KeyReuseSignature(Source(0)), KeyReuseSignature(Sink(0)))
694694

695+
def test_reprs(self):
696+
self.assertEqual(repr(Sink(0)), "Sink(0)")
697+
self.assertEqual(repr(Source(0)), "Source(0)")
698+
self.assertEqual(repr(Forward(0, 1)), "Forward(0, 1)")
699+
self.assertEqual(repr(KeyReuseSignature(Sink(1), Source(0))),
700+
"KeyReuseSignature(Sink(1), Source(0))")
701+
self.assertEqual(repr(KeyReuseSignature(Sink(1), Sink(0))),
702+
"KeyReuseSignature(Sink(0), Sink(1))")
703+
695704

696705

697706
@jtu.with_config(jax_enable_checks=False)

0 commit comments

Comments
 (0)