-
Notifications
You must be signed in to change notification settings - Fork 383
Support gradient accumulation #1238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Maybe it would also make sense to rename |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. I suggest that we don't let train_step()
be aware of data_iterator
. Please see the detail comments.
Also, this PR doesn't change the parallelization, which is not correct. We will have to call set_requires_gradient_sync
if FSDP is applied. We can raise an exception if DDP is used and accumulation_steps > 1
for now.
torchtitan/train.py
Outdated
unwrapped_loss_fn = self.loss_fn | ||
|
||
@functools.wraps(unwrapped_loss_fn) | ||
def accumulated_loss_fn(*args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just modify build_loss_fn
to take accmulation_steps
to let the loss function decide the usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK either way.
I think being more explicit about grad accumulation handling doesn't look bad.
Also if we go with explicit global_batch_size
and implicit grad_accu_steps
, then we'll need to do another check & computation in the loss function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the wrapping functionality to torchtitan.components.loss
, called it rescale_accumulated_loss
. Not quite like what you wanted, but that way we can re-use the Trainer.gradient_accumulation_step
value more easily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding this feature!
I left several comments. Please see if they make sense.
torchtitan/config_manager.py
Outdated
@@ -192,6 +192,11 @@ class Training: | |||
batch_size: int = 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah let's call it local_batch_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a rename across the codebase whereever JobConfig.training.batch_size
or --training.batch_size
was used. Not sure how you'd like me to handle the compatibility breakage that this introduces.
torchtitan/train.py
Outdated
if job_config.training.global_batch_size < 0: | ||
job_config.training.global_batch_size = ( | ||
job_config.training.batch_size * dp_degree | ||
) | ||
assert job_config.training.global_batch_size > 0 | ||
assert ( | ||
job_config.training.global_batch_size | ||
% (job_config.training.batch_size * dp_degree) | ||
== 0 | ||
), ( | ||
f"global batch size must be multiple of local batch size times " | ||
f"data-parallel degree ({job_config.training.global_batch_size} " | ||
f"% ({job_config.training.batch_size} * {dp_degree}) != 0)" | ||
) | ||
|
||
self.gradient_accumulation_steps = job_config.training.global_batch_size // ( | ||
job_config.training.batch_size * dp_degree | ||
) | ||
assert self.gradient_accumulation_steps > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit comment
if job_config.training.global_batch_size < 0: | |
job_config.training.global_batch_size = ( | |
job_config.training.batch_size * dp_degree | |
) | |
assert job_config.training.global_batch_size > 0 | |
assert ( | |
job_config.training.global_batch_size | |
% (job_config.training.batch_size * dp_degree) | |
== 0 | |
), ( | |
f"global batch size must be multiple of local batch size times " | |
f"data-parallel degree ({job_config.training.global_batch_size} " | |
f"% ({job_config.training.batch_size} * {dp_degree}) != 0)" | |
) | |
self.gradient_accumulation_steps = job_config.training.global_batch_size // ( | |
job_config.training.batch_size * dp_degree | |
) | |
assert self.gradient_accumulation_steps > 0 | |
global_batch_size = job_config.training.global_batch_size | |
if global_batch_size < 0: | |
global_batch_size = job_config.training.batch_size * dp_degree | |
self.gradient_accumulation_steps = 1 | |
else: | |
assert global_batch_size > (job_config.training.batch_size * dp_degree) | |
assert ( | |
job_config.training.global_batch_size | |
% (job_config.training.batch_size * dp_degree) | |
== 0 | |
), ( | |
f"global batch size must be multiple of local batch size times " | |
f"data-parallel degree ({global_batch_size} " | |
f"% ({job_config.training.batch_size} * {dp_degree}) != 0)" | |
) | |
self.gradient_accumulation_steps = global_batch_size // ( | |
job_config.training.batch_size * dp_degree | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't really agree with not re-using the code that would become else
case here, but can still change it to your recommendation. For now, I put the addition of the global_batch_size
variable into its own commit, which probably already has the readability improvements that you'd like. Also added a comment in the if
case that this global batch size results in 1 gradient accumulation step.
torchtitan/train.py
Outdated
@@ -183,6 +205,15 @@ def __init__(self, job_config: JobConfig): | |||
|
|||
self.loss_fn = self.train_spec.build_loss_fn(job_config) | |||
|
|||
unwrapped_loss_fn = self.loss_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's put the self.gradient_accumulation_steps
derivation code right before here, to group gradient accum logic together as much as possible.
I understand that it is desirable to fail early on infeasible global batch size, even before parallelism and other heavy things are applied. But I'd suggest we prioritize readability. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds fair! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this.
torchtitan/train.py
Outdated
|
||
# Keep these variables local to shorten the code as these are | ||
# the major variables that are used in the training loop. | ||
def batch_backward(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we call it forward_backward_step
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. By the way, if you'd prefer me to squash these changes into the previous commits, I'd be happy to clean up the commit chain.
torchtitan/train.py
Outdated
model_parts = self.model_parts | ||
world_mesh = self.world_mesh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly, maybe not worth keeping these two
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
torchtitan/components/metrics.py
Outdated
@@ -336,6 +337,7 @@ def __init__( | |||
) | |||
self.ntokens_since_last_log = 0 | |||
self.data_loading_times = [] | |||
self.accumulated_losses = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it represents a core training concept, rather than directly used for metrics logging, let's put this in Trainer
, instead of MetricsProcessor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Also added the gradient_accumulation_steps
attribute to the Trainer
's dataclass attributes.
torchtitan/train.py
Outdated
except StopIteration: | ||
# If data runs out during gradient accumulation, that | ||
# entire step will not be executed. | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of explicit return True
, can we just call next
and let the StopIteration
exception propagate to train_step
and catch over there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I initially had it implemented this way, but thought the try
block would encapsulate too much code. If anything else raises a StopIteration
, it would make debugging much more difficult. Therefore the minimization of the try
scope.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer directly raise StopIteration and let the outer loop to catch. As mentioned in the above discussion, the original design is to keep train_step()
simple without data dependency. So there is no other StopIteration()
afaik. If there are other places actually raise the StopIteration
, we should figure it out.
If we really want to avoid ambiguity , we can have a customized next()
, like next_batch()
which will raise a customized DataDepleteException()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's considerate. I think it's quite unlikely other places would also raise StopIteration
? Maybe microbatching in pipeline parallel? But over there the number of microbatches should be fixed ahead of time.
Anyways, if you think we need to deal with this explicitly, we should catch the StopIteration
exception, and raise a customized DataloaderStopIteration
exception to be caught by caller, instead of return True
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Went with a combination of these suggestions; a Trainer.next_batch
method basically just calls next(data_iterator)
, but catches and re-raises its StopIteration
as a new DataloaderStopIteration
.
torchtitan/train.py
Outdated
self.step += 1 | ||
self.gc_handler.run(self.step) | ||
self.train_step(inputs, labels) | ||
data_ran_out = self.train_step(data_iterator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can catch the StopIteration
here and do different treatment on self.checkpointer.save
in try vs. catch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has been changed, but we now simply break
in case of the DataloaderStopIteration
to prevent the change to the checkpointing logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does change the general logic (e.g., torch_profiler
and memory_profiler
won't be step
ped anymore) compared to the previous code, but is a bit nicer to read instead of adding an extra variable check in the while
-query, IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this PR doesn't change the parallelization, which is not correct. We will have to call set_requires_gradient_sync if FSDP is applied.
@fegin For background please see #292 (comment)
I think for us we don't want the potential memory overhead and code complexity, although it can save some communications which could've been hidden anyway.
torchtitan/train.py
Outdated
unwrapped_loss_fn = self.loss_fn | ||
|
||
@functools.wraps(unwrapped_loss_fn) | ||
def accumulated_loss_fn(*args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK either way.
I think being more explicit about grad accumulation handling doesn't look bad.
Also if we go with explicit global_batch_size
and implicit grad_accu_steps
, then we'll need to do another check & computation in the loss function.
torchtitan/train.py
Outdated
def train_step( | ||
self, | ||
data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]], | ||
) -> bool | None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just return bool
and change all other returns to return False to keep the semantic consistent. This should be changed if we still keep the returning value as the design option. But I prefer try/catch. See the below response.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted this/refactored to try-catch solution as per other discussions. Return type is back to implicit None
.
torchtitan/train.py
Outdated
except StopIteration: | ||
# If data runs out during gradient accumulation, that | ||
# entire step will not be executed. | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer directly raise StopIteration and let the outer loop to catch. As mentioned in the above discussion, the original design is to keep train_step()
simple without data dependency. So there is no other StopIteration()
afaik. If there are other places actually raise the StopIteration
, we should figure it out.
If we really want to avoid ambiguity , we can have a customized next()
, like next_batch()
which will raise a customized DataDepleteException()
.
The review order looks pretty confusing, lol. The summary of some big discussions:
cc., @tianyu-l |
hey @janEbert how about let's work a bit more on the PR. Sorry for the confusion in the reviews. I think we have agreed on the direction:
Please also add a test case in https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests.py |
@fegin said: > TorchTitan currently doesn't perform force checkpoint if data is > depleted. We can fix this but I suggest that we don't do this in this > PR. (See pytorch#1238 (comment).)
I believe I have incorporated all the feedback. Let me know how you like the changes. FYI I'm currently on a conference and on vacation from Friday, so it would be great to get this done before Friday, even if I may only sporadically find time. :) |
Previously `int | None`. Makes it possible to obtain the automatic calculation of it when it has already been set in a TOML config.
@fegin said: > TorchTitan currently doesn't perform force checkpoint if data is > depleted. We can fix this but I suggest that we don't do this in this > PR. (See pytorch#1238 (comment).)
I.e., a new `DataloaderStopIteration` that inherits from `StopIteration`. Accordingly, no longer return an optional `bool` to indicate depletion and adapt the remainder of the code to catch the new exception instead.
This concerns only renaming - `--training.batch_size` to `--training.local_batch_size` and - `job_config.training.batch_size` to `job_config.training.local_batch_size`.
I.e., the method in `Trainer`.
Instead use a new helper variable `global_batch_size` for all logic. Improves readability.
Improve readability.
These were only used in 1 or 2 locations each.
... from `MetricsProcessor`.
Rebased because of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks almost good! Please address final comments.
Also the addition of forward_backward_step
breaks the FLUX model training.
Could you help refactor the train_step
to forward_backward_step
over there? Probably just
- remove the
optimizer.zero_grad
- remove https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/train.py#L152-L180
return loss
For the eval step https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/train.py#L182
It should be done in Trainer.train()
, but since we are not using grad accumulation in FLUX training, it is OK to leave it in forward_backward_step
to accelerate landing of this PR, as long as CI tests pass. @wwwjn and I will work together on fixing it later.
tests/integration_tests.py
Outdated
OverrideDefinitions( | ||
[ | ||
[ | ||
# Default local batch size = 8, and `ngpu=2`, so |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's explicitly specify local batch size as well, in case some future PR change the default without changing the test here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
torchtitan/config_manager.py
Outdated
@@ -333,7 +338,7 @@ class Parallelism: | |||
pipeline_parallel_microbatch_size: int = 1 | |||
""" | |||
The size of each pipeline parallel microbatch (default 1). | |||
This value is used to compute the total number of microbatches by dividing batch_size with | |||
This value is used to compute the total number of microbatches by dividing local batch_size with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This value is used to compute the total number of microbatches by dividing local batch_size with | |
This value is used to compute the total number of microbatches by dividing local_batch_size with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch! I didn't see the underscore on my dirty screen lol
torchtitan/train.py
Outdated
class DataloaderStopIteration(StopIteration): | ||
"""An exception that indicates dataloader exhaustion.""" | ||
|
||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
torchtitan/train.py
Outdated
try: | ||
self.train_step(data_iterator) | ||
except DataloaderStopIteration: | ||
logger.info("Ran out of data; last step was canceled.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger.info("Ran out of data; last step was canceled.") | |
logger.warning("Ran out of data; last step was canceled.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
# Keep these variables local to shorten the code as these are | ||
# the major variables that are used in the training loop. | ||
def next_batch( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function sounds less necessary, especially when we already have dataloader and batch_generator
. Given how short it is, it seems not too bad just running the try-catch in train_step
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, it makes the train_step
look cleaner and it was nice to have it re-usable for the FLUX refactor. Does that change your mind? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking to patch the data iterator's __next__
method on-the-fly, to ensure the DataloaderStopIteration
is raised, but didn't want to put too much black magic. It would require modifying the ParallelAwareDataloader.__iter__
method to apply the patch to the returned iterator. What do you think of that option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to keep the current implementation. Monkey patching is usually not a good idea. Also agree this function makes train_step
cleaner.
Some future benefit, we may want to do data loader pipelining, which overlaps the to("cuda")
with the computation. This function gives us a good place to implement it.
job_config.training.local_batch_size * dp_degree | ||
) | ||
assert self.gradient_accumulation_steps > 0 | ||
self.loss_fn = rescale_accumulated_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a comment not a suggestion:
The code sounds to me assuming the loss function we use must perform a "mean" reduction, instead of "sum" also available in e.g. cross entropy loss.
But I believe this assumption is also made in pytorch DDP, FSDP, PP, and universally accepted as the default now. So I think it's ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I added a docstring to the function to explicitly mention this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, CP also assumes mean
. A docstring will be nice, thanks!
... toward `forward_backward_step` design.
PTAL. |
Agree! The current change on FLUX side looks good to me. In the future I will also test grad accumulation w/ FLUX. Ideally in the future I will move |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. There are some typing nits, but overall the implementation is clean.
|
||
def forward_backward_step( | ||
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we type the return value?
|
||
def train_step( | ||
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, can we type the return value?
First, the batched backward calculation is refactored into its own function. Then, gradient accumulation is implemented by moving the data iterator inside the
train_step
method and consuming data from it as necessary. I added some extra handling for non-infinite data iterators, but if you dislike that additional complexity, I can remove it to simplify the code.The feature is enabled by giving an additional
--training.global_batch_size
, which has a sensible default of 1 gradient accumulation step (i.e., no actual accumulation).@tianyu-l thanks for the ping.