Skip to content

Commit 4be25d7

Browse files
junwhanahnjax authors
authored andcommitted
Optimize jax.device_put() dispatch for 1:1 device-to-device transfers
* Cache the sharding index comparison in addition to sharding index calculation. This helps when the list of indices is expensive to compare. * Remove caching from `pxla.get_addressable_devices_for_shard_arg()` since `sharding._addressable_device_assignment` is already a cached property. * Use `a is b` instead of `id(a) == id(b)` since the former is more concise. PiperOrigin-RevId: 627080325
1 parent 1b1c6e7 commit 4be25d7

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

jax/_src/array.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -928,20 +928,26 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
928928
return pxla.batched_device_put(x.aval, sharding, bufs, devices)
929929

930930

931+
@functools.lru_cache(maxsize=4096)
932+
def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
933+
src_indices = src_sharding.addressable_devices_indices_map(shape).values()
934+
dst_indices = dst_sharding.addressable_devices_indices_map(shape).values()
935+
return dst_indices, tuple(src_indices) == tuple(dst_indices)
936+
937+
931938
def _array_shard_arg(x, sharding):
932939
x._check_if_deleted()
933940

934-
x_indices = x.sharding.addressable_devices_indices_map(x.shape).values()
935-
indices = sharding.addressable_devices_indices_map(x.shape).values()
941+
indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding)
936942
if not x.is_fully_addressable:
937-
if tuple(x_indices) == tuple(indices):
943+
if same_indices:
938944
return x
939945
else:
940946
raise NotImplementedError(
941947
"Cannot reshard an input that is not fully addressable")
942948
else:
943-
devices = pxla.get_addressable_devices_for_shard_arg(sharding)
944-
if tuple(x_indices) == tuple(indices):
949+
devices = sharding._addressable_device_assignment
950+
if same_indices:
945951
return xc.copy_array_to_devices_with_sharding(x, list(devices), sharding)
946952
# Resharding starts here:
947953
if dispatch.is_single_device_sharding(x.sharding):

jax/_src/interpreters/pxla.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,6 @@ def shard_args(
122122
shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {}
123123

124124

125-
@lru_cache(maxsize=1024)
126-
def get_addressable_devices_for_shard_arg(
127-
s: sharding_impls.XLACompatibleSharding) -> tuple[xc.Device, ...]:
128-
return s._addressable_device_assignment
129-
130125
@lru_cache(maxsize=1024)
131126
def _get_replicated_slices(num_addressable_devices: int):
132127
return ((slice(None),),) * num_addressable_devices
@@ -138,7 +133,7 @@ def _masked_array_error(x, sharding):
138133
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error
139134

140135
def _shard_array(x, sharding):
141-
devices = get_addressable_devices_for_shard_arg(sharding)
136+
devices = sharding._addressable_device_assignment
142137
if x.dtype == dtypes.float0:
143138
x = np.zeros(x.shape, dtype=np.dtype(bool))
144139
aval = api_util.shaped_abstractify(x)

jax/_src/sharding_impls.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,13 @@ def __hash__(self):
317317
def __eq__(self, other):
318318
if not isinstance(other, NamedSharding):
319319
return False
320-
if id(self) == id(other):
320+
if self is other:
321321
return True
322322
if (self._parsed_pspec != other._parsed_pspec
323323
or self.memory_kind != other.memory_kind
324324
or self._manual_axes != other._manual_axes):
325325
return False
326-
return id(self.mesh) == id(other.mesh) or self.mesh == other.mesh
326+
return self.mesh is other.mesh or self.mesh == other.mesh
327327

328328
def is_compatible_aval(self, aval_shape: Shape):
329329
assert self._parsed_pspec is not None
@@ -422,7 +422,7 @@ def __hash__(self):
422422
def __eq__(self, other):
423423
if not isinstance(other, SingleDeviceSharding):
424424
return False
425-
if id(self) == id(other):
425+
if self is other:
426426
return True
427427
return (self._device == other._device and
428428
self.memory_kind == other.memory_kind)
@@ -485,7 +485,7 @@ def __reduce__(self):
485485
def __eq__(self, other):
486486
if not isinstance(other, PmapSharding):
487487
return False
488-
if id(self) == id(other):
488+
if self is other:
489489
return True
490490
return (self.sharding_spec == other.sharding_spec and
491491
self.devices.shape == other.devices.shape and
@@ -741,12 +741,11 @@ def __hash__(self) -> int:
741741
def __eq__(self, other) -> bool:
742742
if not isinstance(other, PositionalSharding):
743743
return False
744-
if id(self) == id(other):
744+
if self is other:
745745
return True
746746
all_ids_equal = np.array_equal(self._ids,other._ids)
747747
mem_kind_equal = self.memory_kind == other.memory_kind
748-
if (id(self._devices) == id(other._devices) and mem_kind_equal and
749-
all_ids_equal):
748+
if self._devices is other._devices and mem_kind_equal and all_ids_equal:
750749
return True
751750
return (mem_kind_equal and all_ids_equal and
752751
self._internal_device_list == other._internal_device_list)
@@ -852,7 +851,7 @@ def _hlo_sharding_hash(self):
852851
def __eq__(self, other):
853852
if not isinstance(other, GSPMDSharding):
854853
return False
855-
if id(self) == id(other):
854+
if self is other:
856855
return True
857856
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
858857
and self.memory_kind == other.memory_kind

0 commit comments

Comments
 (0)