@@ -161,79 +161,6 @@ def __init__(self, job_config: JobConfig):
161
161
f"Building { self .train_spec .name } { job_config .model .flavor } with { model_args } "
162
162
)
163
163
164
-
165
- def llama3_autoparallel_init_fn (model ):
166
- # WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying
167
- # code from the llama3 init_weights functions throughout the model components, and adjusting them to use
168
- # the new FQN structures in autoparallel.
169
- # TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module
170
- def param (name ):
171
- return model .get_parameter (f"params.{ name } " )
172
-
173
- from torchtitan .models .llama3 .model import precompute_freqs_cis
174
-
175
- model .buffers_ .get_buffer ("freqs_cis" ).copy_ (
176
- DTensor .from_local (
177
- precompute_freqs_cis (
178
- model_args .dim // model_args .n_heads ,
179
- model_args .max_seq_len ,
180
- model_args .rope_theta ,
181
- ),
182
- device_mesh = model .buffers_ .get_buffer ("freqs_cis" ).device_mesh ,
183
- )
184
- )
185
-
186
- torch .nn .init .normal_ (param ("tok_embeddings/weight" ))
187
-
188
- def init_layer (i ):
189
- for norm in ("attention_norm" , "ffn_norm" ):
190
- torch .nn .init .ones_ (param (f"layers/{ i } /{ norm } /weight" ))
191
-
192
- if model_args .depth_init :
193
- weight_init_std = 0.02 / (2 * (i + 1 )) ** 0.5
194
- else :
195
- weight_init_std = 0.02 / (2 * model_args .n_layers ) ** 0.5
196
-
197
- for linear in ("wq" , "wk" , "wv" ):
198
- torch .nn .init .trunc_normal_ (
199
- param (f"layers/{ i } /attention/{ linear } /weight" ),
200
- mean = 0.0 ,
201
- std = 0.02 ,
202
- )
203
- torch .nn .init .trunc_normal_ (
204
- param (f"layers/{ i } /attention/wo/weight" ),
205
- mean = 0.0 ,
206
- std = weight_init_std ,
207
- )
208
-
209
- torch .nn .init .trunc_normal_ (
210
- param (f"layers/{ i } /feed_forward/w1/weight" ), mean = 0.0 , std = 0.02
211
- )
212
- for linear in ("w2" , "w3" ):
213
- torch .nn .init .trunc_normal_ (
214
- param (f"layers/{ i } /feed_forward/{ linear } /weight" ),
215
- mean = 0.0 ,
216
- std = weight_init_std ,
217
- )
218
-
219
- for i in range (model_args .n_layers ):
220
- init_layer (i )
221
-
222
- if param ("norm/weight" ) is not None :
223
- torch .nn .init .ones_ (param ("norm/weight" ))
224
-
225
- final_out_std = model_args .dim ** - 0.5
226
- cutoff_factor = 3
227
-
228
- if param ("output/weight" ) is not None :
229
- torch .nn .init .trunc_normal_ (
230
- param ("output/weight" ),
231
- mean = 0.0 ,
232
- std = final_out_std ,
233
- a = - cutoff_factor * final_out_std ,
234
- b = cutoff_factor * final_out_std ,
235
- )
236
-
237
164
with torch .device ("meta" ):
238
165
model = model_cls .from_model_args (model_args )
239
166
# Build the collection of model converters. No-op if `model.converters` empty
@@ -343,9 +270,7 @@ def init_layer(i):
343
270
344
271
model .to_empty (device = init_device )
345
272
with torch .no_grad ():
346
- # TODO(whc) make model.init_weights work with autoparallel
347
- llama3_autoparallel_init_fn (model )
348
- # model.init_weights(buffer_device=buffer_device)
273
+ model .init_weights (buffer_device = buffer_device )
349
274
model .train ()
350
275
351
276
self .model_parts = [model ]
0 commit comments