Skip to content

Commit d20a2f1

Browse files
author
jax authors
committed
Merge pull request #20317 from inailuig:mpi_collectives
PiperOrigin-RevId: 627208382
2 parents 16b4f69 + 1e32fb5 commit d20a2f1

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ Remember to align the itemized text with the first line of an item within a list
1212
* Added {func}`jax.numpy.unstack` and {func}`jax.numpy.cumulative_sum`,
1313
following their addition in the array API 2023 standard, soon to be
1414
adopted by NumPy.
15+
* Added a new config option `jax_cpu_collectives_implementation` to select the
16+
implementation of cross-process collective operations used by the CPU backend.
17+
Choices available are `'none'`(default), `'gloo'` and `'mpi'` (requires jaxlib 0.4.26).
18+
If set to `'none'`, cross-process collective operations are disabled.
1519

1620
* Changes
1721
* {func}`jax.pure_callback`, {func}`jax.experimental.io_callback`
@@ -49,6 +53,8 @@ Remember to align the itemized text with the first line of an item within a list
4953
deprecation is completed.
5054
* Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and
5155
related functions now raise an error, following a similar change in NumPy.
56+
* The config option `jax_cpu_enable_gloo_collectives` is deprecated.
57+
Use `jax.config.update('jax_cpu_collectives_implementation', 'gloo')` instead.
5258

5359
## jaxlib 0.4.27
5460

jax/_src/xla_bridge.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from __future__ import annotations
2424

25+
import atexit
2526
from collections.abc import Mapping
2627
import dataclasses
2728
from functools import lru_cache, partial
@@ -103,7 +104,14 @@
103104
_CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool(
104105
name="jax_cpu_enable_gloo_collectives",
105106
default=False,
106-
help="If True, enable cross-process collectives on CPU using Gloo.",
107+
help="Deprecated, please use jax_cpu_collectives_implementation instead.",
108+
)
109+
110+
_CPU_COLLECTIVES_IMPLEMENTATION = config.DEFINE_string(
111+
name='jax_cpu_collectives_implementation',
112+
default='none',
113+
help='Cross-process collective implementation used on CPU. Either "none", '
114+
'"gloo" or "mpi"'
107115
)
108116

109117
# TODO(yueshengys): turn default back to True after resolving memory increase
@@ -229,10 +237,29 @@ def register_backend_factory(name: str, factory: BackendFactory, *,
229237

230238
def make_cpu_client() -> xla_client.Client:
231239
collectives: xla_client._xla.CpuCollectives | None = None
240+
241+
collectives_impl = _CPU_COLLECTIVES_IMPLEMENTATION.value
232242
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
243+
collectives_impl = 'gloo'
244+
warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is deprecated. '
245+
'Please use `jax.config.update('
246+
'"jax_cpu_collectives_implementation", "gloo")` instead.',
247+
DeprecationWarning,
248+
)
249+
if collectives_impl == 'gloo':
233250
collectives = xla_client._xla.make_gloo_tcp_collectives( # type: ignore
234251
distributed_client=distributed.global_state.client,
235252
)
253+
elif collectives_impl == 'mpi' and xla_extension_version >= 251:
254+
collectives = xla_client._xla.make_mpi_collectives() # type: ignore
255+
collectives.Init() # type: ignore
256+
atexit.register(collectives.Finalize) # type: ignore
257+
elif collectives_impl != 'none':
258+
collectives_impls = ['none', 'gloo'
259+
] + (['mpi'] if xla_extension_version >= 251 else [])
260+
raise RuntimeError(f"Unknown collectives implementation "
261+
f"{collectives_impl}. Available implementations are "
262+
f"{collectives_impls}.")
236263
if xla_extension_version >= 257:
237264
return xla_client.make_cpu_client( # type: ignore
238265
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,

0 commit comments

Comments
 (0)