Skip to content

Commit 7cd091c

Browse files
feat(*): re-impl embedding/head of isp version (#261)
Co-authored-by: huangting4201 <1538303371@qq.com>
1 parent 2c6df5c commit 7cd091c

22 files changed

+504
-140
lines changed

internlm/checkpoint/components.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def load_model_checkpoint(folder, model):
100100
101101
If tensor parallel mode is isp, the saved weight is named:
102102
- folder
103-
- model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt
103+
- model_wp{wp_rank}_pp{pp_rank}.pt
104104
105105
If fsdp is activated, the saved weight is named:
106106
- folder
@@ -122,19 +122,19 @@ def load_model_checkpoint(folder, model):
122122
fns = get_fns(folder)
123123

124124
# avoid ckpt misuse between FSDP and no-FSDP
125-
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
125+
_start_with = "model_w" if is_using_isp() else "model_t"
126+
test_fn = list([f for f in fns if f.startswith(_start_with) and not f.endswith(".md5")]).pop()
126127
assert ("_dp" in test_fn and gpc.config.parallel.zero1.fsdp) or (
127128
"_dp" not in test_fn and not gpc.config.parallel.zero1.fsdp
128129
), "FSDP model wants to load no-FSDP ckpts or reverse"
129130

130131
max_pp, max_wp, max_tp, max_zo = 0, 0, 0, 0
131132
for fn in fns:
132-
if fn.startswith("model_t") and not fn.endswith(".md5"):
133+
if fn.startswith(_start_with) and not fn.endswith(".md5"):
133134
segements = os.path.splitext(fn)[0].split("_")
134135
if is_using_isp():
135136
max_pp = max(max_pp, int(segements[-1][2:]))
136137
max_wp = max(max_wp, int(segements[-2][2:]))
137-
max_tp = max(max_tp, int(segements[-3][2:]))
138138
elif gpc.config.parallel.zero1.fsdp:
139139
max_zo = max(max_zo, int(segements[-1][2:]))
140140
max_pp = max(max_pp, int(segements[-2][2:]))
@@ -149,16 +149,17 @@ def load_model_checkpoint(folder, model):
149149
assert (
150150
wp_size == max_wp + 1
151151
), f"The weights are save for {max_wp+1} parallelism, while current has {wp_size} weight parallelism"
152-
assert (
153-
tp_size == max_tp + 1
154-
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
152+
if not is_using_isp():
153+
assert (
154+
tp_size == max_tp + 1
155+
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
155156
if gpc.config.parallel.zero1.fsdp:
156157
assert (
157158
dp_size == max_zo + 1
158159
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"
159160

160161
if is_using_isp():
161-
should_load_name = f"model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt"
162+
should_load_name = f"model_wp{wp_rank}_pp{pp_rank}.pt"
162163
elif gpc.config.parallel.zero1.fsdp:
163164
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
164165
else:
@@ -205,7 +206,7 @@ def save_model_checkpoint(folder, model):
205206
206207
If tensor parallel mode is isp, the saved weight is named:
207208
- folder
208-
- model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt
209+
- model_wp{wp_rank}_pp{pp_rank}.pt
209210
210211
If fsdp is activated, the saved weight is named:
211212
- folder
@@ -243,11 +244,11 @@ def save_model_checkpoint(folder, model):
243244

244245
# for tensor parallel mode with isp
245246
if is_using_isp():
246-
if wdp_rank == 0 or dp_rank == 0:
247-
fn = f"model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt"
247+
if wdp_rank == 0:
248+
fn = f"model_wp{wp_rank}_pp{pp_rank}.pt"
248249
fp = os.path.join(folder, fn)
249250
llm_save(fp, saved_obj=states)
250-
topo_fn = f"topo_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.json"
251+
topo_fn = f"topo_wp{wp_rank}_pp{pp_rank}.json"
251252
topo_fp = os.path.join(folder, topo_fn)
252253
llm_save(topo_fp, saved_obj=topo)
253254
else:
@@ -292,13 +293,12 @@ def load_optimizer_checkpoint(folder, optim):
292293
"""
293294

294295
fns = get_fns(folder)
295-
max_tp, max_wp, max_pp, max_zero, max_dp = 0, 0, 0, 0, 0
296+
max_tp, max_wp, max_pp, max_zero = 0, 0, 0, 0
296297
for fn in fns:
297298
if fn.startswith("optimizer_") and not fn.endswith(".md5"):
298299
if is_using_isp():
299-
_, tp, wp, pp, dp = os.path.splitext(fn)[0].split("_")
300-
max_dp = max(max_dp, int(dp[2:]))
301-
max_tp = max(max_tp, int(tp[2:]))
300+
_, wp, pp, zero = os.path.splitext(fn)[0].split("_")
301+
max_zero = max(max_zero, int(zero[2:]))
302302
max_wp = max(max_wp, int(wp[2:]))
303303
max_pp = max(max_pp, int(pp[2:]))
304304
else:
@@ -311,24 +311,18 @@ def load_optimizer_checkpoint(folder, optim):
311311
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
312312
wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
313313
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
314-
dp_size = gpc.get_world_size(ParallelMode.DATA)
315314

316-
if is_using_isp():
317-
assert dp_size == max_dp + 1, (
318-
f"The optimizer states are save for {max_dp+1} data parallelism, "
319-
f"while current has {dp_size} data parallelism"
320-
)
321-
if not is_using_isp():
322-
assert zero_size == max_zero + 1, (
323-
f"The optimizer states are save for {max_zero+1} zero parallel, "
324-
f"while current has {zero_size} zero broadcast range."
325-
)
315+
assert zero_size == max_zero + 1, (
316+
f"The optimizer states are save for {max_zero+1} zero parallel, "
317+
f"while current has {zero_size} zero broadcast range."
318+
)
326319
assert (
327320
pp_size == max_pp + 1
328321
), f"The optimizer states are save for {max_pp+1} pipelines, while current has {pp_size} pipelines"
329-
assert (
330-
tp_size == max_tp + 1
331-
), f"The optimizer states are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
322+
if not is_using_isp():
323+
assert (
324+
tp_size == max_tp + 1
325+
), f"The optimizer states are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
332326
assert (
333327
wp_size == max_wp + 1
334328
), f"The optimizer states are save for {max_wp+1} parallelism, while current has {wp_size} weight parallelism"
@@ -337,9 +331,8 @@ def load_optimizer_checkpoint(folder, optim):
337331
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
338332
wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
339333
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
340-
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
341334
if is_using_isp():
342-
fp = f"optimizer_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
335+
fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
343336
else:
344337
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
345338

@@ -387,16 +380,17 @@ def save_optimizer_checkpoint(optim, state_path):
387380
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
388381
wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
389382
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
390-
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
391383
zero_size = gpc.get_world_size(ParallelMode.ZERO1)
392384
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
385+
wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
393386
dp_size = gpc.get_world_size(ParallelMode.DATA)
394387

395388
states = optim.state_dict()
396389
if isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)):
397390
if is_using_isp():
398-
fp = f"optimizer_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
399-
llm_save(os.path.join(state_path, fp), states)
391+
fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
392+
if (gpc.get_global_rank() % (tp_size * dp_size)) < zero_size * wp_size:
393+
llm_save(os.path.join(state_path, fp), states)
400394
else:
401395
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
402396
if (gpc.get_global_rank() % (tp_size * dp_size)) < zero_size * tp_size:

internlm/core/context/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .parallel_context import (
22
IS_REPLICA_ZERO_PARALLEL,
3-
IS_TENSOR_DATA_PARALLEL,
43
IS_TENSOR_EXPERT_DATA_PARALLEL,
54
IS_TENSOR_ZERO_PARALLEL,
65
IS_WEIGHT_ZERO_PARALLEL,
@@ -32,7 +31,6 @@
3231
__all__ = [
3332
"Config",
3433
"IS_TENSOR_ZERO_PARALLEL",
35-
"IS_TENSOR_DATA_PARALLEL",
3634
"IS_REPLICA_ZERO_PARALLEL",
3735
"IS_WEIGHT_ZERO_PARALLEL",
3836
"IS_TENSOR_EXPERT_DATA_PARALLEL",

internlm/core/context/parallel_context.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
from .process_group_initializer import ParallelMode
2525
from .random import add_seed, get_seeds, set_mode
2626

27+
# for layernorm
2728
IS_REPLICA_ZERO_PARALLEL = "is_replica_zero_parallel"
28-
# for isp, with optimizer split in dp group
29-
IS_TENSOR_DATA_PARALLEL = "is_tensor_data_parallel"
30-
# for mtp/msp/fsp, with optimizer split in zero1 group
29+
# for mtp/msp/fsp with tensor parallel, and optimizer split in zero1 group
3130
IS_TENSOR_ZERO_PARALLEL = "is_tensor_zero_parallel"
31+
# for isp with weight parallel, and optimizer split in zero1 group
3232
IS_WEIGHT_ZERO_PARALLEL = "is_weight_zero_parallel"
33+
# for moe
3334
IS_TENSOR_EXPERT_DATA_PARALLEL = "is_tensor_expert_data_parallel"
3435

3536
logger = get_logger(__file__)
@@ -564,6 +565,7 @@ def init_parallel_groups(self):
564565
initializers.append(pgroup_initializer.Initializer_Weight_Data(*initializer_args))
565566
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
566567
initializers.append(pgroup_initializer.Initializer_Data(*initializer_args))
568+
initializers.append(pgroup_initializer.Initializer_ISP_Data(*initializer_args))
567569
if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp":
568570
initializers.append(pgroup_initializer.Initializer_Zero1_ISP(*initializer_args))
569571
else:

internlm/core/context/process_group_initializer.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class ParallelMode(Enum):
6060
# sequence parallel
6161
SEQUENCE = "sequence"
6262

63+
# real data parallel for isp
64+
ISP_DATA = "isp_data"
65+
6366
# grouped query attention
6467
GQA = "gqa"
6568

@@ -854,6 +857,66 @@ def init_dist_group(self, use_cpu: bool = False):
854857
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
855858

856859

860+
class Initializer_ISP_Data(ProcessGroupInitializer):
861+
"""A ProcessGroupInitializer for real data parallel group in isp.
862+
863+
Args:
864+
rank (int): The rank of current process.
865+
world_size (int): Size of whole communication world.
866+
weight_parallel_size (int): Size of model weight parallel.
867+
weight_data_parallel_size (int): Size of data parallel for common weight.
868+
sequence_parallel_size (int): Size of data sequence parallel.
869+
data_parallel_size (int): Size of data parallel.
870+
pipeline_parallel_size (int): Size of pipeline parallel.
871+
tensor_parallel_size (int): Size of tensor parallel.
872+
zero1_parallel_size (int): Size of zero1 parallel.
873+
nettest_parallel_size (int): Size of net testing parallel.
874+
expert_parallel_size (int): Size of expert parallel.
875+
"""
876+
877+
def __init__(self, *args, **kwargs):
878+
super().__init__(*args, **kwargs)
879+
self.isp_data_parallel_size = self.tensor_parallel_size * self.data_parallel_size
880+
self.num_isp_data_parallel_group = self.world_size // self.isp_data_parallel_size
881+
882+
assert self.world_size % self.isp_data_parallel_size == 0
883+
884+
def init_dist_group(self, use_cpu: bool = False):
885+
"""Initialize real data parallel groups for isp, and assign local_ranks and groups to each gpu.
886+
887+
Returns:
888+
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
889+
A real data parallelism's information tuple.
890+
"""
891+
local_rank = None
892+
ranks_in_group = None
893+
process_group = None
894+
cpu_group = None
895+
group_world_size = None
896+
mode = ParallelMode.ISP_DATA
897+
898+
for i in range(self.num_isp_data_parallel_group):
899+
ranks = [i * self.isp_data_parallel_size + j for j in range(self.isp_data_parallel_size)]
900+
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
901+
if use_cpu:
902+
group_cpu = (
903+
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
904+
if dist.get_backend() != "gloo"
905+
else group
906+
)
907+
else:
908+
group_cpu = None
909+
910+
if self.rank in ranks:
911+
local_rank = ranks.index(self.rank)
912+
group_world_size = len(ranks)
913+
process_group = group
914+
cpu_group = group_cpu
915+
ranks_in_group = ranks
916+
917+
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
918+
919+
857920
class Initializer_GQA(ProcessGroupInitializer):
858921
"""A ProcessGroupInitializer for allreduce kv gradients with common attention head.
859922

0 commit comments

Comments
 (0)