You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[WIP][RFC] Always flatten model state_dict (#1347)
The model state_dict is unique compared to other state dictionaries
(e.g., optimizer). It's the only one that will be exported outside of
TorchTitan and imported from other sources. To ensure FQN consistency,
we previously removed the prefix during the first checkpoint load and
last checkpoint save. However, this approach has caused confusion among
users, despite available options to control behavior.
This PR aims to resolve the issue by always flattening the model state
dictionary, eliminating the `"MODEL."` prefix from its keys. We decided
not to flatten all components due to the risk of key collisions between
different components. Instead, this PR only flattens the model
state_dict, which is a special case.
While this solution isn't perfect, as it introduces different handling
for different components, it's a good compromise given the unique nature
of the model state_dict.
Also see the discussion in
#1321 (comment)
This is the pseudo code for the current state:
```
if model_only:
state_dict = model.state_dict()
else:
state_dict = {
"MODEL": model,
"OPTIMIZER": optimizer,
...
}
}
```
This is the pseudo code after this PR is landed:
```
state_dict = model.state_dict()
if not model_only:
state_dict.update(
{"OPTIMIZER": optimizer}
...
)
```
FSDP4 v.s. FSDP4 TP2 loss curve with seed checkpoint and
--training.seed=42
![Uploading Screenshot 2025-07-02 at 1.02.23 PM.png…]()
0 commit comments