@@ -247,6 +247,36 @@ def _canonicalize_axis_index_groups(axis_index_groups):
247
247
return
248
248
return tuple (map (tuple , axis_index_groups ))
249
249
250
+
251
+ def pbroadcast (x , axis_name , source ):
252
+ """Perform a collective broadcast and replicate from ``source``.
253
+
254
+ This is equivalent to
255
+ ```
256
+ def pbroadcast(x, axis_name, source):
257
+ masked = jnp.where(axis_index(axis_name) == source, x, zeros_like(x))
258
+ return psum(masked, axis_name)
259
+ ```
260
+ but implemented in a hardware optimized way.
261
+
262
+ If ``x`` is a pytree then the result is equivalent to mapping this function to
263
+ each leaf in the tree.
264
+
265
+ This function is an analog of the CollectiveBroadcast HLO.
266
+
267
+ Args:
268
+ x: array(s) with a mapped axis named ``axis_name``.
269
+ axis_name: hashable Python object used to name a pmapped axis (see the
270
+ :func:`jax.pmap` documentation for more details).
271
+ source: int, representing which index into ``axis_name`` that should be copied.
272
+
273
+ Returns:
274
+ Array(s) with ``x`` being copied from the ``source`` index slice of ``axis_name``.
275
+ """
276
+ return tree_util .tree_map (
277
+ partial (pbroadcast_p .bind , axis_name = axis_name , source = source ), x )
278
+
279
+
250
280
def ppermute (x , axis_name , perm ):
251
281
"""Perform a collective permutation according to the permutation ``perm``.
252
282
@@ -927,6 +957,43 @@ def _collective_batcher(prim, args, dims, **params):
927
957
batching .axis_primitive_batchers [ppermute_p ] = _ppermute_batcher
928
958
core .axis_substitution_rules [ppermute_p ] = partial (_subst_all_names_in_param , 'axis_name' )
929
959
960
+ def _pbroadcast_transpose_rule (t , x , source , axis_name ):
961
+ is_source = axis_index (axis_name ) == source
962
+ tsum = psum (t , axis_name )
963
+ return [lax_numpy .where (is_source , tsum , lax_numpy .zeros_like (t ))]
964
+
965
+ def _pbroadcast_batcher (axis_size , frame_name , _ , vals_in , dims_in , axis_name , source ):
966
+ (v ,), (d ,) = vals_in , dims_in
967
+ if not isinstance (axis_name , (tuple , list )):
968
+ axis_name = (axis_name ,)
969
+ remaining_axes = tuple (axis for axis in axis_name if axis != frame_name )
970
+ if remaining_axes :
971
+ raise NotImplementedError ("pbroadcast batcher only supports a single axis" )
972
+ assert axis_name [0 ] == frame_name , "pbroadcast batcher called with a wrong axis!"
973
+ assert source >= 0 and source < axis_size , "collective broadcast doesn't fit in the axis size!"
974
+ if axis_size == 1 and remaining_axes :
975
+ return pbroadcast_p .bind (v , source = source , axis_name = remaining_axes ), d
976
+ if d is batching .not_mapped :
977
+ return v , d
978
+ return lax_numpy .take (v , [source ] * axis_size , d ), d
979
+
980
+ def _pbroadcast_lowering (ctx , x , * , axis_name , source ):
981
+ replica_groups = _replica_groups (ctx .module_context .axis_env , axis_name , None )
982
+ def source_to_front (group ):
983
+ return [group [source ]] + list (group [:source ]) + list (group [source + 1 :])
984
+ replica_groups = [source_to_front (group ) for group in replica_groups ]
985
+ channel = ctx .module_context .new_channel ()
986
+ return hlo .CollectiveBroadcastOp (
987
+ x , replica_groups = _replica_groups_hlo (replica_groups )).results
988
+
989
+ pbroadcast_p = core .AxisPrimitive ('pbroadcast' )
990
+ pbroadcast_p .def_abstract_eval (lambda x , ** params : raise_to_shaped (x ))
991
+ ad .deflinear2 (pbroadcast_p , _pbroadcast_transpose_rule )
992
+ mlir .register_lowering (pbroadcast_p , _pbroadcast_lowering )
993
+ batching .primitive_batchers [pbroadcast_p ] = partial (_collective_batcher , pbroadcast_p )
994
+ batching .axis_primitive_batchers [pbroadcast_p ] = _pbroadcast_batcher
995
+ core .axis_substitution_rules [pbroadcast_p ] = partial (_subst_all_names_in_param , 'axis_name' )
996
+
930
997
931
998
def _moveaxis (src , dst , x ):
932
999
perm = [i for i in range (x .ndim ) if i != src ]
0 commit comments