Skip to content

Commit 01412f7

Browse files
pbroadcast
1 parent e7eb207 commit 01412f7

File tree

3 files changed

+98
-1
lines changed

3 files changed

+98
-1
lines changed

jax/_src/lax/parallel.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,36 @@ def _canonicalize_axis_index_groups(axis_index_groups):
247247
return
248248
return tuple(map(tuple, axis_index_groups))
249249

250+
251+
def pbroadcast(x, axis_name, source):
252+
"""Perform a collective broadcast and replicate from ``source``.
253+
254+
This is equivalent to
255+
```
256+
def pbroadcast(x, axis_name, source):
257+
masked = jnp.where(axis_index(axis_name) == source, x, zeros_like(x))
258+
return psum(masked, axis_name)
259+
```
260+
but implemented in a hardware optimized way.
261+
262+
If ``x`` is a pytree then the result is equivalent to mapping this function to
263+
each leaf in the tree.
264+
265+
This function is an analog of the CollectiveBroadcast HLO.
266+
267+
Args:
268+
x: array(s) with a mapped axis named ``axis_name``.
269+
axis_name: hashable Python object used to name a pmapped axis (see the
270+
:func:`jax.pmap` documentation for more details).
271+
source: int, representing which index into ``axis_name`` that should be copied.
272+
273+
Returns:
274+
Array(s) with ``x`` being copied from the ``source`` index slice of ``axis_name``.
275+
"""
276+
return tree_util.tree_map(
277+
partial(pbroadcast_p.bind, axis_name=axis_name, source=source), x)
278+
279+
250280
def ppermute(x, axis_name, perm):
251281
"""Perform a collective permutation according to the permutation ``perm``.
252282
@@ -927,6 +957,43 @@ def _collective_batcher(prim, args, dims, **params):
927957
batching.axis_primitive_batchers[ppermute_p] = _ppermute_batcher
928958
core.axis_substitution_rules[ppermute_p] = partial(_subst_all_names_in_param, 'axis_name')
929959

960+
def _pbroadcast_transpose_rule(t, x, source, axis_name):
961+
is_source = axis_index(axis_name) == source
962+
tsum = psum(t, axis_name)
963+
return [lax_numpy.where(is_source, tsum, lax_numpy.zeros_like(t))]
964+
965+
def _pbroadcast_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, source):
966+
(v,), (d,) = vals_in, dims_in
967+
if not isinstance(axis_name, (tuple, list)):
968+
axis_name = (axis_name,)
969+
remaining_axes = tuple(axis for axis in axis_name if axis != frame_name)
970+
if remaining_axes:
971+
raise NotImplementedError("pbroadcast batcher only supports a single axis")
972+
assert axis_name[0] == frame_name, "pbroadcast batcher called with a wrong axis!"
973+
assert source >= 0 and source < axis_size, "collective broadcast doesn't fit in the axis size!"
974+
if axis_size == 1 and remaining_axes:
975+
return pbroadcast_p.bind(v, source=source, axis_name=remaining_axes), d
976+
if d is batching.not_mapped:
977+
return v, d
978+
return lax_numpy.take(v, [source] * axis_size, d), d
979+
980+
def _pbroadcast_lowering(ctx, x, *, axis_name, source):
981+
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, None)
982+
def source_to_front(group):
983+
return [group[source]] + list(group[:source]) + list(group[source + 1:])
984+
replica_groups = [source_to_front(group) for group in replica_groups]
985+
channel = ctx.module_context.new_channel()
986+
return hlo.CollectiveBroadcastOp(
987+
x, replica_groups=_replica_groups_hlo(replica_groups)).results
988+
989+
pbroadcast_p = core.AxisPrimitive('pbroadcast')
990+
pbroadcast_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
991+
ad.deflinear2(pbroadcast_p, _pbroadcast_transpose_rule)
992+
mlir.register_lowering(pbroadcast_p, _pbroadcast_lowering)
993+
batching.primitive_batchers[pbroadcast_p] = partial(_collective_batcher, pbroadcast_p)
994+
batching.axis_primitive_batchers[pbroadcast_p] = _pbroadcast_batcher
995+
core.axis_substitution_rules[pbroadcast_p] = partial(_subst_all_names_in_param, 'axis_name')
996+
930997

931998
def _moveaxis(src, dst, x):
932999
perm = [i for i in range(x.ndim) if i != src]

jax/lax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@
342342
all_to_all_p as all_to_all_p,
343343
axis_index as axis_index,
344344
axis_index_p as axis_index_p,
345+
pbroadcast as pbroadcast,
345346
pmax as pmax,
346347
pmax_p as pmax_p,
347348
pmean as pmean,

tests/pmap_test.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1109,12 +1109,41 @@ def testAxisGroups(self):
11091109
self.assertEqual((tuple(sorted(groups[0])),),
11101110
((0, 1, 2, 3, 4, 5, 6, 7,),)) # order doesn't matter
11111111

1112+
@jtu.skip_on_devices("cpu", "tpu")
1113+
def testCollectiveBroadcast(self):
1114+
device_count = jax.device_count()
1115+
f = lambda x: lax.pbroadcast(x, source=0, axis_name='i')
1116+
f = self.pmap(f, 'i')
1117+
x = jnp.arange(4 * device_count).reshape((device_count, 4))
1118+
ans = f(x)
1119+
expected = np.take(x, [0] * device_count, axis=0)
1120+
self.assertAllClose(ans, expected, check_dtypes=False)
1121+
1122+
@jtu.skip_on_devices("cpu", "tpu")
1123+
def testCollectiveBroadcastVmap(self):
1124+
device_count = jax.device_count()
1125+
f = lambda x: lax.pbroadcast(x, source=0, axis_name='i')
1126+
x = np.arange(device_count * 16, dtype=np.float32)
1127+
x = x.reshape((device_count, 4, 4))
1128+
ans = self.pmap(vmap(f), 'i')(x)
1129+
expected = jnp.broadcast_to(x[0:1], x.shape)
1130+
self.assertAllClose(ans, expected, check_dtypes=False)
1131+
1132+
@jtu.skip_on_devices("cpu", "tpu")
1133+
def testCollectiveBroadcastGrad(self):
1134+
device_count = jax.device_count()
1135+
f = lambda x: lax.pbroadcast(x, source=0, axis_name='i')
1136+
x = np.arange(device_count, dtype=np.float32)
1137+
ans = self.pmap(grad(f), 'i')(x)
1138+
expected = np.zeros_like(x)
1139+
expected[0] = device_count
1140+
self.assertAllClose(ans, expected, check_dtypes=False)
1141+
11121142
def testCollectivePermute(self):
11131143
device_count = jax.device_count()
11141144
rotation = [(i, (i + 1) % device_count) for i in range(device_count)]
11151145
f = lambda x: lax.ppermute(x, perm=rotation, axis_name='i')
11161146
f = self.pmap(f, 'i')
1117-
11181147
x = jnp.arange(4 * device_count).reshape((device_count, 4))
11191148
ans = f(x)
11201149
expected = np.roll(x, shift=1, axis=0)

0 commit comments

Comments
 (0)