Skip to content

Commit bbfdeb4

Browse files
committed
refactor: rename daa parallel dataclasses DistributedType the same as accelerate package
1 parent 67e7bd3 commit bbfdeb4

File tree

5 files changed

+24
-19
lines changed

5 files changed

+24
-19
lines changed

mindnlp/accelerate/accelerator.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,13 @@ def prepare(self, *args):
105105
"""
106106
result = []
107107

108-
# Only support mindsormers and MULTI_NPU_DP now
108+
# Only support mindsormers and MULTI_NPU now
109109
if self.distributed_type == DistributedType.MINDFORMERS:
110110
result = self._prepare_mindformers(*args)
111-
elif self.distributed_type == DistributedType.MULTI_NPU_DP:
112-
result = self._prepare_data_parallel_native_minspore(*args)
111+
elif self.distributed_type == DistributedType.MULTI_NPU:
112+
pass # nothing prepare for data parallel
113113
return result
114114

115-
def _prepare_data_parallel_native_minspore(self, *args):
116-
# initialize data parallel for native mindspore
117-
mindspore.set_context(mode=mindspore.GRAPH_MODE)
118-
mindspore.set_auto_parallel_context(parallel_mode=mindspore.ParallelMode.DATA_PARALLEL, gradients_mean=True)
119-
mindspore.communication.init()
120-
mindspore.set_seed(numpy.random.seed())
121-
122115
def _prepare_mindformers(self, *args):
123116
mindformers_plugin = self.state.mindformers_plugin
124117

mindnlp/accelerate/state.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,9 @@ def _prepare_backend(self):
347347
if accelerate_distributed_type == DistributedType.MINDFORMERS and is_mindformers_available():
348348
self.backend = "hccl"
349349
self.distributed_type = DistributedType.MINDFORMERS
350-
elif accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
350+
elif accelerate_distributed_type == DistributedType.MULTI_NPU:
351351
self.backend = "hccl"
352-
self.distributed_type = DistributedType.MULTI_NPU_DP
352+
self.distributed_type = DistributedType.MULTI_NPU
353353

354354
@num_processes.setter
355355
def num_processes(self, value):
@@ -372,8 +372,8 @@ def __init__(self, mindformers_plugin=None, **kwargs):
372372
PartialState(**kwargs)
373373
self.__dict__.update(PartialState._shared_state)
374374
# set distributed_type
375-
if accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
376-
self.distributed_type = DistributedType.MULTI_NPU_DP
375+
if accelerate_distributed_type == DistributedType.MULTI_NPU:
376+
self.distributed_type = DistributedType.MULTI_NPU
377377
elif accelerate_distributed_type == DistributedType.MINDFORMERS:
378378
self.distributed_type = DistributedType.MINDFORMERS
379379
self.mindformers_plugin = mindformers_plugin

mindnlp/accelerate/utils/constants.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
"""constants"""
22
import os
3+
import mindspore
4+
import numpy
35
from .dataclasses import DistributedType
46

7+
8+
def _prepare_data_parallel_native_minspore():
9+
# initialize data parallel hcc backend for data_loader and Trainer API
10+
mindspore.set_context(mode=mindspore.GRAPH_MODE)
11+
mindspore.set_auto_parallel_context(parallel_mode=mindspore.ParallelMode.DATA_PARALLEL, gradients_mean=True)
12+
mindspore.communication.init()
13+
random_seed = numpy.random.randint(10000)
14+
mindspore.set_seed(random_seed)
15+
516
def detect_accelerate_distributed_type():
617
"""
718
detect distributed_type
@@ -10,8 +21,9 @@ def detect_accelerate_distributed_type():
1021
_type_: According to the factors such as the available parallel software and hardware environment of the current system and the user-specified parallel scheme,
1122
the optimal parallel strategy is comprehensively decided in different situations.
1223
"""
13-
if os.environ.get("MULTI_NPU_DP", None) == "true":
14-
return DistributedType.MULTI_NPU_DP
24+
if os.environ.get("MULTI_NPU", None) == "true":
25+
_prepare_data_parallel_native_minspore()
26+
return DistributedType.MULTI_NPU
1527
if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true":
1628
return DistributedType.MINDFORMERS
1729
else:

mindnlp/accelerate/utils/dataclasses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ class DistributedType(str, enum.Enum):
1818
Values:
1919
- **MINDFORMERS** -- Using mindformers
2020
- **NO** -- Not a distributed environment, just a single process.
21-
- **MULTI_NPU_DP** -- Distributed data parallel on multiple NPUs.
21+
- **MULTI_NPU** -- Distributed data parallel on multiple NPUs.
2222
"""
2323

24-
MULTI_NPU_DP = "MULTI_NPU_DP"
24+
MULTI_NPU = "MULTI_NPU"
2525
MINDFORMERS = "MINDFORMERS"
2626
NO = "NO"
2727

mindnlp/dataset/load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def load_dataset(
335335
column_names = list(raw_ds.features.keys())
336336
source = TransferDataset(raw_ds, column_names) if isinstance(raw_ds, Dataset) \
337337
else TransferIterableDataset(raw_ds, column_names)
338-
if accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
338+
if accelerate_distributed_type == DistributedType.MULTI_NPU:
339339
ms_ds = GeneratorDataset(source=source,
340340
column_names=column_names,
341341
shuffle=shuffle,

0 commit comments

Comments
 (0)