Skip to content

Commit 315191e

Browse files
[SimpleFSDP] Note on DCP checkpoint saving/loading (#1280)
Per this previous conversation: #1273 (comment), we found there will be additional all-gathers when saving/loading DCP in simplefsdp. I'm leaving a note here and will update after @fegin's checkpointing upgrading. ![Screenshot 2025-06-10 at 10 53 14 AM](https://github.com/user-attachments/assets/748b3b68-4062-4840-b1e0-318a42d813ca)
1 parent b7c7ed7 commit 315191e

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,11 @@ def _register_parametrization(
128128
module: nn.Module, param_names: List[str], parametrization: nn.Module
129129
):
130130
"""
131-
it works with state_dict without incurring parametrization calls because
131+
It works with state_dict without incurring parametrization calls because
132132
state_dict accesses parameters directly from self._parameters, not from getters
133133
https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L2141
134+
TODO: In checkpoint saving/loading, avoid parametrization calls when calling
135+
get_model_state_dict func in torchtitan's torchtitan/components/checkpoint.py.
134136
"""
135137
param_name_to_property = {
136138
param_name: property(lambda self: parametrization(self._parameters[param_name]))

0 commit comments

Comments
 (0)