@@ -148,8 +148,79 @@ def __init__(self, job_config: JobConfig):
148
148
)
149
149
150
150
def model_fn ():
151
+ # WHC - allow auto_p to construct the model object under its own fake_mode.
152
+ # TODO: let us pass in meta model, and internally hook it up to the auto_p fake mode
151
153
return model_cls .from_model_args (model_args ).cuda ()
152
154
155
+ def init_fn (model ):
156
+ # WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying
157
+ # code from the llama3 init_weights functions throughout the model components, and adjusting them to use
158
+ # the new FQN structures in autoparallel.
159
+ # TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module
160
+ def param (name ):
161
+ return model .get_parameter (f"params.{ name } " )
162
+
163
+ from torchtitan .models .llama3 .model import precompute_freqs_cis
164
+
165
+ model .buffers_ .get_buffer ("freqs_cis" ).copy_ (
166
+ precompute_freqs_cis (
167
+ model_args .dim // model_args .n_heads ,
168
+ model_args .max_seq_len ,
169
+ model_args .rope_theta ,
170
+ )
171
+ )
172
+
173
+ torch .nn .init .normal_ (param ("tok_embeddings/weight" ))
174
+
175
+ def init_layer (i ):
176
+ for norm in ("attention_norm" , "ffn_norm" ):
177
+ torch .nn .init .ones_ (param (f"layers/{ i } /{ norm } /weight" ))
178
+
179
+ if model_args .depth_init :
180
+ weight_init_std = 0.02 / (2 * (i + 1 )) ** 0.5
181
+ else :
182
+ weight_init_std = 0.02 / (2 * model_args .n_layers ) ** 0.5
183
+
184
+ for linear in ("wq" , "wk" , "wv" ):
185
+ torch .nn .init .trunc_normal_ (
186
+ param (f"layers/{ i } /attention/{ linear } /weight" ),
187
+ mean = 0.0 ,
188
+ std = 0.02 ,
189
+ )
190
+ torch .nn .init .trunc_normal_ (
191
+ param (f"layers/{ i } /attention/wo/weight" ),
192
+ mean = 0.0 ,
193
+ std = weight_init_std ,
194
+ )
195
+
196
+ torch .nn .init .trunc_normal_ (
197
+ param (f"layers/{ i } /feed_forward/w1/weight" ), mean = 0.0 , std = 0.02
198
+ )
199
+ for linear in ("w2" , "w3" ):
200
+ torch .nn .init .trunc_normal_ (
201
+ param (f"layers/{ i } /feed_forward/{ linear } /weight" ),
202
+ mean = 0.0 ,
203
+ std = weight_init_std ,
204
+ )
205
+
206
+ for i in range (model_args .n_layers ):
207
+ init_layer (i )
208
+
209
+ if param ("norm/weight" ) is not None :
210
+ torch .nn .init .ones_ (param ("norm/weight" ))
211
+
212
+ final_out_std = model_args .dim ** - 0.5
213
+ cutoff_factor = 3
214
+
215
+ if param ("output/weight" ) is not None :
216
+ torch .nn .init .trunc_normal_ (
217
+ param ("output/weight" ),
218
+ mean = 0.0 ,
219
+ std = final_out_std ,
220
+ a = - cutoff_factor * final_out_std ,
221
+ b = cutoff_factor * final_out_std ,
222
+ )
223
+
153
224
# with torch.device("meta"):
154
225
# model = model_fn()
155
226
# Build the collection of model converters. No-op if `model.converters` empty
@@ -254,12 +325,12 @@ def model_fn():
254
325
else :
255
326
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
256
327
model = self .train_spec .parallelize_fn (
257
- model_fn , world_mesh , parallel_dims , job_config
328
+ model_fn , init_fn , world_mesh , parallel_dims , job_config
258
329
)
259
330
260
- model .to_empty (device = init_device )
261
- with torch .no_grad ():
262
- model .init_weights (buffer_device = buffer_device )
331
+ # model.to_empty(device=init_device)
332
+ # with torch.no_grad():
333
+ # model.init_weights(buffer_device=buffer_device)
263
334
model .train ()
264
335
265
336
self .model_parts = [model ]
0 commit comments