Skip to content

Commit 9a4b724

Browse files
authored
Merge pull request #1 from Tridu33/openmind
Openmind to master
2 parents ef64a3b + 59fab8b commit 9a4b724

File tree

11 files changed

+124
-18
lines changed

11 files changed

+124
-18
lines changed

mindnlp/accelerate/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
# DDPCommunicationHookType,
66
# DeepSpeedPlugin,
77
# DistributedDataParallelKwargs,
8-
# DistributedType,
98
# FullyShardedDataParallelPlugin,
9+
accelerate_distributed_type,
10+
DistributedType,
1011
# GradScalerKwargs,
1112
# InitProcessGroupKwargs,
1213
# ProfileKwargs,

mindnlp/accelerate/accelerator.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
"""accelerate"""
22
import os
3+
import mindspore
4+
import numpy
5+
36
from contextlib import contextmanager
47
from typing import Optional
5-
6-
import mindspore
78
from mindspore import nn
89
from mindspore.communication import init
910

1011
from .state import AcceleratorState
1112
from .utils import (
12-
DistributedType,
1313
MindFormersPlugin,
1414
is_mindformers_available,
1515
wait_for_everyone
1616
)
17+
from .utils import DistributedType,accelerate_distributed_type
1718
from ..utils import logging
1819

