Skip to content

refactor ParallelDims and CheckpointManager #1384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2025
Merged

refactor ParallelDims and CheckpointManager #1384

merged 1 commit into from
Jul 14, 2025

Conversation

tianyu-l
Copy link
Contributor

@tianyu-l tianyu-l commented Jul 12, 2025

This PR does the following:

  1. move world_mesh into ParallelDims, as they have a close relationship
  2. move enable_loss_parallel out of ParallelDims constructor
  3. add a convenient property seq_len_divisor to ParallelDims
  4. set dataloader and ft_manager as optional in CheckpointManager
  5. some minor improvements on typing and code organization

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 12, 2025
@tianyu-l
Copy link
Contributor Author

cc @ebsmothers

@@ -180,17 +180,19 @@ class CheckpointManager:

def __init__(
self,
dataloader: DataLoader,
dataloader: BaseDataLoader | None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers
I kept this field to be required and value to be optional -- the code still works. I didn't make it completely optional with a None default because that would require more if-else in this file.

I think it won't look too bad when I specify dataloader=None in forge_engine.py. Let me know if it's ok to you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up, I think this is a reasonable compromise

@tianyu-l tianyu-l force-pushed the refactor branch 2 times, most recently from 668ee1e to 0b4cad7 Compare July 13, 2025 02:56
@tianyu-l tianyu-l requested a review from ebsmothers July 13, 2025 07:21
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@@ -54,7 +55,7 @@ def parallelize_deepseekv3(
apply_non_moe_tp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add the same check job_config.training.seq_len % parallel_dims.seq_len_divisor for TP here. You could add it or I could add it in next PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I can add them

)
# Build world mesh for parallelism
world_mesh = parallel_dims.build_mesh(device_type=device_type)
world_mesh = parallel_dims.world_mesh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We don't need this line here to "build world mesh" explicitly, right? In line 134, parallel_dims.world_mesh["tp"] will call build_mesh internally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be used in line 136 to set determinism.

@@ -457,7 +454,9 @@ def train_step(
[p for m in self.model_parts for p in m.parameters()],
self.job_config.training.max_norm,
foreach=True,
pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None,
pp_mesh=(
parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: getting a bit inconsistent about whether to pass parallel_dims or pass a mesh obj. I think it is not a big deal though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's good catch!
Passing in ParallelDims only seems enough, but I have the concern that it will break BC as some users use this function as a standalone util -- I think we can change it to not pass in ParallelDims.

@tianyu-l tianyu-l merged commit 6204cdf into main Jul 14, 2025
10 checks passed
@tianyu-l tianyu-l deleted the refactor branch July 14, 2025 20:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants