@@ -100,7 +100,7 @@ def load_model_checkpoint(folder, model):
100
100
101
101
If tensor parallel mode is isp, the saved weight is named:
102
102
- folder
103
- - model_tp{tp_rank}_wp {wp_rank}_pp{pp_rank}.pt
103
+ - model_wp {wp_rank}_pp{pp_rank}.pt
104
104
105
105
If fsdp is activated, the saved weight is named:
106
106
- folder
@@ -122,19 +122,19 @@ def load_model_checkpoint(folder, model):
122
122
fns = get_fns (folder )
123
123
124
124
# avoid ckpt misuse between FSDP and no-FSDP
125
- test_fn = list ([f for f in fns if f .startswith ("model_t" ) and not f .endswith (".md5" )]).pop ()
125
+ _start_with = "model_w" if is_using_isp () else "model_t"
126
+ test_fn = list ([f for f in fns if f .startswith (_start_with ) and not f .endswith (".md5" )]).pop ()
126
127
assert ("_dp" in test_fn and gpc .config .parallel .zero1 .fsdp ) or (
127
128
"_dp" not in test_fn and not gpc .config .parallel .zero1 .fsdp
128
129
), "FSDP model wants to load no-FSDP ckpts or reverse"
129
130
130
131
max_pp , max_wp , max_tp , max_zo = 0 , 0 , 0 , 0
131
132
for fn in fns :
132
- if fn .startswith ("model_t" ) and not fn .endswith (".md5" ):
133
+ if fn .startswith (_start_with ) and not fn .endswith (".md5" ):
133
134
segements = os .path .splitext (fn )[0 ].split ("_" )
134
135
if is_using_isp ():
135
136
max_pp = max (max_pp , int (segements [- 1 ][2 :]))
136
137
max_wp = max (max_wp , int (segements [- 2 ][2 :]))
137
- max_tp = max (max_tp , int (segements [- 3 ][2 :]))
138
138
elif gpc .config .parallel .zero1 .fsdp :
139
139
max_zo = max (max_zo , int (segements [- 1 ][2 :]))
140
140
max_pp = max (max_pp , int (segements [- 2 ][2 :]))
@@ -149,16 +149,17 @@ def load_model_checkpoint(folder, model):
149
149
assert (
150
150
wp_size == max_wp + 1
151
151
), f"The weights are save for { max_wp + 1 } parallelism, while current has { wp_size } weight parallelism"
152
- assert (
153
- tp_size == max_tp + 1
154
- ), f"The weights are save for { max_tp + 1 } parallelism, while current has { tp_size } tensor parallelism"
152
+ if not is_using_isp ():
153
+ assert (
154
+ tp_size == max_tp + 1
155
+ ), f"The weights are save for { max_tp + 1 } parallelism, while current has { tp_size } tensor parallelism"
155
156
if gpc .config .parallel .zero1 .fsdp :
156
157
assert (
157
158
dp_size == max_zo + 1
158
159
), f"The weights are save for { max_zo + 1 } FSDP shards , while current has { dp_size } FSDP shards"
159
160
160
161
if is_using_isp ():
161
- should_load_name = f"model_tp { tp_rank } _wp { wp_rank } _pp{ pp_rank } .pt"
162
+ should_load_name = f"model_wp { wp_rank } _pp{ pp_rank } .pt"
162
163
elif gpc .config .parallel .zero1 .fsdp :
163
164
should_load_name = f"model_tp{ tp_rank } _pp{ pp_rank } _dp{ dp_rank } .pt"
164
165
else :
@@ -205,7 +206,7 @@ def save_model_checkpoint(folder, model):
205
206
206
207
If tensor parallel mode is isp, the saved weight is named:
207
208
- folder
208
- - model_tp{tp_rank}_wp {wp_rank}_pp{pp_rank}.pt
209
+ - model_wp {wp_rank}_pp{pp_rank}.pt
209
210
210
211
If fsdp is activated, the saved weight is named:
211
212
- folder
@@ -243,11 +244,11 @@ def save_model_checkpoint(folder, model):
243
244
244
245
# for tensor parallel mode with isp
245
246
if is_using_isp ():
246
- if wdp_rank == 0 or dp_rank == 0 :
247
- fn = f"model_tp { tp_rank } _wp { wp_rank } _pp{ pp_rank } .pt"
247
+ if wdp_rank == 0 :
248
+ fn = f"model_wp { wp_rank } _pp{ pp_rank } .pt"
248
249
fp = os .path .join (folder , fn )
249
250
llm_save (fp , saved_obj = states )
250
- topo_fn = f"topo_tp { tp_rank } _wp { wp_rank } _pp{ pp_rank } .json"
251
+ topo_fn = f"topo_wp { wp_rank } _pp{ pp_rank } .json"
251
252
topo_fp = os .path .join (folder , topo_fn )
252
253
llm_save (topo_fp , saved_obj = topo )
253
254
else :
@@ -292,13 +293,12 @@ def load_optimizer_checkpoint(folder, optim):
292
293
"""
293
294
294
295
fns = get_fns (folder )
295
- max_tp , max_wp , max_pp , max_zero , max_dp = 0 , 0 , 0 , 0 , 0
296
+ max_tp , max_wp , max_pp , max_zero = 0 , 0 , 0 , 0
296
297
for fn in fns :
297
298
if fn .startswith ("optimizer_" ) and not fn .endswith (".md5" ):
298
299
if is_using_isp ():
299
- _ , tp , wp , pp , dp = os .path .splitext (fn )[0 ].split ("_" )
300
- max_dp = max (max_dp , int (dp [2 :]))
301
- max_tp = max (max_tp , int (tp [2 :]))
300
+ _ , wp , pp , zero = os .path .splitext (fn )[0 ].split ("_" )
301
+ max_zero = max (max_zero , int (zero [2 :]))
302
302
max_wp = max (max_wp , int (wp [2 :]))
303
303
max_pp = max (max_pp , int (pp [2 :]))
304
304
else :
@@ -311,24 +311,18 @@ def load_optimizer_checkpoint(folder, optim):
311
311
tp_size = gpc .get_world_size (ParallelMode .TENSOR )
312
312
wp_size = gpc .get_world_size (ParallelMode .WEIGHT )
313
313
pp_size = gpc .get_world_size (ParallelMode .PIPELINE )
314
- dp_size = gpc .get_world_size (ParallelMode .DATA )
315
314
316
- if is_using_isp ():
317
- assert dp_size == max_dp + 1 , (
318
- f"The optimizer states are save for { max_dp + 1 } data parallelism, "
319
- f"while current has { dp_size } data parallelism"
320
- )
321
- if not is_using_isp ():
322
- assert zero_size == max_zero + 1 , (
323
- f"The optimizer states are save for { max_zero + 1 } zero parallel, "
324
- f"while current has { zero_size } zero broadcast range."
325
- )
315
+ assert zero_size == max_zero + 1 , (
316
+ f"The optimizer states are save for { max_zero + 1 } zero parallel, "
317
+ f"while current has { zero_size } zero broadcast range."
318
+ )
326
319
assert (
327
320
pp_size == max_pp + 1
328
321
), f"The optimizer states are save for { max_pp + 1 } pipelines, while current has { pp_size } pipelines"
329
- assert (
330
- tp_size == max_tp + 1
331
- ), f"The optimizer states are save for { max_tp + 1 } parallelism, while current has { tp_size } tensor parallelism"
322
+ if not is_using_isp ():
323
+ assert (
324
+ tp_size == max_tp + 1
325
+ ), f"The optimizer states are save for { max_tp + 1 } parallelism, while current has { tp_size } tensor parallelism"
332
326
assert (
333
327
wp_size == max_wp + 1
334
328
), f"The optimizer states are save for { max_wp + 1 } parallelism, while current has { wp_size } weight parallelism"
@@ -337,9 +331,8 @@ def load_optimizer_checkpoint(folder, optim):
337
331
tp_rank = gpc .get_local_rank (ParallelMode .TENSOR )
338
332
wp_rank = gpc .get_local_rank (ParallelMode .WEIGHT )
339
333
pp_rank = gpc .get_local_rank (ParallelMode .PIPELINE )
340
- dp_rank = gpc .get_local_rank (ParallelMode .DATA )
341
334
if is_using_isp ():
342
- fp = f"optimizer_tp { tp_rank } _wp { wp_rank } _pp{ pp_rank } _dp { dp_rank } .pt"
335
+ fp = f"optimizer_wp { wp_rank } _pp{ pp_rank } _zo { zero_rank } .pt"
343
336
else :
344
337
fp = f"optimizer_tp{ tp_rank } _pp{ pp_rank } _zo{ zero_rank } .pt"
345
338
@@ -387,16 +380,17 @@ def save_optimizer_checkpoint(optim, state_path):
387
380
tp_rank = gpc .get_local_rank (ParallelMode .TENSOR )
388
381
wp_rank = gpc .get_local_rank (ParallelMode .WEIGHT )
389
382
pp_rank = gpc .get_local_rank (ParallelMode .PIPELINE )
390
- dp_rank = gpc .get_local_rank (ParallelMode .DATA )
391
383
zero_size = gpc .get_world_size (ParallelMode .ZERO1 )
392
384
tp_size = gpc .get_world_size (ParallelMode .TENSOR )
385
+ wp_size = gpc .get_world_size (ParallelMode .WEIGHT )
393
386
dp_size = gpc .get_world_size (ParallelMode .DATA )
394
387
395
388
states = optim .state_dict ()
396
389
if isinstance (optim , (HybridZeroOptimizer , HybridZeroOptimizer_v2 )):
397
390
if is_using_isp ():
398
- fp = f"optimizer_tp{ tp_rank } _wp{ wp_rank } _pp{ pp_rank } _dp{ dp_rank } .pt"
399
- llm_save (os .path .join (state_path , fp ), states )
391
+ fp = f"optimizer_wp{ wp_rank } _pp{ pp_rank } _zo{ zero_rank } .pt"
392
+ if (gpc .get_global_rank () % (tp_size * dp_size )) < zero_size * wp_size :
393
+ llm_save (os .path .join (state_path , fp ), states )
400
394
else :
401
395
fp = f"optimizer_tp{ tp_rank } _pp{ pp_rank } _zo{ zero_rank } .pt"
402
396
if (gpc .get_global_rank () % (tp_size * dp_size )) < zero_size * tp_size :
0 commit comments