-
Notifications
You must be signed in to change notification settings - Fork 428
refactor ParallelDims and CheckpointManager #1384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,8 +26,8 @@ | |
) | ||
from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType | ||
from torch.distributed.checkpoint.stateful import Stateful | ||
from torch.utils.data import DataLoader | ||
|
||
from torchtitan.components.dataloader import BaseDataLoader | ||
from torchtitan.components.ft import FTManager | ||
from torchtitan.components.lr_scheduler import LRSchedulersContainer | ||
from torchtitan.components.optimizer import OptimizersContainer | ||
|
@@ -180,17 +180,19 @@ class CheckpointManager: | |
|
||
def __init__( | ||
self, | ||
dataloader: DataLoader, | ||
dataloader: BaseDataLoader | None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ebsmothers I think it won't look too bad when I specify There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the heads up, I think this is a reasonable compromise |
||
model_parts: list[nn.Module], | ||
optimizers: OptimizersContainer, | ||
lr_schedulers: LRSchedulersContainer, | ||
states: dict[str, Any], | ||
job_config: JobConfig, | ||
ft_manager: FTManager, | ||
ft_manager: FTManager | None = None, | ||
) -> None: | ||
ckpt_config = job_config.checkpoint | ||
self.enable_checkpoint = ckpt_config.enable_checkpoint | ||
self.ft_manager = ft_manager.manager if ft_manager.enabled else None | ||
self.ft_manager = ( | ||
ft_manager.manager if ft_manager and ft_manager.enabled else None | ||
) | ||
|
||
if self.ft_manager: | ||
optimizers.init_cache_state_dict() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We don't need this line here to "build world mesh" explicitly, right? In line 134,
parallel_dims.world_mesh["tp"]
will callbuild_mesh
internally.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be used in line 136 to set determinism.