Skip to content

Commit 596756f

Browse files
author
jax authors
committed
Enhance compilation cache key generation with a custom hook.
The custom hook is called every time the cache key is generated. It can be programmed to add a custom string that is then hashed as part of the cache key. Testing: test workloads. PiperOrigin-RevId: 610586945
1 parent fab8f6c commit 596756f

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

jax/_src/cache_key.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ def get_flag_prefixes() -> list[str]:
5151
return _extra_flag_prefixes
5252

5353

54+
def custom_hook() -> str:
55+
"""Custom hook for any addition to the cache key.
56+
57+
The custom hook will be called everytime get() is called and can be
58+
defined to return a string that will be hashed into the cache key.
59+
"""
60+
return ""
61+
62+
5463
def get(module: ir.Module,
5564
devices: np.ndarray,
5665
compile_options: xla_client.CompileOptions,
@@ -86,6 +95,7 @@ def get(module: ir.Module,
8695
lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)),
8796
("compression",
8897
lambda hash_obj: _hash_string(hash_obj, compression_algorithm)),
98+
("custom_hook", lambda hash_obj: _hash_string(hash_obj, custom_hook())),
8999
]
90100

91101
hash_obj = hashlib.sha256()

tests/cache_key_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,21 @@ def test_different_key(self):
127127
cache_key.get(computation, devices, compile_options_filled, backend),
128128
)
129129

130+
def test_custom_hook(self):
131+
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
132+
devices = np.array([[jax.local_devices()[0]]])
133+
compile_options = compiler.get_compile_options(
134+
num_replicas=1, num_partitions=1
135+
)
136+
backend = xla_bridge.get_backend()
137+
original_custom_hook = cache_key.custom_hook
138+
cache_key.custom_hook = lambda: "hook1"
139+
key1 = cache_key.get(computation, devices, compile_options, backend)
140+
cache_key.custom_hook = lambda: "hook2"
141+
key2 = cache_key.get(computation, devices, compile_options, backend)
142+
cache_key.custom_hook = original_custom_hook
143+
self.assertNotEqual(key1, key2)
144+
130145
def test_different_computations(self):
131146
computation1 = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
132147
computation2 = jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()

0 commit comments

Comments
 (0)