|
22 | 22 |
|
23 | 23 | from __future__ import annotations
|
24 | 24 |
|
| 25 | +import atexit |
25 | 26 | from collections.abc import Mapping
|
26 | 27 | import dataclasses
|
27 | 28 | from functools import lru_cache, partial
|
|
103 | 104 | _CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool(
|
104 | 105 | name="jax_cpu_enable_gloo_collectives",
|
105 | 106 | default=False,
|
106 |
| - help="If True, enable cross-process collectives on CPU using Gloo.", |
| 107 | + help="Deprecated, please use jax_cpu_collectives_implementation instead.", |
| 108 | +) |
| 109 | + |
| 110 | +_CPU_COLLECTIVES_IMPLEMENTATION = config.DEFINE_string( |
| 111 | + name='jax_cpu_collectives_implementation', |
| 112 | + default='none', |
| 113 | + help='Cross-process collective implementation used on CPU. Either "none", ' |
| 114 | + '"gloo" or "mpi"' |
107 | 115 | )
|
108 | 116 |
|
109 | 117 | # TODO(yueshengys): turn default back to True after resolving memory increase
|
@@ -229,10 +237,29 @@ def register_backend_factory(name: str, factory: BackendFactory, *,
|
229 | 237 |
|
230 | 238 | def make_cpu_client() -> xla_client.Client:
|
231 | 239 | collectives: xla_client._xla.CpuCollectives | None = None
|
| 240 | + |
| 241 | + collectives_impl = _CPU_COLLECTIVES_IMPLEMENTATION.value |
232 | 242 | if _CPU_ENABLE_GLOO_COLLECTIVES.value:
|
| 243 | + collectives_impl = 'gloo' |
| 244 | + warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is deprecated. ' |
| 245 | + 'Please use `jax.config.update(' |
| 246 | + '"jax_cpu_collectives_implementation", "gloo")` instead.', |
| 247 | + DeprecationWarning, |
| 248 | + ) |
| 249 | + if collectives_impl == 'gloo': |
233 | 250 | collectives = xla_client._xla.make_gloo_tcp_collectives( # type: ignore
|
234 | 251 | distributed_client=distributed.global_state.client,
|
235 | 252 | )
|
| 253 | + elif collectives_impl == 'mpi' and xla_extension_version >= 251: |
| 254 | + collectives = xla_client._xla.make_mpi_collectives() # type: ignore |
| 255 | + collectives.Init() # type: ignore |
| 256 | + atexit.register(collectives.Finalize) # type: ignore |
| 257 | + elif collectives_impl != 'none': |
| 258 | + collectives_impls = ['none', 'gloo' |
| 259 | + ] + (['mpi'] if xla_extension_version >= 251 else []) |
| 260 | + raise RuntimeError(f"Unknown collectives implementation " |
| 261 | + f"{collectives_impl}. Available implementations are " |
| 262 | + f"{collectives_impls}.") |
236 | 263 | if xla_extension_version >= 257:
|
237 | 264 | return xla_client.make_cpu_client( # type: ignore
|
238 | 265 | asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
|
|
0 commit comments