|
22 | 22 |
|
23 | 23 | from __future__ import annotations
|
24 | 24 |
|
| 25 | +import atexit |
25 | 26 | from collections.abc import Mapping
|
26 | 27 | import dataclasses
|
27 | 28 | from functools import lru_cache, partial
|
|
103 | 104 | _CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool(
|
104 | 105 | name="jax_cpu_enable_gloo_collectives",
|
105 | 106 | default=False,
|
106 |
| - help="If True, enable cross-process collectives on CPU using Gloo.", |
| 107 | + help="Deprecated, please use jax_cpu_collectives_implementation instead.", |
| 108 | +) |
| 109 | + |
| 110 | +_CPU_COLLECTIVES_IMPLEMENTATION = config.DEFINE_string( |
| 111 | + name='jax_cpu_collectives_implementation', |
| 112 | + default='none', |
| 113 | + help='Cross-process collective implementation used on CPU. Either "none", ' |
| 114 | + '"gloo" or "mpi"' |
107 | 115 | )
|
108 | 116 |
|
109 | 117 | # TODO(yueshengys): turn default back to True after resolving memory increase
|
@@ -228,11 +236,32 @@ def register_backend_factory(name: str, factory: BackendFactory, *,
|
228 | 236 |
|
229 | 237 |
|
230 | 238 | def make_cpu_client() -> xla_client.Client:
|
231 |
| - collectives: xla_client._xla.CpuCollectives | None = None |
| 239 | + collectives: (xla_client._xla.CpuCollectives | |
| 240 | + xla_client._xla.MpiCollectives | |
| 241 | + None ) = None |
| 242 | + |
| 243 | + collectives_impl = _CPU_COLLECTIVES_IMPLEMENTATION.value |
232 | 244 | if _CPU_ENABLE_GLOO_COLLECTIVES.value:
|
| 245 | + collectives_impl = 'gloo' |
| 246 | + warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is deprecated. ' |
| 247 | + 'Please use `jax.config.update(' |
| 248 | + '"jax_cpu_collectives_implementation", "gloo")` instead.', |
| 249 | + DeprecationWarning, |
| 250 | + ) |
| 251 | + if collectives_impl == 'gloo': |
233 | 252 | collectives = xla_client._xla.make_gloo_tcp_collectives( # type: ignore
|
234 | 253 | distributed_client=distributed.global_state.client,
|
235 | 254 | )
|
| 255 | + elif collectives_impl == 'mpi' and xla_extension_version >= 251: |
| 256 | + collectives = xla_client._xla.make_mpi_collectives() # type: ignore |
| 257 | + collectives.Init() |
| 258 | + atexit.register(collectives.Finalize) |
| 259 | + elif collectives_impl != 'none': |
| 260 | + collectives_impls = ['none', 'gloo' |
| 261 | + ] + (['mpi'] if xla_extension_version >= 251 else []) |
| 262 | + raise RuntimeError(f"Unknown collectives implementation " |
| 263 | + f"{collectives_impl}. Available implementations are " |
| 264 | + f"{collectives_impls}.") |
236 | 265 | if xla_extension_version >= 257:
|
237 | 266 | return xla_client.make_cpu_client( # type: ignore
|
238 | 267 | asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
|
|
0 commit comments