Skip to content

[WIP][RFC] Always flatten model state_dict #1347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 3, 2025
Merged

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Jun 26, 2025

The model state_dict is unique compared to other state dictionaries (e.g., optimizer). It's the only one that will be exported outside of TorchTitan and imported from other sources. To ensure FQN consistency, we previously removed the prefix during the first checkpoint load and last checkpoint save. However, this approach has caused confusion among users, despite available options to control behavior.

This PR aims to resolve the issue by always flattening the model state dictionary, eliminating the "MODEL." prefix from its keys. We decided not to flatten all components due to the risk of key collisions between different components. Instead, this PR only flattens the model state_dict, which is a special case.

While this solution isn't perfect, as it introduces different handling for different components, it's a good compromise given the unique nature of the model state_dict.

Also see the discussion in #1321 (comment)

This is the pseudo code for the current state:

if model_only:
    state_dict = model.state_dict()
else:
    state_dict = {
        "MODEL": model,
        "OPTIMIZER": optimizer,
         ...
     }
}

This is the pseudo code after this PR is landed:

state_dict = model.state_dict()
if not model_only:
    state_dict.update(
        {"OPTIMIZER": optimizer}
         ...
     )

FSDP4 v.s. FSDP4 TP2 loss curve with seed checkpoint and --training.seed=42

Uploading Screenshot 2025-07-02 at 1.02.23 PM.png…

@fegin fegin requested review from tianyu-l and wwwjn as code owners June 26, 2025 18:07
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 26, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you help verify "save seed checkpoint -> load from seed checkpoint with different parallelism" still yields identical loss, as an additional integration test?

