-
Notifications
You must be signed in to change notification settings - Fork 379
[SimpleFSDP] Add CI for SimpleFSDP #1231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
f9581c8
to
31a3ab5
Compare
31a3ab5
to
d6d3f0a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had some suggestions on truncating the test set: some of them I believe are not proper to run, some else are to relieve CI burden.
For checkpointing, we marked it as WIP feature. I'm a bit surprise it can run -- can we verify the behavior is expected?
# "cpu_offload+opt_in_bwd+TP+DP+CP", | ||
# ngpu=8, | ||
# ), | ||
OverrideDefinitions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for testing generate script in https://github.com/pytorch/torchtitan/tree/main/scripts/generate
Let's remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's needed to verify if DCP can correctly load checkpoint for this line:
torchtitan/scripts/generate/test_generate.py
Lines 147 to 150 in ed2bbc0
begin = time.monotonic() | |
logger.info(f"Loading chkpt at: {checkpoint_path}") | |
dcp.load(state_dict, checkpoint_id=checkpoint_path) | |
logger.info(f"Finished loading chkpt in {time.monotonic() - begin:.2f} seconds.") |
Currently, this line is not correctly loading weights from saved checkpoints. I have commented entry to this test as TODO from L283. Maybe I should comment and add TODO this test as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a test for DCP resharding, similar to https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests.py#L478. If we can reproduce the error, we should remove this test, which is to test the generate script, not (only) for DCP loading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a quick verification for DCP. It turns out resharding from [dp:4] --> [dp:2, tp:2] still works properly. The loss is converging.
I took a closer look at test_generate
. The error comes from resharding [dp:4] --> [tp:4] when initializing the parallel dim here:
torchtitan/scripts/generate/test_generate.py
Lines 121 to 130 in ed2bbc0
dist_utils.init_distributed(config) | |
parallel_dims = ParallelDims( | |
dp_replicate=1, | |
dp_shard=-1, | |
cp=1, | |
tp=world_size, | |
pp=1, | |
world_size=world_size, | |
enable_loss_parallel=False, | |
) |
optional_checkpoint
test for this.
For DCP, it will throw error when loading the weight dicts in |
cacbd5b
to
3316395
Compare
3316395
to
9200036
Compare
Does SimpleFSDP change FQNs? |
As titled, this PR fixes the incorrect test name in CI testing per this previous discussion: #1231 (comment)
I think SimpleFSDP only parametrizes the parameters but doesn't explicitly change FQNs. (@tianyu-l can confirm this) However, I found the model weight key is renamed to This should be the reason causing [dp:4] -> [tp:4] resharding error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
However, I found the model weight key is renamed to model.tok_embeddings.parametrizations.weight.original after being wrapped by SimpleFSDP. When the weight is loaded by [tp:4], SimpleFSDP is not called here and the weight is the normal model.tok_embeddings.weight.
Yes, general DCP compatibility is broken by parametrization. IIRC @fmassa has a fix/improvement to parametrization so it does not interfere with parameter names for DCP. Let's follow up on the fix later.
This PR adds CI for SimpleFSDP. It mainly adopted FSDP-related integration tests from this file: https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests.py
A few things to list here:
Currently, DDP/HSDP + TP (SimpleFSDP's
replicate
andhybrid_shard
mode) is not supported. I'll add support for them in follow-up PRs.DCP + SimpleFSDP failed in
test_generate
test on my end. The output seems to indicate some of the model weights are not correctly checkpointed when resharding from [dp:4] -> [tp:4]. [dp:4] -> [dp:2, tp:2] works out.