Skip to content

Support gradient accumulation #1238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open

Conversation

janEbert
Copy link

First, the batched backward calculation is refactored into its own function. Then, gradient accumulation is implemented by moving the data iterator inside the train_step method and consuming data from it as necessary. I added some extra handling for non-infinite data iterators, but if you dislike that additional complexity, I can remove it to simplify the code.

The feature is enabled by giving an additional --training.global_batch_size, which has a sensible default of 1 gradient accumulation step (i.e., no actual accumulation).

@tianyu-l thanks for the ping.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 29, 2025
@janEbert
Copy link
Author

Maybe it would also make sense to rename --training.batch_size to --training.local_batch_size accordingly to differentiate it further from the --training.global_batch_size config.

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. I suggest that we don't let train_step() be aware of data_iterator. Please see the detail comments.

Also, this PR doesn't change the parallelization, which is not correct. We will have to call set_requires_gradient_sync if FSDP is applied. We can raise an exception if DDP is used and accumulation_steps > 1 for now.

unwrapped_loss_fn = self.loss_fn

@functools.wraps(unwrapped_loss_fn)
def accumulated_loss_fn(*args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just modify build_loss_fn to take accmulation_steps to let the loss function decide the usage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK either way.
I think being more explicit about grad accumulation handling doesn't look bad.
Also if we go with explicit global_batch_size and implicit grad_accu_steps, then we'll need to do another check & computation in the loss function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the wrapping functionality to torchtitan.components.loss, called it rescale_accumulated_loss. Not quite like what you wanted, but that way we can re-use the Trainer.gradient_accumulation_step value more easily.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this feature!
I left several comments. Please see if they make sense.

@@ -192,6 +192,11 @@ class Training:
batch_size: int = 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah let's call it local_batch_size

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a rename across the codebase whereever JobConfig.training.batch_size or --training.batch_size was used. Not sure how you'd like me to handle the compatibility breakage that this introduces.

Comment on lines 123 to 141
if job_config.training.global_batch_size < 0:
job_config.training.global_batch_size = (
job_config.training.batch_size * dp_degree
)
assert job_config.training.global_batch_size > 0
assert (
job_config.training.global_batch_size
% (job_config.training.batch_size * dp_degree)
== 0
), (
f"global batch size must be multiple of local batch size times "
f"data-parallel degree ({job_config.training.global_batch_size} "
f"% ({job_config.training.batch_size} * {dp_degree}) != 0)"
)

self.gradient_accumulation_steps = job_config.training.global_batch_size // (
job_config.training.batch_size * dp_degree
)
assert self.gradient_accumulation_steps > 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit comment

Suggested change
if job_config.training.global_batch_size < 0:
job_config.training.global_batch_size = (
job_config.training.batch_size * dp_degree
)
assert job_config.training.global_batch_size > 0
assert (
job_config.training.global_batch_size
% (job_config.training.batch_size * dp_degree)
== 0
), (
f"global batch size must be multiple of local batch size times "
f"data-parallel degree ({job_config.training.global_batch_size} "
f"% ({job_config.training.batch_size} * {dp_degree}) != 0)"
)
self.gradient_accumulation_steps = job_config.training.global_batch_size // (
job_config.training.batch_size * dp_degree
)
assert self.gradient_accumulation_steps > 0
global_batch_size = job_config.training.global_batch_size
if global_batch_size < 0:
global_batch_size = job_config.training.batch_size * dp_degree
self.gradient_accumulation_steps = 1
else:
assert global_batch_size > (job_config.training.batch_size * dp_degree)
assert (
job_config.training.global_batch_size
% (job_config.training.batch_size * dp_degree)
== 0
), (
f"global batch size must be multiple of local batch size times "
f"data-parallel degree ({global_batch_size} "
f"% ({job_config.training.batch_size} * {dp_degree}) != 0)"
)
self.gradient_accumulation_steps = global_batch_size // (
job_config.training.batch_size * dp_degree
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't really agree with not re-using the code that would become else case here, but can still change it to your recommendation. For now, I put the addition of the global_batch_size variable into its own commit, which probably already has the readability improvements that you'd like. Also added a comment in the if case that this global batch size results in 1 gradient accumulation step.

@@ -183,6 +205,15 @@ def __init__(self, job_config: JobConfig):

self.loss_fn = self.train_spec.build_loss_fn(job_config)

unwrapped_loss_fn = self.loss_fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put the self.gradient_accumulation_steps derivation code right before here, to group gradient accum logic together as much as possible.

I understand that it is desirable to fail early on infeasible global batch size, even before parallelism and other heavy things are applied. But I'd suggest we prioritize readability. What do you think?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds fair! :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this.


# Keep these variables local to shorten the code as these are
# the major variables that are used in the training loop.
def batch_backward(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we call it forward_backward_step?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. By the way, if you'd prefer me to squash these changes into the previous commits, I'd be happy to clean up the commit chain.

Comment on lines 416 to 417
model_parts = self.model_parts
world_mesh = self.world_mesh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, maybe not worth keeping these two

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -336,6 +337,7 @@ def __init__(
)
self.ntokens_since_last_log = 0
self.data_loading_times = []
self.accumulated_losses = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it represents a core training concept, rather than directly used for metrics logging, let's put this in Trainer, instead of MetricsProcessor.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Also added the gradient_accumulation_steps attribute to the Trainer's dataclass attributes.

except StopIteration:
# If data runs out during gradient accumulation, that
# entire step will not be executed.
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of explicit return True, can we just call next and let the StopIteration exception propagate to train_step and catch over there?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially had it implemented this way, but thought the try block would encapsulate too much code. If anything else raises a StopIteration, it would make debugging much more difficult. Therefore the minimization of the try scope.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer directly raise StopIteration and let the outer loop to catch. As mentioned in the above discussion, the original design is to keep train_step() simple without data dependency. So there is no other StopIteration() afaik. If there are other places actually raise the StopIteration, we should figure it out.

If we really want to avoid ambiguity , we can have a customized next(), like next_batch() which will raise a customized DataDepleteException().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's considerate. I think it's quite unlikely other places would also raise StopIteration? Maybe microbatching in pipeline parallel? But over there the number of microbatches should be fixed ahead of time.

Anyways, if you think we need to deal with this explicitly, we should catch the StopIteration exception, and raise a customized DataloaderStopIteration exception to be caught by caller, instead of return True.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went with a combination of these suggestions; a Trainer.next_batch method basically just calls next(data_iterator), but catches and re-raises its StopIteration as a new DataloaderStopIteration.

self.step += 1
self.gc_handler.run(self.step)
self.train_step(inputs, labels)
data_ran_out = self.train_step(data_iterator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can catch the StopIteration here and do different treatment on self.checkpointer.save in try vs. catch.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has been changed, but we now simply break in case of the DataloaderStopIteration to prevent the change to the checkpointing logic.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does change the general logic (e.g., torch_profiler and memory_profiler won't be stepped anymore) compared to the previous code, but is a bit nicer to read instead of adding an extra variable check in the while-query, IMO.

@tianyu-l tianyu-l linked an issue May 30, 2025 that may be closed by this pull request
@fegin
Copy link
Contributor

fegin commented May 30, 2025

@tianyu-l Let me know what you think about the proposal above. Don't want @janEbert to be stuck in two different reviews.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this PR doesn't change the parallelization, which is not correct. We will have to call set_requires_gradient_sync if FSDP is applied.

@fegin For background please see #292 (comment)

I think for us we don't want the potential memory overhead and code complexity, although it can save some communications which could've been hidden anyway.

unwrapped_loss_fn = self.loss_fn

@functools.wraps(unwrapped_loss_fn)
def accumulated_loss_fn(*args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK either way.
I think being more explicit about grad accumulation handling doesn't look bad.
Also if we go with explicit global_batch_size and implicit grad_accu_steps, then we'll need to do another check & computation in the loss function.

def train_step(
self,
data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]],
) -> bool | None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just return bool and change all other returns to return False to keep the semantic consistent. This should be changed if we still keep the returning value as the design option. But I prefer try/catch. See the below response.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted this/refactored to try-catch solution as per other discussions. Return type is back to implicit None.

except StopIteration:
# If data runs out during gradient accumulation, that
# entire step will not be executed.
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer directly raise StopIteration and let the outer loop to catch. As mentioned in the above discussion, the original design is to keep train_step() simple without data dependency. So there is no other StopIteration() afaik. If there are other places actually raise the StopIteration, we should figure it out.

If we really want to avoid ambiguity , we can have a customized next(), like next_batch() which will raise a customized DataDepleteException().

@fegin
Copy link
Contributor

fegin commented May 30, 2025

The review order looks pretty confusing, lol. The summary of some big discussions:

  1. Keep the design with a new forward_backward_step and global_batch to align with RL use case.
  2. Avoid returning a value from train_step(), using a customized Exception for data depletion.

cc., @tianyu-l

@tianyu-l
Copy link
Contributor

tianyu-l commented Jun 3, 2025

hey @janEbert how about let's work a bit more on the PR.

Sorry for the confusion in the reviews. I think we have agreed on the direction:

Keep the design with a new forward_backward_step and global_batch to align with RL use case.
Avoid returning a value from train_step(), using a customized Exception for data depletion.

Please also add a test case in https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests.py

janEbert added a commit to janEbert/torchtitan that referenced this pull request Jun 3, 2025
@fegin said:
> TorchTitan currently doesn't perform force checkpoint if data is
> depleted. We can fix this but I suggest that we don't do this in this
> PR.

(See
pytorch#1238 (comment).)
@janEbert
Copy link
Author

janEbert commented Jun 3, 2025

I believe I have incorporated all the feedback. Let me know how you like the changes. FYI I'm currently on a conference and on vacation from Friday, so it would be great to get this done before Friday, even if I may only sporadically find time. :)

janEbert added 16 commits June 3, 2025 16:08
Previously `int | None`. Makes it possible to obtain the automatic
calculation of it when it has already been set in a TOML config.
@fegin said:
> TorchTitan currently doesn't perform force checkpoint if data is
> depleted. We can fix this but I suggest that we don't do this in this
> PR.

(See
pytorch#1238 (comment).)
I.e., a new `DataloaderStopIteration` that inherits from
`StopIteration`.

Accordingly, no longer return an optional `bool` to indicate depletion
and adapt the remainder of the code to catch the new exception instead.
This concerns only renaming
- `--training.batch_size` to `--training.local_batch_size` and
- `job_config.training.batch_size` to
  `job_config.training.local_batch_size`.
I.e., the method in `Trainer`.
Instead use a new helper variable `global_batch_size` for all logic.
Improves readability.
These were only used in 1 or 2 locations each.
... from `MetricsProcessor`.
@janEbert
Copy link
Author

janEbert commented Jun 3, 2025

Rebased because of local_batch_size changes.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks almost good! Please address final comments.

Also the addition of forward_backward_step breaks the FLUX model training.
Could you help refactor the train_step to forward_backward_step over there? Probably just

  1. remove the optimizer.zero_grad
  2. remove https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/train.py#L152-L180
  3. return loss

For the eval step https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/train.py#L182
It should be done in Trainer.train(), but since we are not using grad accumulation in FLUX training, it is OK to leave it in forward_backward_step to accelerate landing of this PR, as long as CI tests pass. @wwwjn and I will work together on fixing it later.

OverrideDefinitions(
[
[
# Default local batch size = 8, and `ngpu=2`, so
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's explicitly specify local batch size as well, in case some future PR change the default without changing the test here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -333,7 +338,7 @@ class Parallelism:
pipeline_parallel_microbatch_size: int = 1
"""
The size of each pipeline parallel microbatch (default 1).
This value is used to compute the total number of microbatches by dividing batch_size with
This value is used to compute the total number of microbatches by dividing local batch_size with
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This value is used to compute the total number of microbatches by dividing local batch_size with
This value is used to compute the total number of microbatches by dividing local_batch_size with

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! I didn't see the underscore on my dirty screen lol

Comment on lines 35 to 38
class DataloaderStopIteration(StopIteration):
"""An exception that indicates dataloader exhaustion."""

pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

try:
self.train_step(data_iterator)
except DataloaderStopIteration:
logger.info("Ran out of data; last step was canceled.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.info("Ran out of data; last step was canceled.")
logger.warning("Ran out of data; last step was canceled.")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


# Keep these variables local to shorten the code as these are
# the major variables that are used in the training loop.
def next_batch(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function sounds less necessary, especially when we already have dataloader and batch_generator. Given how short it is, it seems not too bad just running the try-catch in train_step?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, it makes the train_step look cleaner and it was nice to have it re-usable for the FLUX refactor. Does that change your mind? :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to patch the data iterator's __next__ method on-the-fly, to ensure the DataloaderStopIteration is raised, but didn't want to put too much black magic. It would require modifying the ParallelAwareDataloader.__iter__ method to apply the patch to the returned iterator. What do you think of that option?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to keep the current implementation. Monkey patching is usually not a good idea. Also agree this function makes train_step cleaner.

Some future benefit, we may want to do data loader pipelining, which overlaps the to("cuda") with the computation. This function gives us a good place to implement it.

job_config.training.local_batch_size * dp_degree
)
assert self.gradient_accumulation_steps > 0
self.loss_fn = rescale_accumulated_loss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a comment not a suggestion:

The code sounds to me assuming the loss function we use must perform a "mean" reduction, instead of "sum" also available in e.g. cross entropy loss.
But I believe this assumption is also made in pytorch DDP, FSDP, PP, and universally accepted as the default now. So I think it's ok.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I added a docstring to the function to explicitly mention this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, CP also assumes mean. A docstring will be nice, thanks!

@janEbert
Copy link
Author

janEbert commented Jun 4, 2025

PTAL.

@wwwjn
Copy link
Contributor

wwwjn commented Jun 4, 2025

Looks almost good! Please address final comments.

Also the addition of forward_backward_step breaks the FLUX model training. Could you help refactor the train_step to forward_backward_step over there? Probably just

  1. remove the optimizer.zero_grad
  2. remove https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/train.py#L152-L180
  3. return loss

For the eval step https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/train.py#L182 It should be done in Trainer.train(), but since we are not using grad accumulation in FLUX training, it is OK to leave it in forward_backward_step to accelerate landing of this PR, as long as CI tests pass. @wwwjn and I will work together on fixing it later.

Agree! The current change on FLUX side looks good to me. In the future I will also test grad accumulation w/ FLUX. Ideally in the future I will move eval_step() in to the main trainer's train loop, and reuse main trainer's train_step() in FLUX

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. There are some typing nits, but overall the implementation is clean.


def forward_backward_step(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we type the return value?


def train_step(
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, can we type the return value?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature] Add gradient accumulation
5 participants