TP+ZeRO1 in JAX
#28159
Replies: 1 comment
-
here's what the code looks like btw def _train_step_with_microbatching(
model: Module,
optimizer: Optimizer,
inputs: Array,
labels: Array,
*,
microbatch_size: int,
data_parallel_axis: str,
loss_function: Callable[[Array, Array], Array],
) -> tuple[Array, Any]:
microbatched_inputs, microbatched_labels = _split_into_microbatches(
inputs,
labels,
microbatch_size=microbatch_size,
data_parallel_degree=get_axis_size(data_parallel_axis),
)
num_microbatches, *trailing_dims = microbatched_inputs.shape
# manually hoist the gathering of updated model parameters out of the loop
# (otherwise it will be redone for every gradient accumulation step)
model_graph, model_states = flax.nnx.split(model)
model_states_sharding_for_compute = flax.nnx.get_partition_spec(model_states)
model_states_for_compute = jax.lax.with_sharding_constraint(
model_states, model_states_sharding_for_compute
)
model_for_compute = flax.nnx.merge(model_graph, model_states_for_compute)
loss_accumulator = jnp.zeros((1,), dtype=jnp.float32)
wgrad_accumulators = jax.tree.map(
lambda x: jnp.zeros_like(x, dtype=jnp.float32),
flax.nnx.state(model_for_compute, Param),
)
# don't use scan because it lowers to a while loop:
# * XLA doesn't hoist DP collectives out of scan/while, resulting in them being
# unnecessarily repeated for each microbatch.
# * even if we manually hoist out the collectives, we lose DP overlap since the scan
# acts as an implicit optimization barrier.
# * it causes problems on neuron.
# simple for loop will lower to an unrolled representation which doesn't have the
# issues described above.
for microbatch_index in range(num_microbatches):
# prevent XLA from mixing up microbatches. we want to finish one microbatch
# before starting on the next to get reasonable gradient accumulation memory characteristics.
input_microbatch, *false_deps = jax.lax.optimization_barrier(
(microbatched_inputs[microbatch_index], loss_accumulator, wgrad_accumulators)
)
label_microbatch, *false_deps = jax.lax.optimization_barrier(
(microbatched_labels[microbatch_index], loss_accumulator, wgrad_accumulators)
)
with jax.named_scope(f"microbatch_{microbatch_index}"):
def fwd_loss_fn(model: Module, inputs: Array, labels: Array) -> Array:
return loss_function(model(inputs), labels) # type: ignore[operator]
loss, wgrads = flax.nnx.value_and_grad(fwd_loss_fn, argnums=0)(
model_for_compute,
inputs=input_microbatch,
labels=label_microbatch,
)
# TODO. we want to sum across microbatches, not across DP ranks (yet)..
# see: https://github.com/jax-ml/jax/discussions/28159
# and: https://github.com/jax-ml/jax/discussions/16156
loss_accumulator += loss
wgrad_accumulators = jax.tree.map(jnp.add, wgrad_accumulators, wgrads)
global_batch_loss = loss_accumulator / num_microbatches
global_batch_wgrads = jax.tree.map(lambda x: x / num_microbatches, wgrad_accumulators)
with jax.named_scope("optimizer_step"):
optimizer.update(global_batch_wgrads)
return global_batch_loss, global_batch_wgrads |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello!
I'm working on writing a training loop w/ Flax NNX that uses TP+ZeRO1 with gradient accumulation. So basically:
for the case where we have gradient accumulation we're going for the following:
I did find some examples of gradient accumulation in other JAX codebases
they both seem to use scan/while. the problem i encountered with this was that the DP communications aren't hoisted out of the loop by XLA, because it would increase memory consumption (which is reasonable). I tried manually hoisting which sort of works, but the problem then becomes lack of DP overlap, since the loop boundaries are sort of treated like optimization barriers by the compiler.
so the next thing i tried was "unrolling" by turning it into a normal python for loop. Now i can get reasonable overlap for both AG and RS, but seeing RS launched in each gradient accumulation step. I think to get around this I would need to shard map the gradient accumulations to run independently over DP, but i'm not sure how to make this work with GSPMD. I don't quite want to write per device code, more like "per DP rank" code. Anything you can suggest? or am i approaching this completely wrong?
Sorry for the long winded post. the core of the issue seems to be that i'm going for TP+ZeRO1 while others are going for TP+ZeRO3(FSDP). The scan approaches used by other libraries seem to work well for their use case if they just want TP+FSDP since the scan iterations are all the same.
Beta Was this translation helpful? Give feedback.
All reactions