Skip to content

Commit ac1a53d

Browse files
junwhanahnjax authors
authored andcommitted
Optimize _create_copy_plan by iterating over only the shards that are needed for materialization
For arrays that are fully or partially replicated, it is more efficient to (pre-)construct a list of addressable array shards that participate in materialization rather than going over all array shards. This is particularly useful for single-controller JAX. The implementation assumes that addressable arrays appear in the same order as the corresponding addressable devices in `sharding.addressable_devices_indices_map()`. PiperOrigin-RevId: 624969222
1 parent 3a09404 commit ac1a53d

File tree

1 file changed

+10
-22
lines changed

1 file changed

+10
-22
lines changed

jax/_src/array.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -125,23 +125,13 @@ def _reconstruct_array(fun, args, arr_state, aval_state):
125125
def _cached_index_calc(s, shape):
126126
map_ = s.addressable_devices_indices_map(shape)
127127
seen_h_indices = set()
128-
m = {}
129-
for d, index in map_.items():
128+
l = []
129+
for array_index, index in enumerate(map_.values()):
130130
h_index = hashed_index(index)
131131
if h_index not in seen_h_indices:
132132
seen_h_indices.add(h_index)
133-
m[d] = index
134-
return m
135-
136-
137-
def _create_copy_plan(arrays, s: Sharding, shape: Shape):
138-
di_map = _cached_index_calc(s, shape)
139-
copy_plan = []
140-
for a in arrays:
141-
ind = di_map.get(a.sharding._internal_device_list[0], None) # type:ignore
142-
if ind is not None:
143-
copy_plan.append((ind, a))
144-
return copy_plan
133+
l.append((array_index, index))
134+
return l
145135

146136

147137
@functools.lru_cache(maxsize=4096)
@@ -607,9 +597,8 @@ def copy_to_host_async(self):
607597
if self.is_fully_replicated:
608598
self._copy_single_device_array_to_host_async()
609599
return
610-
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
611-
for _, arr in copy_plan:
612-
arr._copy_single_device_array_to_host_async()
600+
for i, _ in _cached_index_calc(self.sharding, self.shape):
601+
self._arrays[i]._copy_single_device_array_to_host_async()
613602

614603
@property
615604
@functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)")
@@ -631,13 +620,12 @@ def _value(self) -> np.ndarray:
631620
"`jax.experimental.multihost_utils.process_allgather` "
632621
"for this use case.")
633622

634-
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
635-
for _, arr in copy_plan:
636-
arr._copy_single_device_array_to_host_async()
623+
for i, _ in _cached_index_calc(self.sharding, self.shape):
624+
self._arrays[i]._copy_single_device_array_to_host_async()
637625

638626
npy_value = np.empty(self.shape, self.dtype)
639-
for ind, arr in copy_plan:
640-
npy_value[ind] = arr._single_device_array_to_np_array()
627+
for i, ind in _cached_index_calc(self.sharding, self.shape):
628+
npy_value[ind] = self._arrays[i]._single_device_array_to_np_array()
641629
self._npy_value = npy_value # type: ignore
642630
self._npy_value.flags.writeable = False
643631
# https://docs.python.org/3/library/typing.html#typing.cast

0 commit comments

Comments
 (0)