-
Notifications
You must be signed in to change notification settings - Fork 428
refactor ParallelDims and CheckpointManager #1384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
cc @ebsmothers |
@@ -180,17 +180,19 @@ class CheckpointManager: | |||
|
|||
def __init__( | |||
self, | |||
dataloader: DataLoader, | |||
dataloader: BaseDataLoader | None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers
I kept this field to be required and value to be optional -- the code still works. I didn't make it completely optional with a None
default because that would require more if-else in this file.
I think it won't look too bad when I specify dataloader=None
in forge_engine.py
. Let me know if it's ok to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the heads up, I think this is a reasonable compromise
668ee1e
to
0b4cad7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@@ -54,7 +55,7 @@ def parallelize_deepseekv3( | |||
apply_non_moe_tp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to add the same check job_config.training.seq_len % parallel_dims.seq_len_divisor
for TP here. You could add it or I could add it in next PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I can add them
) | ||
# Build world mesh for parallelism | ||
world_mesh = parallel_dims.build_mesh(device_type=device_type) | ||
world_mesh = parallel_dims.world_mesh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We don't need this line here to "build world mesh" explicitly, right? In line 134, parallel_dims.world_mesh["tp"]
will call build_mesh
internally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be used in line 136 to set determinism.
@@ -457,7 +454,9 @@ def train_step( | |||
[p for m in self.model_parts for p in m.parameters()], | |||
self.job_config.training.max_norm, | |||
foreach=True, | |||
pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, | |||
pp_mesh=( | |||
parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: getting a bit inconsistent about whether to pass parallel_dims or pass a mesh obj. I think it is not a big deal though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's good catch!
Passing in ParallelDims
only seems enough, but I have the concern that it will break BC as some users use this function as a standalone util -- I think we can change it to not pass in ParallelDims
.
This PR does the following:
world_mesh
intoParallelDims
, as they have a close relationshipenable_loss_parallel
out ofParallelDims
constructorseq_len_divisor
toParallelDims
dataloader
andft_manager
as optional inCheckpointManager