1920
if is_mindformers_available():
@@ -45,7 +46,7 @@ def __init__(
4546
# init mindformers_plugin from env variables
4647
if mindformers_plugin is None:
4748
mindformers_plugin = (
48-
MindFormersPlugin() if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true" else None
49+
MindFormersPlugin() if accelerate_distributed_type == DistributedType.MINDFORMERS else None
4950
)
5051
else:
5152
os.environ["ACCELERATE_USE_MINDFORMERS"] = "true"
@@ -104,12 +105,20 @@ def prepare(self, *args):
104105
"""
105106
result = []
106107

107-
# Only support mindsormers now
108+
# Only support mindsormers and MULTI_NPU_DP now
108109
if self.distributed_type == DistributedType.MINDFORMERS:
109110
result = self._prepare_mindformers(*args)
110-
111+
elif self.distributed_type == DistributedType.MULTI_NPU_DP:
112+
result = self._prepare_data_parallel_native_minspore(*args)
111113
return result
112114

115+
def _prepare_data_parallel_native_minspore(self, *args):
116+
# initialize data parallel for native mindspore
117+
mindspore.set_context(mode=mindspore.GRAPH_MODE)
118+
mindspore.set_auto_parallel_context(parallel_mode=mindspore.ParallelMode.DATA_PARALLEL, gradients_mean=True)
119+
mindspore.communication.init()
120+
mindspore.set_seed(numpy.random.seed())
121+
113122
def _prepare_mindformers(self, *args):
114123
mindformers_plugin = self.state.mindformers_plugin
115124

mindnlp/accelerate/state.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
from contextlib import contextmanager
55
from typing import Callable, Any
66
from mindspore import communication
7+
78
try:
89
from mindspore.communication.comm_func import barrier
910
except:
1011
barrier = None
1112

1213
from .utils import (
13-
DistributedType, is_mindformers_available
14+
is_mindformers_available
1415
)
16+
from ..accelerate.utils import accelerate_distributed_type, DistributedType
1517

1618
SharedDict = dict
1719

@@ -341,11 +343,14 @@ def print(self, *args, **kwargs):
341343
print(*args, **kwargs)
342344

343345
def _prepare_backend(self):
344-
# now mindformers only
345-
if is_mindformers_available():
346+
# now mindformers and mindspore data parallel only
347+
if accelerate_distributed_type == DistributedType.MINDFORMERS and is_mindformers_available():
346348
self.backend = "hccl"
347349
self.distributed_type = DistributedType.MINDFORMERS
348-
350+
elif accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
351+
self.backend = "hccl"
352+
self.distributed_type = DistributedType.MULTI_NPU_DP
353+
349354
@num_processes.setter
350355
def num_processes(self, value):
351356
self._num_processes = value
@@ -366,10 +371,14 @@ def __init__(self, mindformers_plugin=None, **kwargs):
366371
if PartialState._shared_state:
367372
PartialState(**kwargs)
368373
self.__dict__.update(PartialState._shared_state)
369-
370-
if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true":
374+
# set distributed_type
375+
if accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
376+
self.distributed_type = DistributedType.MULTI_NPU_DP
377+
elif accelerate_distributed_type == DistributedType.MINDFORMERS:
371378
self.distributed_type = DistributedType.MINDFORMERS
372379
self.mindformers_plugin = mindformers_plugin
380+
else:
381+
self.distributed_type = DistributedType.NO
373382

374383
PartialState._shared_state["distributed_type"] = self.distributed_type
375384

mindnlp/accelerate/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""accelerate utils"""
2+
from .constants import accelerate_distributed_type
23
from .dataclasses import (
34
DistributedType,
45
MindFormersPlugin

mindnlp/accelerate/utils/constants.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""constants"""
2+
import os
3+
from .dataclasses import DistributedType
4+
5+
def detect_accelerate_distributed_type():
6+
"""
7+
detect distributed_type
8+
9+
Returns:
10+
_type_: According to the factors such as the available parallel software and hardware environment of the current system and the user-specified parallel scheme,
11+
the optimal parallel strategy is comprehensively decided in different situations.
12+
"""
13+
if os.environ.get("MULTI_NPU_DP", None) == "true":
14+
return DistributedType.MULTI_NPU_DP
15+
if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true":
16+
return DistributedType.MINDFORMERS
17+
else:
18+
return DistributedType.NO
19+
20+
accelerate_distributed_type = detect_accelerate_distributed_type()
21+

mindnlp/accelerate/utils/dataclasses.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ class DistributedType(str, enum.Enum):
1717
1818
Values:
1919
- **MINDFORMERS** -- Using mindformers
20+
- **NO** -- Not a distributed environment, just a single process.
21+
- **MULTI_NPU_DP** -- Distributed data parallel on multiple NPUs.
2022
"""
2123

24+
MULTI_NPU_DP = "MULTI_NPU_DP"
2225
MINDFORMERS = "MINDFORMERS"
2326
NO = "NO"
2427

mindnlp/dataset/load.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
from datasets import Dataset, IterableDataset, Split, Features, \
2424
DownloadConfig, DownloadMode, VerificationMode, Version
2525
from mindnlp.configs import DEFAULT_ROOT
26+
from mindspore.communication import get_rank, get_group_size
27+
from ..accelerate import DistributedType
28+
from ..accelerate.utils import accelerate_distributed_type
29+
2630

2731
class TransferIterableDataset():
2832
"""TransferDataset for Huggingface Dataset."""
@@ -331,12 +335,19 @@ def load_dataset(
331335
column_names = list(raw_ds.features.keys())
332336
source = TransferDataset(raw_ds, column_names) if isinstance(raw_ds, Dataset) \
333337
else TransferIterableDataset(raw_ds, column_names)
334-
ms_ds = GeneratorDataset(
335-
source=source,
338+
if accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
339+
ms_ds = GeneratorDataset(source=source,
340+
column_names=column_names,
341+
shuffle=shuffle,
342+
num_parallel_workers=num_proc if num_proc else 1,
343+
num_shards=get_group_size(), shard_id=get_rank())
344+
datasets_dict[key] = ms_ds
345+
else:
346+
ms_ds = GeneratorDataset(source=source,
336347
column_names=column_names,
337348
shuffle=shuffle,
338349
num_parallel_workers=num_proc if num_proc else 1)
339-
datasets_dict[key] = ms_ds
350+
datasets_dict[key] = ms_ds
340351

341352
if len(datasets_dict) == 1:
342353
return datasets_dict.popitem()[1]

mindnlp/engine/trainer/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
4646
from ...dataset import BaseMapFunction
4747
from ...utils import logging, find_labels, can_return_loss
48+
from ...accelerate.utils import DistributedType
49+
from ...accelerate.utils import accelerate_distributed_type
4850
from ...utils.import_utils import is_safetensors_available
4951
from ...transformers.modeling_utils import PreTrainedModel
5052
from ...transformers.configuration_utils import PretrainedConfig
@@ -88,6 +90,7 @@
8890
TrainerControl,
8991
TrainerState,
9092
)
93+
from ..utils import _get_learning_rate
9194

9295

9396
logger = logging.get_logger(__name__)
@@ -124,7 +127,6 @@ class Trainer:
124127
"""
125128
Trainer is a simple but feature-complete training and eval loop for MindSpore, optimized for 🤗 Transformers.
126129
"""
127-
from ..utils import _get_learning_rate
128130
def __init__(
129131
self,
130132
model: Union[PreTrainedModel, nn.Module] = None,
@@ -284,6 +286,7 @@ def __init__(
284286
# Internal variables to help with automatic batch size reduction
285287
self._train_batch_size = args.train_batch_size
286288
self._created_lr_scheduler = False
289+
self.actual_distributed_type = accelerate_distributed_type
287290

288291
def _activate_neftune(self, model):
289292
r"""
@@ -1373,6 +1376,14 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[mindspore.Tens
13731376
inputs = self._prepare_inputs(inputs)
13741377

13751378
def forward(inputs):
1379+
if accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
1380+
from mindspore.communication import get_group_size
1381+
import mindspore.ops as msops
1382+
rank_size = get_group_size()
1383+
for parameter in model.parameters():
1384+
all_reduce_sum = msops.AllReduce(msops.ReduceOp.SUM)
1385+
new_grads_mean = all_reduce_sum(parameter.grad) / rank_size
1386+
parameter.grad = new_grads_mean
13761387
return self.compute_loss(model, inputs)
13771388

13781389
if getattr(self, 'grad_fn', None) is None or self.model_reload:

mindnlp/engine/trainer/default_func.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
"""
1616
utils for trainer.
1717
"""
18-
from mindspore import ops, value_and_grad
18+
from mindspore import nn, ops, value_and_grad
1919
from mindspore.amp import all_finite
2020

2121
from mindnlp.utils import ModelOutput
22+
from ...accelerate.utils import DistributedType
23+
from ...accelerate.utils import accelerate_distributed_type
2224

2325
def get_default_forward_fn_with_loss_fn(network, loss_fn, loss_scaler):
2426
"""get default forward function with loss function"""
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
2+
3+
def test_AllReduce_mean():
4+
import numpy as np
5+
from mindspore.communication import init, get_rank, get_group_size
6+
import mindspore as ms
7+
import mindspore.nn as nn
8+
import mindspore.ops as ops
9+
10+
init()
11+
rank_size = get_group_size()
12+
class Net(nn.Cell):
13+
def __init__(self):
14+
super(Net, self).__init__()
15+
self.all_reduce_sum = ops.AllReduce(ops.ReduceOp.SUM)
16+
17+
def construct(self, x):
18+
new_grads_mean = self.all_reduce_sum(x) / rank_size
19+
new_grad = new_grads_mean
20+
return new_grad
21+
22+
rank_id_value = get_rank() # Current NPU number 0,...,7
23+
print('rank_id_value=',rank_id_value)
24+
input_x = ms.Tensor(np.array([[rank_id_value]]).astype(np.float32))
25+
print('input_x=',input_x)
26+
net = Net()
27+
output = net(input_x)
28+
print("mean:",output) # sum(0, rank_size) / rank_size
29+
30+
31+
32+
33+
if __name__ == '__main__':
34+
test_AllReduce_mean()
35+

0 commit comments

Comments
 (0)