Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,6 @@ def __post_init__(self):

# use_hybrid_parallel
if self.use_hybrid_parallel:

if ShardingOption.OFFLOAD in self.sharding:
warnings.warn("`offload` is not supported NOW!")

Expand Down Expand Up @@ -2327,6 +2326,17 @@ def sharding_parallel_rank(self):
else:
return 0

@property
def moe_sharding_parallel_rank(self):
if self.use_hybrid_parallel:
hcg = fleet.get_hybrid_communicate_group()
if hasattr(hcg, "get_moe_sharding_parallel_group"):
return max(hcg.get_moe_sharding_parallel_group().rank, 0)
else:
return 0
else:
return 0

@property
def tensor_parallel_rank(self):
if self.use_hybrid_parallel:
Expand Down Expand Up @@ -2405,6 +2415,8 @@ def weight_name_suffix(self):
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
if self.use_expert_parallel and self.expert_parallel_degree > 1:
name.append(self._format_name("moe_sharding", self.expert_parallel_rank, self.expert_parallel_degree))
return "_".join(name)

else:
Expand Down Expand Up @@ -2534,7 +2546,9 @@ def should_save_model_state(self):
return False
elif self.use_hybrid_parallel:
# save on dataset rank 0
return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)
return (
self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)
) or (self.expert_parallel_degree > 1 and self.moe_sharding_parallel_rank == 0)
else:
return self.process_index == 0 or self.use_expert_parallel

Expand Down
Loading