@@ -573,8 +590,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
"""
# For the first step, we will only load the model weights.
if model_only:
sd = self.states[MODEL].state_dict()
return sd
return self.states[MODEL].state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean "model" still exists in state_dict as a key -- we only flatten it in the checkpoint (and its load and save)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Checkpointer, we still keep separate keys in self.states, like MODEL, OPTIMIZER. This will allow use to manipulate different state_dicts. This line use MODEL to access only the model state_dict but this line does not wrap the model state_dict, so there will be no model. prefix.

@fegin fegin requested a review from wconstab as a code owner June 27, 2025 22:21
@fegin
Copy link
Contributor Author

fegin commented Jul 1, 2025

@tianyu-l The loss curves don't match, w/ or w/o freqs_cis in the seed checkpoint. The seed checkpoint may have broken before freqs_cis being removed from checkpoint.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of loss curve matching

local_batch_size is 2 when TP is 2, otherwise local_batch_size is 1.

In order to make sure dataloader behaves consistently across multiple runs, we need to fix the DP degree (dp_replicate * dpshard).
So a typical comparison you may have (by keeping the overall DP degree 4) could be

  • FSDP 4
  • DP 2, FSDP 2, TP 2
  • DP 2, FSDP 2, CP 2, PP 2

For details and more examples, please see
https://github.com/pytorch/torchtitan/blob/main/docs/converging.md

@wwwjn we probably should include this in #1363

@fegin fegin force-pushed the fegin/flatten_checkpoint branch from 4050014 to f0d0dd7 Compare July 2, 2025 05:52
@fegin
Copy link
Contributor Author

fegin commented Jul 2, 2025

@tianyu-l I also tried FSDP4 v.s. FSDP4 TP2 before I tried this setting and the loss curves didn't match.

@fegin
Copy link
Contributor Author

fegin commented Jul 2, 2025


Correct the statement, FSDP4 v.s. FSDP4 TP2 match, with or without fixing `training.seed`.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

@fegin fegin merged commit 7d5f3cc into main Jul 3, 2025
7 checks passed
@tianyu-l tianyu-l deleted the fegin/flatten_checkpoint branch July 5, 2025 00:18
H-Huang pushed a commit to H-Huang/torchtitan that referenced this pull request Jul 8, 2025
The model state_dict is unique compared to other state dictionaries
(e.g., optimizer). It's the only one that will be exported outside of
TorchTitan and imported from other sources. To ensure FQN consistency,
we previously removed the prefix during the first checkpoint load and
last checkpoint save. However, this approach has caused confusion among
users, despite available options to control behavior.

This PR aims to resolve the issue by always flattening the model state
dictionary, eliminating the `"MODEL."` prefix from its keys. We decided
not to flatten all components due to the risk of key collisions between
different components. Instead, this PR only flattens the model
state_dict, which is a special case.

While this solution isn't perfect, as it introduces different handling
for different components, it's a good compromise given the unique nature
of the model state_dict.

Also see the discussion in
pytorch#1321 (comment)


This is the pseudo code for the current state:
```
if model_only:
    state_dict = model.state_dict()
else:
    state_dict = {
        "MODEL": model,
        "OPTIMIZER": optimizer,
         ...
     }
}
```

This is the pseudo code after this PR is landed:
```
state_dict = model.state_dict()
if not model_only:
    state_dict.update(
        {"OPTIMIZER": optimizer}
         ...
     )
```



FSDP4 v.s. FSDP4 TP2 loss curve with seed checkpoint and
--training.seed=42

![Uploading Screenshot 2025-07-02 at 1.02.23 PM.png…]()
H-Huang pushed a commit to H-Huang/torchtitan that referenced this pull request Jul 8, 2025
The model state_dict is unique compared to other state dictionaries
(e.g., optimizer). It's the only one that will be exported outside of
TorchTitan and imported from other sources. To ensure FQN consistency,
we previously removed the prefix during the first checkpoint load and
last checkpoint save. However, this approach has caused confusion among
users, despite available options to control behavior.

This PR aims to resolve the issue by always flattening the model state
dictionary, eliminating the `"MODEL."` prefix from its keys. We decided
not to flatten all components due to the risk of key collisions between
different components. Instead, this PR only flattens the model
state_dict, which is a special case.

While this solution isn't perfect, as it introduces different handling
for different components, it's a good compromise given the unique nature
of the model state_dict.

Also see the discussion in
pytorch#1321 (comment)


This is the pseudo code for the current state:
```
if model_only:
    state_dict = model.state_dict()
else:
    state_dict = {
        "MODEL": model,
        "OPTIMIZER": optimizer,
         ...
     }
}
```

This is the pseudo code after this PR is landed:
```
state_dict = model.state_dict()
if not model_only:
    state_dict.update(
        {"OPTIMIZER": optimizer}
         ...
     )
```



FSDP4 v.s. FSDP4 TP2 loss curve with seed checkpoint and
--training.seed=42

![Uploading Screenshot 2025-07-02 at 1.02.23 PM.png…]()
mori360 pushed a commit to mori360/torchtitan that referenced this pull request Jul 8, 2025
The model state_dict is unique compared to other state dictionaries
(e.g., optimizer). It's the only one that will be exported outside of
TorchTitan and imported from other sources. To ensure FQN consistency,
we previously removed the prefix during the first checkpoint load and
last checkpoint save. However, this approach has caused confusion among
users, despite available options to control behavior.

This PR aims to resolve the issue by always flattening the model state
dictionary, eliminating the `"MODEL."` prefix from its keys. We decided
not to flatten all components due to the risk of key collisions between
different components. Instead, this PR only flattens the model
state_dict, which is a special case.

While this solution isn't perfect, as it introduces different handling
for different components, it's a good compromise given the unique nature
of the model state_dict.

Also see the discussion in
pytorch#1321 (comment)


This is the pseudo code for the current state:
```
if model_only:
    state_dict = model.state_dict()
else:
    state_dict = {
        "MODEL": model,
        "OPTIMIZER": optimizer,
         ...
     }
}
```

This is the pseudo code after this PR is landed:
```
state_dict = model.state_dict()
if not model_only:
    state_dict.update(
        {"OPTIMIZER": optimizer}
         ...
     )
```



FSDP4 v.s. FSDP4 TP2 loss curve with seed checkpoint and
--training.seed=42

![Uploading Screenshot 2025-07-02 at 1.02.23 PM.png…]()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants