12
12
13
13
import torch
14
14
from torch .distributed .elastic .multiprocessing .errors import record
15
-
16
15
import torchtitan .components .ft as ft
17
16
import torchtitan .protocols .train_spec as train_spec_module
18
17
from torchtitan .components .checkpoint import CheckpointManager
25
24
from torchtitan .config_manager import ConfigManager , JobConfig
26
25
from torchtitan .distributed import ParallelDims , utils as dist_utils
27
26
28
- # from torchtitan.protocols.model_converter import build_model_converters
27
+ from torchtitan .protocols .model_converter import build_model_converters
29
28
from torchtitan .tools import utils
30
29
from torchtitan .tools .logging import init_logger , logger
31
30
from torchtitan .tools .profiling import (
@@ -147,12 +146,8 @@ def __init__(self, job_config: JobConfig):
147
146
f"Building { self .train_spec .name } { job_config .model .flavor } with { model_args } "
148
147
)
149
148
150
- def model_fn ():
151
- # WHC - allow auto_p to construct the model object under its own fake_mode.
152
- # TODO: let us pass in meta model, and internally hook it up to the auto_p fake mode
153
- return model_cls .from_model_args (model_args ).cuda ()
154
149
155
- def init_fn (model ):
150
+ def llama3_autoparallel_init_fn (model ):
156
151
# WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying
157
152
# code from the llama3 init_weights functions throughout the model components, and adjusting them to use
158
153
# the new FQN structures in autoparallel.
@@ -221,11 +216,11 @@ def init_layer(i):
221
216
b = cutoff_factor * final_out_std ,
222
217
)
223
218
224
- # with torch.device("meta"):
225
- # model = model_fn( )
226
- # Build the collection of model converters. No-op if `model.converters` empty
227
- # model_converters = build_model_converters(job_config, parallel_dims)
228
- # model_converters.convert(model)
219
+ with torch .device ("meta" ):
220
+ model = model_cls . from_model_args ( model_args )
221
+ # Build the collection of model converters. No-op if `model.converters` empty
222
+ model_converters = build_model_converters (job_config , parallel_dims )
223
+ model_converters .convert (model )
229
224
230
225
# metrics logging
231
226
build_metrics_processor_fn = (
@@ -239,15 +234,15 @@ def init_layer(i):
239
234
color = self .metrics_processor .color
240
235
241
236
# calculate model size and flops per token
242
- # (
243
- # model_param_count,
244
- # self.metrics_processor.num_flops_per_token,
245
- # ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)
237
+ (
238
+ model_param_count ,
239
+ self .metrics_processor .num_flops_per_token ,
240
+ ) = model_args .get_nparams_and_flops (model , job_config .training .seq_len )
246
241
247
- # logger.info(
248
- # f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} "
249
- # f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
250
- # )
242
+ logger .info (
243
+ f"{ color .blue } Model { self .train_spec .name } { job_config .model .flavor } "
244
+ f"{ color .red } size: { model_param_count :,} total parameters{ color .reset } "
245
+ )
251
246
252
247
# move sharded model to CPU/GPU and initialize weights via DTensor
253
248
if job_config .checkpoint .create_seed_checkpoint :
@@ -325,12 +320,14 @@ def init_layer(i):
325
320
else :
326
321
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
327
322
model = self .train_spec .parallelize_fn (
328
- model_fn , init_fn , world_mesh , parallel_dims , job_config
323
+ model , world_mesh , parallel_dims , job_config
329
324
)
330
325
331
- # model.to_empty(device=init_device)
332
- # with torch.no_grad():
333
- # model.init_weights(buffer_device=buffer_device)
326
+ model .to_empty (device = init_device )
327
+ with torch .no_grad ():
328
+ # TODO(whc) make model.init_weights work with autoparallel
329
+ llama3_autoparallel_init_fn (model )
330
+ # model.init_weights(buffer_device=buffer_device)
334
331
model .train ()
335
332
336
333
self .model_parts = [model ]
@@ -362,11 +359,11 @@ def init_layer(i):
362
359
# Post optimizer step model converters hook.
363
360
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
364
361
# where it issues a single all-reduce for all parameters at once for better performance
365
- # self.optimizers.register_step_post_hook(
366
- # lambda *args, **kwargs: model_converters.post_optimizer_hook(
367
- # self.model_parts
368
- # )
369
- # )
362
+ self .optimizers .register_step_post_hook (
363
+ lambda * args , ** kwargs : model_converters .post_optimizer_hook (
364
+ self .model_parts
365
+ )
366
+ )
370
367
self .metrics_processor .optimizers = self .optimizers
371
368
372
369
# Initialize trainer states that will be saved in checkpoint.
0 commit comments