-
Couldn't load subscription status.
- Fork 4.6k
Description
Is your feature request related to a problem? Please describe.
I'm trying to manipulate the gradient during training, however the current safe_get_full_grad API is prohibitively slow.
I'm basically doing this:
def training_step(model,data):
loss = get_loss(model, data)
engine.backward(loss)
# grad_ph set up in advance to hold gradient temporarily
for pname, param in model.named_parameters():
grad_ph[pname].copy_(safe_get_full_grad(param).detach())
grad_modified = magic_modification(grad_ph)
for pname, param in model.named_parameters():
safe_set_full_grad(param, grad_modified[pname])
engine.step()In my case (a GH200 cluster), it takes 300ms to collect the gradient across 8GPUs for a 1B model by using safe_get_full_grad, while the optimizer step+sync only takes ~70ms. Intuitively getting the gradient should be faster?
I believe the issue might be that each call to safe_get_full_grad results in its own tiny all_reduce instead of a single large all_reduce.
Interestingly, setting the gradient is a lot faster than getting it.
Sidenote: For zero-1, no communication should be necessary to get the gradient at all, but with the current implementation it still communicates a lot, why is that? Also for zero-2, all_gather should be sufficient instead of all_reduce?
Describe the solution you'd like
I would like a function along the lines of safe_get_full_model_grads, that returns the gradients for all parameters of the model, and does so more efficiently than calling safe_get_full_grad repeatedly.
Ideally it would also avoid any cross-device communication for zero stage 1.
Describe alternatives you've considered
I don't see a good alternative with the currently provided APIs.
Additional context
See my annotated profiler trace here, with gradient accumulation across two batches and zero stage 2:
I would be happy to contribute a PR if I can get some pointers about how this should be implemented!
Thanks for providing this awesome project!