Skip to content

Commit f1ae623

Browse files
yueshengysjax authors
authored andcommitted
Fix token management for ordered side-effects.
Right now, when there are multiple devices, we shall get a output token from each device, but we only keep the token from `device_0` and replicate it across devices to get input tokens for next function call with ordered side-effects. This is fine on TPU/GPU, as they are essentially executed in sequence. But on CPU, they could run in parallel, so we need to make sure the dependency is set correctly. PiperOrigin-RevId: 623296894
1 parent 9809aa1 commit f1ae623

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

jax/_src/dispatch.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def simple_impl(prim):
108108
RuntimeToken = Any
109109

110110
class RuntimeTokenSet(threading.local):
111-
"""See docstring for effect.py module for the calling convention for tokens."""
111+
"""See docstring for effects.py module for the calling convention for tokens."""
112112

113113
# For each ordered effect, the token returned by the last dispatched
114114
# computation, sharded over the devices in that computation.
@@ -125,6 +125,16 @@ def __init__(self):
125125
def get_token_input(self, eff: core.Effect,
126126
devices: list[Device]) -> jax.Array:
127127
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
128+
129+
if isinstance(tok, jax.Array):
130+
# The order of devices may change, so we need to reshard if necessary.
131+
# TODO(yueshengys): This might still be buggy in a multi-process SPMD
132+
# scenario. Revise the logic later. A distributed shutdown barrier inside
133+
# the XLA program may be needed.
134+
return jax.device_put(tok, jax.sharding.PositionalSharding(devices))
135+
136+
# We only use replicated sharding for the first time when the token for the
137+
# order effect hasn't been created.
128138
s = jax.sharding.GSPMDSharding.get_replicated(devices)
129139
sharded_tok = pxla.shard_args([s], [tok])[0]
130140
self.current_tokens[eff] = sharded_tok

jax/_src/interpreters/pxla.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,13 +1155,25 @@ def _add_tokens_to_inputs(self, input_bufs):
11551155

11561156
def _handle_token_bufs(self, token_bufs, sharded_token):
11571157
# token_bufs: Sequence[Sequence[tokenArray]], for each effect the returned
1158-
# token buffer (as a singleton list).
1158+
# token buffers.
11591159
# sharded_token: ShardedToken, containing the RuntimeTokens for each device
11601160
for i, device in enumerate(self._local_devices):
11611161
dispatch.runtime_tokens.set_output_runtime_token(
11621162
device, sharded_token.get_token(i))
11631163
for eff, token_buf in zip(self.ordered_effects, token_bufs):
1164-
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
1164+
assert len(token_buf) > 0
1165+
if len(token_buf) == 1:
1166+
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
1167+
else:
1168+
token_devices = []
1169+
for token in token_buf:
1170+
assert isinstance(token.sharding, sharding_impls.SingleDeviceSharding)
1171+
token_devices.append(token.sharding._device_assignment[0])
1172+
s = sharding_impls.PositionalSharding(token_devices)
1173+
global_token_array = jax.make_array_from_single_device_arrays(
1174+
(0,), s, token_buf
1175+
)
1176+
dispatch.runtime_tokens.set_token_result(eff, global_token_array)
11651177

11661178
@profiler.annotate_function
11671179
def __call__(self, *args):

0 commit comments

Comments
 (0)