@@ -67,7 +67,10 @@ def init_multi_module(self) -> nn.Module:
67
67
return module
68
68
69
69
def init_transformer (
70
- self , weight_tying : bool , dtype : Optional [torch .dtype ] = None
70
+ self ,
71
+ weight_tying : bool ,
72
+ dtype : Optional [torch .dtype ] = None ,
73
+ requires_grad : bool = True ,
71
74
) -> nn .Module :
72
75
torch .manual_seed (42 )
73
76
args = ModelArgs (
@@ -81,6 +84,13 @@ def init_transformer(
81
84
module = Transformer (args ).cuda ()
82
85
if dtype is not None :
83
86
module = module .to (dtype = dtype )
87
+
88
+ # if requires_grad=False, just set requires_grad to False
89
+ # in the first layer to ensure we still train some params.
90
+ if requires_grad is False :
91
+ for param in module .layers [0 ].parameters ():
92
+ param .requires_grad = requires_grad
93
+
84
94
self .broadcast_module (module )
85
95
return module
86
96
@@ -107,6 +117,7 @@ def test_transformer_parity(self):
107
117
],
108
118
"compile_transformer_block" : [False , True ],
109
119
"dtype" : [torch .float32 , torch .bfloat16 ],
120
+ "requires_grad" : [True , False ],
110
121
},
111
122
self ._test_transformer_parity ,
112
123
)
@@ -117,6 +128,7 @@ def _test_transformer_parity(
117
128
precompute : bool ,
118
129
scaling_type_weight : ScalingType ,
119
130
compile_transformer_block : bool ,
131
+ requires_grad : bool ,
120
132
dtype : Optional [torch .dtype ] = None ,
121
133
):
122
134
if not enable_fsdp_float8_all_gather and precompute :
@@ -127,7 +139,10 @@ def _test_transformer_parity(
127
139
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
128
140
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
129
141
weight_tying = not enable_fsdp_float8_all_gather
130
- module = self .init_transformer (weight_tying = weight_tying , dtype = dtype )
142
+ module = self .init_transformer (
143
+ weight_tying = weight_tying , dtype = dtype , requires_grad = requires_grad
144
+ )
145
+
131
146
ref_module = copy .deepcopy (module )
132
147
float8_linear_config1 = Float8LinearConfig (
133
148
cast_config_weight = CastConfig (scaling_type = scaling_type_weight ),
0 commit comments