Skip to content

[SimpleFSDP] Add CI for SimpleFSDP #1231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2025
Merged

[SimpleFSDP] Add CI for SimpleFSDP #1231

merged 1 commit into from
May 30, 2025

Conversation

ruisizhang123
Copy link
Collaborator

@ruisizhang123 ruisizhang123 commented May 28, 2025

This PR adds CI for SimpleFSDP. It mainly adopted FSDP-related integration tests from this file: https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests.py

A few things to list here:

  • Currently, DDP/HSDP + TP (SimpleFSDP's replicate and hybrid_shard mode) is not supported. I'll add support for them in follow-up PRs.

  • DCP + SimpleFSDP failed in test_generate test on my end. The output seems to indicate some of the model weights are not correctly checkpointed when resharding from [dp:4] -> [tp:4]. [dp:4] -> [dp:2, tp:2] works out.

[rank2]:   File "/home/ruisi/pytorch/torch/distributed/checkpoint/default_planner.py", line 471, in create_default_local_load_plan
[rank2]:     raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank2]: RuntimeError: Missing key in checkpoint state_dict: model.tok_embeddings.weight.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 28, 2025
@ruisizhang123 ruisizhang123 force-pushed the ruisi/simplefsdp_ci branch 3 times, most recently from f9581c8 to 31a3ab5 Compare May 28, 2025 08:24
@ruisizhang123 ruisizhang123 requested a review from tianyu-l May 28, 2025 16:31
@ruisizhang123 ruisizhang123 force-pushed the ruisi/simplefsdp_ci branch from 31a3ab5 to d6d3f0a Compare May 29, 2025 06:40
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had some suggestions on truncating the test set: some of them I believe are not proper to run, some else are to relieve CI burden.

For checkpointing, we marked it as WIP feature. I'm a bit surprise it can run -- can we verify the behavior is expected?

# "cpu_offload+opt_in_bwd+TP+DP+CP",
# ngpu=8,
# ),
OverrideDefinitions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for testing generate script in https://github.com/pytorch/torchtitan/tree/main/scripts/generate
Let's remove

Copy link
Collaborator Author

@ruisizhang123 ruisizhang123 May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's needed to verify if DCP can correctly load checkpoint for this line:

begin = time.monotonic()
logger.info(f"Loading chkpt at: {checkpoint_path}")
dcp.load(state_dict, checkpoint_id=checkpoint_path)
logger.info(f"Finished loading chkpt in {time.monotonic() - begin:.2f} seconds.")
.

Currently, this line is not correctly loading weights from saved checkpoints. I have commented entry to this test as TODO from L283. Maybe I should comment and add TODO this test as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a test for DCP resharding, similar to https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests.py#L478. If we can reproduce the error, we should remove this test, which is to test the generate script, not (only) for DCP loading.

Copy link
Collaborator Author

@ruisizhang123 ruisizhang123 May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a quick verification for DCP. It turns out resharding from [dp:4] --> [dp:2, tp:2] still works properly. The loss is converging.

I took a closer look at test_generate. The error comes from resharding [dp:4] --> [tp:4] when initializing the parallel dim here:

dist_utils.init_distributed(config)
parallel_dims = ParallelDims(
dp_replicate=1,
dp_shard=-1,
cp=1,
tp=world_size,
pp=1,
world_size=world_size,
enable_loss_parallel=False,
)
. I'll add another config under the optional_checkpoint test for this.

@ruisizhang123
Copy link
Collaborator Author

ruisizhang123 commented May 29, 2025

For checkpointing, we marked it as WIP feature. I

For DCP, it will throw error when loading the weight dicts in test_generate test as I mention in the PR description. Some of the weights are not properly checkpointed. We probably need some efforts for the DCP integration.

@ruisizhang123 ruisizhang123 force-pushed the ruisi/simplefsdp_ci branch 3 times, most recently from cacbd5b to 3316395 Compare May 29, 2025 18:07
@ruisizhang123 ruisizhang123 force-pushed the ruisi/simplefsdp_ci branch from 3316395 to 9200036 Compare May 29, 2025 19:07
@ruisizhang123 ruisizhang123 requested review from fegin and tianyu-l May 29, 2025 20:18
@fegin
Copy link
Contributor

fegin commented May 30, 2025

Does SimpleFSDP change FQNs?

ruisizhang123 added a commit that referenced this pull request May 30, 2025
As titled, this PR fixes the incorrect test name in CI testing per this
previous discussion:
#1231 (comment)
@ruisizhang123
Copy link
Collaborator Author

Does SimpleFSDP change FQNs?

I think SimpleFSDP only parametrizes the parameters but doesn't explicitly change FQNs. (@tianyu-l can confirm this)

However, I found the model weight key is renamed to model.tok_embeddings.parametrizations.weight.original after being wrapped by SimpleFSDP. When the weight is loaded by [tp:4], SimpleFSDP is not called here and the weight is the normal model.tok_embeddings.weight.

This should be the reason causing [dp:4] -> [tp:4] resharding error.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

However, I found the model weight key is renamed to model.tok_embeddings.parametrizations.weight.original after being wrapped by SimpleFSDP. When the weight is loaded by [tp:4], SimpleFSDP is not called here and the weight is the normal model.tok_embeddings.weight.

Yes, general DCP compatibility is broken by parametrization. IIRC @fmassa has a fix/improvement to parametrization so it does not interfere with parameter names for DCP. Let's follow up on the fix later.

@tianyu-l tianyu-l merged commit 0265e8d into main May 30, 2025
7 checks passed
@tianyu-l tianyu-l deleted the ruisi/simplefsdp_ci branch May 30, 2025 04:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants