diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 7599c4d49cc8..b8407d212943 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1266,7 +1266,6 @@ def __post_init__(self): # use_hybrid_parallel if self.use_hybrid_parallel: - if ShardingOption.OFFLOAD in self.sharding: warnings.warn("`offload` is not supported NOW!") @@ -2327,6 +2326,17 @@ def sharding_parallel_rank(self): else: return 0 + @property + def moe_sharding_parallel_rank(self): + if self.use_hybrid_parallel: + hcg = fleet.get_hybrid_communicate_group() + if hasattr(hcg, "get_moe_sharding_parallel_group"): + return max(hcg.get_moe_sharding_parallel_group().rank, 0) + else: + return 0 + else: + return 0 + @property def tensor_parallel_rank(self): if self.use_hybrid_parallel: @@ -2405,6 +2415,8 @@ def weight_name_suffix(self): name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree)) if self.use_expert_parallel and self.expert_parallel_degree <= 1: name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)) + if self.use_expert_parallel and self.expert_parallel_degree > 1: + name.append(self._format_name("moe_sharding", self.expert_parallel_rank, self.expert_parallel_degree)) return "_".join(name) else: @@ -2534,7 +2546,9 @@ def should_save_model_state(self): return False elif self.use_hybrid_parallel: # save on dataset rank 0 - return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel) + return ( + self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel) + ) or (self.expert_parallel_degree > 1 and self.moe_sharding_parallel_rank == 0) else: return self.process_index == 0 or self.use_expert_parallel