Skip to content

Commit 8739c23

Browse files
committed
Drop hacky llama3_init_fn and use autop init_weights feature
Relying on pytorch-labs/autoparallel#20, this lets us automatically apply a user's init_weights fn to the autoparallel model. Verified this works with `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4` ``` [rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step: 1 loss: 8.1848 memory: 1.09GiB(1.14%) tps: 77 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step: 2 loss: 8.1619 memory: 1.15GiB(1.21%) tps: 48,138 tflops: 3.46 mfu: 0.35 % [rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step: 3 loss: 8.1140 memory: 1.15GiB(1.21%) tps: 88,440 tflops: 6.36 mfu: 0.64 % [rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step: 4 loss: 8.0099 memory: 1.15GiB(1.21%) tps: 82,626 tflops: 5.94 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step: 5 loss: 7.8928 memory: 1.15GiB(1.21%) tps: 81,594 tflops: 5.87 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step: 6 loss: 7.7758 memory: 1.15GiB(1.21%) tps: 79,607 tflops: 5.72 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step: 7 loss: 7.6221 memory: 1.15GiB(1.21%) tps: 81,448 tflops: 5.86 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step: 8 loss: 7.5578 memory: 1.15GiB(1.21%) tps: 79,732 tflops: 5.73 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step: 9 loss: 7.3851 memory: 1.15GiB(1.21%) tps: 85,655 tflops: 6.16 mfu: 0.62 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10 loss: 7.3361 memory: 1.15GiB(1.21%) tps: 81,855 tflops: 5.89 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete ```
1 parent d8b9802 commit 8739c23

File tree

1 file changed

+1
-76
lines changed

1 file changed

+1
-76
lines changed

torchtitan/train.py

Lines changed: 1 addition & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -161,79 +161,6 @@ def __init__(self, job_config: JobConfig):
161161
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
162162
)
163163

164-
165-
def llama3_autoparallel_init_fn(model):
166-
# WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying
167-
# code from the llama3 init_weights functions throughout the model components, and adjusting them to use
168-
# the new FQN structures in autoparallel.
169-
# TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module
170-
def param(name):
171-
return model.get_parameter(f"params.{name}")
172-
173-
from torchtitan.models.llama3.model import precompute_freqs_cis
174-
175-
model.buffers_.get_buffer("freqs_cis").copy_(
176-
DTensor.from_local(
177-
precompute_freqs_cis(
178-
model_args.dim // model_args.n_heads,
179-
model_args.max_seq_len,
180-
model_args.rope_theta,
181-
),
182-
device_mesh=model.buffers_.get_buffer("freqs_cis").device_mesh,
183-
)
184-
)
185-
186-
torch.nn.init.normal_(param("tok_embeddings/weight"))
187-
188-
def init_layer(i):
189-
for norm in ("attention_norm", "ffn_norm"):
190-
torch.nn.init.ones_(param(f"layers/{i}/{norm}/weight"))
191-
192-
if model_args.depth_init:
193-
weight_init_std = 0.02 / (2 * (i + 1)) ** 0.5
194-
else:
195-
weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5
196-
197-
for linear in ("wq", "wk", "wv"):
198-
torch.nn.init.trunc_normal_(
199-
param(f"layers/{i}/attention/{linear}/weight"),
200-
mean=0.0,
201-
std=0.02,
202-
)
203-
torch.nn.init.trunc_normal_(
204-
param(f"layers/{i}/attention/wo/weight"),
205-
mean=0.0,
206-
std=weight_init_std,
207-
)
208-
209-
torch.nn.init.trunc_normal_(
210-
param(f"layers/{i}/feed_forward/w1/weight"), mean=0.0, std=0.02
211-
)
212-
for linear in ("w2", "w3"):
213-
torch.nn.init.trunc_normal_(
214-
param(f"layers/{i}/feed_forward/{linear}/weight"),
215-
mean=0.0,
216-
std=weight_init_std,
217-
)
218-
219-
for i in range(model_args.n_layers):
220-
init_layer(i)
221-
222-
if param("norm/weight") is not None:
223-
torch.nn.init.ones_(param("norm/weight"))
224-
225-
final_out_std = model_args.dim**-0.5
226-
cutoff_factor = 3
227-
228-
if param("output/weight") is not None:
229-
torch.nn.init.trunc_normal_(
230-
param("output/weight"),
231-
mean=0.0,
232-
std=final_out_std,
233-
a=-cutoff_factor * final_out_std,
234-
b=cutoff_factor * final_out_std,
235-
)
236-
237164
with torch.device("meta"):
238165
model = model_cls.from_model_args(model_args)
239166
# Build the collection of model converters. No-op if `model.converters` empty
@@ -343,9 +270,7 @@ def init_layer(i):
343270

344271
model.to_empty(device=init_device)
345272
with torch.no_grad():
346-
# TODO(whc) make model.init_weights work with autoparallel
347-
llama3_autoparallel_init_fn(model)
348-
# model.init_weights(buffer_device=buffer_device)
273+
model.init_weights(buffer_device=buffer_device)
349274
model.train()
350275

351276
self.model_parts = [model]

0 commit comments

Comments
 (0)