Skip to content

Commit 0b28a4b

Browse files
author
jax authors
committed
Strip device_assignment on GPU platform.
This makes the hash invariant on a multi-process case. PiperOrigin-RevId: 617093247
1 parent a453301 commit 0b28a4b

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

jax/_src/cache_key.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,10 @@ def get(module: ir.Module,
9090
lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes())),
9191
("compile_options",
9292
lambda hash_obj: _hash_serialized_compile_options(
93-
hash_obj, compile_options)),
93+
hash_obj, compile_options,
94+
# In case of GPU multi-process tasks we need to strip device
95+
# assignment to use cache key as invariant between processes.
96+
strip_device_assignment=(backend.platform == "gpu"))),
9497
("accelerator_config",
9598
lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)),
9699
("compression",
@@ -172,7 +175,8 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend):
172175
_hash_platform(hash_obj, backend)
173176

174177

175-
def _hash_serialized_compile_options(hash_obj, compile_options_obj):
178+
def _hash_serialized_compile_options(hash_obj, compile_options_obj,
179+
strip_device_assignment=False):
176180
# Do not mess with the original CompileOptions object since it is passed to
177181
# the compiler. Create a deep copy for the purpose of cache key generation.
178182
compile_options_copy = copy.deepcopy(compile_options_obj)
@@ -211,6 +215,12 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj):
211215
debug_options.xla_gpu_cuda_data_dir = ""
212216
# LINT.ThenChange(:xla_flags)
213217

218+
if strip_device_assignment and compile_options_copy.device_assignment:
219+
replica_count = compile_options_copy.device_assignment.replica_count()
220+
computation_count = compile_options_copy.device_assignment.computation_count()
221+
compile_options_copy.device_assignment = xla_client.DeviceAssignment.create(
222+
np.ndarray([replica_count, computation_count])
223+
)
214224
return hash_obj.update(compile_options_copy.SerializeAsString())
215225

216226

tests/cache_key_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,23 @@ def test_different_computations(self):
155155
cache_key.get(computation2, devices, compile_options, backend),
156156
)
157157

158+
def test_different_device_assignment(self):
159+
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
160+
devices = np.array([[jax.local_devices()[0]]])
161+
compile_options_1 = compiler.get_compile_options(
162+
num_replicas=1, num_partitions=1, device_assignment=np.array([[0]])
163+
)
164+
compile_options_2 = compiler.get_compile_options(
165+
num_replicas=1, num_partitions=1, device_assignment=np.array([[1]])
166+
)
167+
backend = xla_bridge.get_backend()
168+
hash_1 = cache_key.get(computation, devices, compile_options_1, backend)
169+
hash_2 = cache_key.get(computation, devices, compile_options_2, backend)
170+
if backend.platform == "gpu":
171+
self.assertEqual(hash_1, hash_2)
172+
else:
173+
self.assertNotEqual(hash_1, hash_2)
174+
158175
@parameterized.parameters([False, True])
159176
def test_identical_computations_different_metadata(self, include_metadata):
160177
f = lambda x, y: lax.mul(lax.add(x, y), 2)

0 commit comments

Comments
 (0)