Skip to content

Commit 42d5da6

Browse files
committed
Hack an init_fn for llama3 and observe loss decreasing with autoparallel
""" [rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step: 1 loss: 8.1880 memory: 4.88GiB(6.16%) tps: 28 [rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step: 2 loss: 8.1610 memory: 4.90GiB(6.20%) tps: 13,785 [rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step: 3 loss: 8.0871 memory: 4.90GiB(6.20%) tps: 14,006 [rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step: 4 loss: 7.9516 memory: 4.90GiB(6.20%) tps: 13,770 [rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step: 5 loss: 7.8552 memory: 4.90GiB(6.20%) tps: 13,959 [rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step: 6 loss: 7.7732 memory: 4.90GiB(6.20%) tps: 13,859 [rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step: 7 loss: 7.6987 memory: 4.90GiB(6.20%) tps: 13,664 [rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step: 8 loss: 7.6779 memory: 4.90GiB(6.20%) tps: 13,985 [rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step: 9 loss: 7.6043 memory: 4.90GiB(6.20%) tps: 13,962 [rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10 loss: 7.5778 memory: 4.90GiB(6.20%) tps: 13,891 """
1 parent 5e93263 commit 42d5da6

File tree

3 files changed

+89
-6
lines changed

3 files changed

+89
-6
lines changed

torchtitan/components/metrics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def log(
405405
f"{color.blue}tps: {round(tps):,} "
406406
# f"{color.cyan}tflops: {tflops:,.2f} "
407407
# f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
408+
f"{color.reset}"
408409
)
409410

410411
self.ntokens_since_last_log = 0

torchtitan/experiments/auto_parallel/parallelize_llama.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
def parallelize_llama(
2323
model_fn,
24+
init_fn, # TODO(whc) hack to pass stuff to autoparallel
2425
world_mesh: DeviceMesh,
2526
parallel_dims: ParallelDims,
2627
job_config: JobConfig,
@@ -32,9 +33,14 @@ def parallelize_llama(
3233
NOTE: The passed-in model preferably should be on meta device. Otherwise,
3334
the model must fit on GPU or CPU memory.
3435
"""
35-
# model = model.to_empty(device="cuda")
36-
3736
# TODO: make auto-p work with already created model object?
37+
# wherever the auto-parallel code that creates a FakeTensorMode is...
38+
# fake_mode = ...
39+
# for k, v in m.named_parameters():
40+
# # swap each param in your model with a fake tensor version
41+
# # warning - we probably need to do this before initializing the optimizer?
42+
# setattr(m, k, fake_mode.from_tensor(v))
43+
# # also do the same for named_buffers
3844

3945
def input_fn():
4046
global_batch_size = job_config.training.global_batch_size
@@ -56,6 +62,10 @@ def input_fn():
5662
assert parallel_dims.cp_enabled is False, "CP not supported yet"
5763
assert parallel_dims.pp_enabled is False, "PP not supported yet"
5864

65+
# bail out
66+
# model = model_fn()
67+
# return model
68+
5969
autop = AutoParallel(model_fn, input_fn, world_mesh)
6070
autop.add_parameter_memory_constraint(low=None, high=None)
6171

@@ -73,4 +83,5 @@ def input_fn():
7383
torch._inductor.config.reorder_for_peak_memory = False
7484
parallel_mod = torch.compile(parallel_mod, fullgraph=True)
7585

86+
init_fn(parallel_mod)
7687
return parallel_mod

torchtitan/train.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,79 @@ def __init__(self, job_config: JobConfig):
148148
)
149149

150150
def model_fn():
151+
# WHC - allow auto_p to construct the model object under its own fake_mode.
152+
# TODO: let us pass in meta model, and internally hook it up to the auto_p fake mode
151153
return model_cls.from_model_args(model_args).cuda()
152154

155+
def init_fn(model):
156+
# WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying
157+
# code from the llama3 init_weights functions throughout the model components, and adjusting them to use
158+
# the new FQN structures in autoparallel.
159+
# TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module
160+
def param(name):
161+
return model.get_parameter(f"params.{name}")
162+
163+
from torchtitan.models.llama3.model import precompute_freqs_cis
164+
165+
model.buffers_.get_buffer("freqs_cis").copy_(
166+
precompute_freqs_cis(
167+
model_args.dim // model_args.n_heads,
168+
model_args.max_seq_len,
169+
model_args.rope_theta,
170+
)
171+
)
172+
173+
torch.nn.init.normal_(param("tok_embeddings/weight"))
174+
175+
def init_layer(i):
176+
for norm in ("attention_norm", "ffn_norm"):
177+
torch.nn.init.ones_(param(f"layers/{i}/{norm}/weight"))
178+
179+
if model_args.depth_init:
180+
weight_init_std = 0.02 / (2 * (i + 1)) ** 0.5
181+
else:
182+
weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5
183+
184+
for linear in ("wq", "wk", "wv"):
185+
torch.nn.init.trunc_normal_(
186+
param(f"layers/{i}/attention/{linear}/weight"),
187+
mean=0.0,
188+
std=0.02,
189+
)
190+
torch.nn.init.trunc_normal_(
191+
param(f"layers/{i}/attention/wo/weight"),
192+
mean=0.0,
193+
std=weight_init_std,
194+
)
195+
196+
torch.nn.init.trunc_normal_(
197+
param(f"layers/{i}/feed_forward/w1/weight"), mean=0.0, std=0.02
198+
)
199+
for linear in ("w2", "w3"):
200+
torch.nn.init.trunc_normal_(
201+
param(f"layers/{i}/feed_forward/{linear}/weight"),
202+
mean=0.0,
203+
std=weight_init_std,
204+
)
205+
206+
for i in range(model_args.n_layers):
207+
init_layer(i)
208+
209+
if param("norm/weight") is not None:
210+
torch.nn.init.ones_(param("norm/weight"))
211+
212+
final_out_std = model_args.dim**-0.5
213+
cutoff_factor = 3
214+
215+
if param("output/weight") is not None:
216+
torch.nn.init.trunc_normal_(
217+
param("output/weight"),
218+
mean=0.0,
219+
std=final_out_std,
220+
a=-cutoff_factor * final_out_std,
221+
b=cutoff_factor * final_out_std,
222+
)
223+
153224
# with torch.device("meta"):
154225
# model = model_fn()
155226
# Build the collection of model converters. No-op if `model.converters` empty
@@ -254,12 +325,12 @@ def model_fn():
254325
else:
255326
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
256327
model = self.train_spec.parallelize_fn(
257-
model_fn, world_mesh, parallel_dims, job_config
328+
model_fn, init_fn, world_mesh, parallel_dims, job_config
258329
)
259330

260-
model.to_empty(device=init_device)
261-
with torch.no_grad():
262-
model.init_weights(buffer_device=buffer_device)
331+
# model.to_empty(device=init_device)
332+
# with torch.no_grad():
333+
# model.init_weights(buffer_device=buffer_device)
263334
model.train()
264335

265336
self.model_parts = [model]

0 commit comments

Comments
 (0)