@@ -90,7 +90,10 @@ def get(module: ir.Module,
90
90
lambda hash_obj : _hash_xla_flags (hash_obj , get_flag_prefixes ())),
91
91
("compile_options" ,
92
92
lambda hash_obj : _hash_serialized_compile_options (
93
- hash_obj , compile_options )),
93
+ hash_obj , compile_options ,
94
+ # In case of GPU multi-process tasks we need to strip device
95
+ # assignment to use cache key as invariant between processes.
96
+ strip_device_assignment = (backend .platform == "gpu" ))),
94
97
("accelerator_config" ,
95
98
lambda hash_obj : _hash_accelerator_config (hash_obj , devices , backend )),
96
99
("compression" ,
@@ -172,7 +175,8 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend):
172
175
_hash_platform (hash_obj , backend )
173
176
174
177
175
- def _hash_serialized_compile_options (hash_obj , compile_options_obj ):
178
+ def _hash_serialized_compile_options (hash_obj , compile_options_obj ,
179
+ strip_device_assignment = False ):
176
180
# Do not mess with the original CompileOptions object since it is passed to
177
181
# the compiler. Create a deep copy for the purpose of cache key generation.
178
182
compile_options_copy = copy .deepcopy (compile_options_obj )
@@ -211,6 +215,12 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj):
211
215
debug_options .xla_gpu_cuda_data_dir = ""
212
216
# LINT.ThenChange(:xla_flags)
213
217
218
+ if strip_device_assignment and compile_options_copy .device_assignment :
219
+ replica_count = compile_options_copy .device_assignment .replica_count ()
220
+ computation_count = compile_options_copy .device_assignment .computation_count ()
221
+ compile_options_copy .device_assignment = xla_client .DeviceAssignment .create (
222
+ np .ndarray ([replica_count , computation_count ])
223
+ )
214
224
return hash_obj .update (compile_options_copy .SerializeAsString ())
215
225
216
226
0 commit comments