Skip to content

Commit c8fb6b5

Browse files
committed
Adopt new autoparallel API with meta-init model
Allows reverting a lot of the hacks in the original integration that were caused by not creating a model obj in the train.py due to passing a model_fn builder to autop.
1 parent 42d5da6 commit c8fb6b5

File tree

3 files changed

+35
-50
lines changed

3 files changed

+35
-50
lines changed

torchtitan/components/metrics.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def log(
354354
global_max_loss: float,
355355
extra_metrics: dict[str, Any] | None = None,
356356
):
357-
# assert self.num_flops_per_token > 0, "num_flops_per_token must be set"
357+
assert self.num_flops_per_token > 0, "num_flops_per_token must be set"
358358

359359
time_delta = time.perf_counter() - self.time_last_log
360360

@@ -365,8 +365,8 @@ def log(
365365
# model FLOPS utilization
366366
# For its definition and calculation, please refer to the PaLM paper:
367367
# https://arxiv.org/abs/2204.02311
368-
# mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops
369-
# tflops = self.num_flops_per_token * tps / 1e12
368+
mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops
369+
tflops = self.num_flops_per_token * tps / 1e12
370370

371371
time_end_to_end = time_delta / self.job_config.metrics.log_freq
372372
time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times)
@@ -378,8 +378,8 @@ def log(
378378
"loss_metrics/global_avg_loss": global_avg_loss,
379379
"loss_metrics/global_max_loss": global_max_loss,
380380
"throughput(tps)": tps,
381-
# "tflops": tflops,
382-
# "mfu(%)": mfu,
381+
"tflops": tflops,
382+
"mfu(%)": mfu,
383383
"time_metrics/end_to_end(s)": time_end_to_end,
384384
"time_metrics/data_loading(s)": time_data_loading,
385385
"time_metrics/data_loading(%)": time_data_loading_pct,
@@ -403,9 +403,8 @@ def log(
403403
f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
404404
f"({device_mem_stats.max_reserved_pct:.2f}%) "
405405
f"{color.blue}tps: {round(tps):,} "
406-
# f"{color.cyan}tflops: {tflops:,.2f} "
407-
# f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
408-
f"{color.reset}"
406+
f"{color.cyan}tflops: {tflops:,.2f} "
407+
f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
409408
)
410409

411410
self.ntokens_since_last_log = 0

torchtitan/experiments/auto_parallel/parallelize_llama.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020

2121

2222
def parallelize_llama(
23-
model_fn,
24-
init_fn, # TODO(whc) hack to pass stuff to autoparallel
23+
model,
2524
world_mesh: DeviceMesh,
2625
parallel_dims: ParallelDims,
2726
job_config: JobConfig,
@@ -33,15 +32,6 @@ def parallelize_llama(
3332
NOTE: The passed-in model preferably should be on meta device. Otherwise,
3433
the model must fit on GPU or CPU memory.
3534
"""
36-
# TODO: make auto-p work with already created model object?
37-
# wherever the auto-parallel code that creates a FakeTensorMode is...
38-
# fake_mode = ...
39-
# for k, v in m.named_parameters():
40-
# # swap each param in your model with a fake tensor version
41-
# # warning - we probably need to do this before initializing the optimizer?
42-
# setattr(m, k, fake_mode.from_tensor(v))
43-
# # also do the same for named_buffers
44-
4535
def input_fn():
4636
global_batch_size = job_config.training.global_batch_size
4737
if global_batch_size < 0:
@@ -66,7 +56,7 @@ def input_fn():
6656
# model = model_fn()
6757
# return model
6858

69-
autop = AutoParallel(model_fn, input_fn, world_mesh)
59+
autop = AutoParallel(model, input_fn, world_mesh, device=world_mesh.device_type)
7060
autop.add_parameter_memory_constraint(low=None, high=None)
7161

7262
x_sharding = (Shard(0), Replicate())
@@ -83,5 +73,4 @@ def input_fn():
8373
torch._inductor.config.reorder_for_peak_memory = False
8474
parallel_mod = torch.compile(parallel_mod, fullgraph=True)
8575

86-
init_fn(parallel_mod)
8776
return parallel_mod

torchtitan/train.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import torch
1414
from torch.distributed.elastic.multiprocessing.errors import record
15-
1615
import torchtitan.components.ft as ft
1716
import torchtitan.protocols.train_spec as train_spec_module
1817
from torchtitan.components.checkpoint import CheckpointManager
@@ -25,7 +24,7 @@
2524
from torchtitan.config_manager import ConfigManager, JobConfig
2625
from torchtitan.distributed import ParallelDims, utils as dist_utils
2726

28-
# from torchtitan.protocols.model_converter import build_model_converters
27+
from torchtitan.protocols.model_converter import build_model_converters
2928
from torchtitan.tools import utils
3029
from torchtitan.tools.logging import init_logger, logger
3130
from torchtitan.tools.profiling import (
@@ -147,12 +146,8 @@ def __init__(self, job_config: JobConfig):
147146
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
148147
)
149148

150-
def model_fn():
151-
# WHC - allow auto_p to construct the model object under its own fake_mode.
152-
# TODO: let us pass in meta model, and internally hook it up to the auto_p fake mode
153-
return model_cls.from_model_args(model_args).cuda()
154149

155-
def init_fn(model):
150+
def llama3_autoparallel_init_fn(model):
156151
# WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying
157152
# code from the llama3 init_weights functions throughout the model components, and adjusting them to use
158153
# the new FQN structures in autoparallel.
@@ -221,11 +216,11 @@ def init_layer(i):
221216
b=cutoff_factor * final_out_std,
222217
)
223218

224-
# with torch.device("meta"):
225-
# model = model_fn()
226-
# Build the collection of model converters. No-op if `model.converters` empty
227-
# model_converters = build_model_converters(job_config, parallel_dims)
228-
# model_converters.convert(model)
219+
with torch.device("meta"):
220+
model = model_cls.from_model_args(model_args)
221+
# Build the collection of model converters. No-op if `model.converters` empty
222+
model_converters = build_model_converters(job_config, parallel_dims)
223+
model_converters.convert(model)
229224

230225
# metrics logging
231226
build_metrics_processor_fn = (
@@ -239,15 +234,15 @@ def init_layer(i):
239234
color = self.metrics_processor.color
240235

241236
# calculate model size and flops per token
242-
# (
243-
# model_param_count,
244-
# self.metrics_processor.num_flops_per_token,
245-
# ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)
237+
(
238+
model_param_count,
239+
self.metrics_processor.num_flops_per_token,
240+
) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)
246241

247-
# logger.info(
248-
# f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} "
249-
# f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
250-
# )
242+
logger.info(
243+
f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} "
244+
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
245+
)
251246

252247
# move sharded model to CPU/GPU and initialize weights via DTensor
253248
if job_config.checkpoint.create_seed_checkpoint:
@@ -325,12 +320,14 @@ def init_layer(i):
325320
else:
326321
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
327322
model = self.train_spec.parallelize_fn(
328-
model_fn, init_fn, world_mesh, parallel_dims, job_config
323+
model, world_mesh, parallel_dims, job_config
329324
)
330325

331-
# model.to_empty(device=init_device)
332-
# with torch.no_grad():
333-
# model.init_weights(buffer_device=buffer_device)
326+
model.to_empty(device=init_device)
327+
with torch.no_grad():
328+
# TODO(whc) make model.init_weights work with autoparallel
329+
llama3_autoparallel_init_fn(model)
330+
# model.init_weights(buffer_device=buffer_device)
334331
model.train()
335332

336333
self.model_parts = [model]
@@ -362,11 +359,11 @@ def init_layer(i):
362359
# Post optimizer step model converters hook.
363360
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
364361
# where it issues a single all-reduce for all parameters at once for better performance
365-
# self.optimizers.register_step_post_hook(
366-
# lambda *args, **kwargs: model_converters.post_optimizer_hook(
367-
# self.model_parts
368-
# )
369-
# )
362+
self.optimizers.register_step_post_hook(
363+
lambda *args, **kwargs: model_converters.post_optimizer_hook(
364+
self.model_parts
365+
)
366+
)
370367
self.metrics_processor.optimizers = self.optimizers
371368

372369
# Initialize trainer states that will be saved in checkpoint.

0 commit comments

Comments
 (0)