Skip to content

Commit 8850c5a

Browse files
authored
Fix the paramterization of simple fsdp (#1326)
Some regression since #1273 The issue is python closure could not capture the param name correctly unless you pass it as a input to the lambda function.
1 parent f4048f8 commit 8850c5a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def _register_parametrization(
135135
get_model_state_dict func in torchtitan's torchtitan/components/checkpoint.py.
136136
"""
137137
param_name_to_property = {
138-
param_name: property(lambda self: parametrization(self._parameters[param_name]))
138+
param_name: property(
139+
lambda self, pn=param_name: parametrization(self._parameters[pn])
140+
)
139141
for param_name in param_names
140142
}
141143
module_cls = type(

0 commit comments

Comments
 (0)