Skip to content

[REQUEST] Fast access to whole model gradient during training #7644

@JohannesAck

Description

@JohannesAck

Is your feature request related to a problem? Please describe.
I'm trying to manipulate the gradient during training, however the current safe_get_full_grad API is prohibitively slow.

I'm basically doing this:

def training_step(model,data):
    loss = get_loss(model, data)
    engine.backward(loss)
    # grad_ph set up in advance to hold gradient temporarily
    for pname, param in model.named_parameters():
        grad_ph[pname].copy_(safe_get_full_grad(param).detach())
    grad_modified = magic_modification(grad_ph)
    for pname, param in model.named_parameters():
        safe_set_full_grad(param, grad_modified[pname])
    engine.step()

In my case (a GH200 cluster), it takes 300ms to collect the gradient across 8GPUs for a 1B model by using safe_get_full_grad, while the optimizer step+sync only takes ~70ms. Intuitively getting the gradient should be faster?
I believe the issue might be that each call to safe_get_full_grad results in its own tiny all_reduce instead of a single large all_reduce.
Interestingly, setting the gradient is a lot faster than getting it.

Sidenote: For zero-1, no communication should be necessary to get the gradient at all, but with the current implementation it still communicates a lot, why is that? Also for zero-2, all_gather should be sufficient instead of all_reduce?

Describe the solution you'd like
I would like a function along the lines of safe_get_full_model_grads, that returns the gradients for all parameters of the model, and does so more efficiently than calling safe_get_full_grad repeatedly.
Ideally it would also avoid any cross-device communication for zero stage 1.

Describe alternatives you've considered
I don't see a good alternative with the currently provided APIs.

Additional context
See my annotated profiler trace here, with gradient accumulation across two batches and zero stage 2:

Image

I would be happy to contribute a PR if I can get some pointers about how this should be implemented!

Thanks for providing this awesome project!

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions