Skip to content

Commit c3c4e2d

Browse files
authored
feat: add data parallel of native mindspore to mindnlp.Trainer.base (#1852)
1 parent 77e97d4 commit c3c4e2d

File tree

13 files changed

+266
-40
lines changed

13 files changed

+266
-40
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
"""
4+
unset MULTI_NPU && python bert_imdb_finetune_cpu_mindnlp_trainer_npus_same.py
5+
bash bert_imdb_finetune_npu_mindnlp_trainer.sh
6+
"""
7+
8+
import mindspore
9+
from mindspore.dataset import transforms
10+
from mindnlp.engine import Trainer
11+
from mindnlp.dataset import load_dataset
12+
13+
from mindnlp.accelerate.utils.constants import accelerate_distributed_type
14+
from mindnlp.accelerate.utils.dataclasses import DistributedType
15+
16+
def main():
17+
"""demo
18+
19+
Returns:
20+
desc: _description_
21+
"""
22+
imdb_ds = load_dataset('imdb', split=['train', 'test'])
23+
imdb_train = imdb_ds['train']
24+
imdb_train.get_dataset_size()
25+
26+
from mindnlp.transformers import AutoTokenizer
27+
# tokenizer
28+
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
29+
30+
def process_dataset(dataset, tokenizer, max_seq_len=256, batch_size=32, shuffle=False):
31+
is_ascend = mindspore.get_context('device_target') == 'Ascend'
32+
def tokenize(text):
33+
if is_ascend:
34+
tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
35+
else:
36+
tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
37+
return tokenized['input_ids'], tokenized['token_type_ids'], tokenized['attention_mask']
38+
39+
if shuffle:
40+
dataset = dataset.shuffle(batch_size)
41+
42+
# map dataset
43+
dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'token_type_ids', 'attention_mask'])
44+
dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
45+
# batch dataset
46+
if is_ascend:
47+
dataset = dataset.batch(batch_size)
48+
else:
49+
dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
50+
'token_type_ids': (None, 0),
51+
'attention_mask': (None, 0)})
52+
return dataset
53+
54+
55+
dataset_train = process_dataset(imdb_train, tokenizer, shuffle=True)
56+
57+
next(dataset_train.create_tuple_iterator())
58+
59+
from mindnlp.transformers import AutoModelForSequenceClassification
60+
61+
# set bert config and define parameters for training
62+
model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', num_labels=2)
63+
64+
from mindnlp.engine import TrainingArguments
65+
66+
training_args = TrainingArguments(
67+
output_dir="bert_imdb_finetune_cpu",
68+
save_strategy="epoch",
69+
logging_strategy="epoch",
70+
num_train_epochs=2.0,
71+
learning_rate=2e-5
72+
)
73+
training_args = training_args.set_optimizer(name="adamw", beta1=0.8) # 手动指定优化器,OptimizerNames.SGD
74+
75+
trainer = Trainer(
76+
model=model,
77+
args=training_args,
78+
train_dataset=dataset_train,
79+
)
80+
print("Start training")
81+
trainer.train()
82+
83+
if __name__ == '__main__':
84+
main()
85+
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/bin/bash
2+
3+
echo "=========================================="
4+
echo "Please run the script as: "
5+
echo "bash bert_imdb_finetune_npu_mindnlp_trainer.sh"
6+
echo "==========================================="
7+
8+
EXEC_PATH=$(pwd)
9+
if [ ! -d "${EXEC_PATH}/data" ]; then
10+
if [ ! -f "${EXEC_PATH}/emotion_detection.tar.gz" ]; then
11+
wget wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
12+
fi
13+
tar xvf emotion_detection.tar.gz
14+
fi
15+
export DATA_PATH=${EXEC_PATH}/data/
16+
17+
rm -rf bert_imdb_finetune_cpu_mindnlp_trainer_npus_same
18+
mkdir bert_imdb_finetune_cpu_mindnlp_trainer_npus_same
19+
echo "start training"
20+
21+
export MULTI_NPU="true"
22+
export ASCEND_SLOG_PRINT_TO_STDOUT=1
23+
24+
msrun --worker_num=2 --local_worker_num=2 --master_port=8121 \
25+
--log_dir=bert_imdb_finetune_cpu_mindnlp_trainer_npus_same --join=True \
26+
--cluster_time_out=10 bert_imdb_finetune_cpu_mindnlp_trainer_npus_same.py

mindnlp/accelerate/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
# DDPCommunicationHookType,
66
# DeepSpeedPlugin,
77
# DistributedDataParallelKwargs,
8-
# DistributedType,
98
# FullyShardedDataParallelPlugin,
9+
accelerate_distributed_type,
10+
DistributedType,
1011
# GradScalerKwargs,
1112
# InitProcessGroupKwargs,
1213
# ProfileKwargs,

mindnlp/accelerate/accelerator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
from .state import AcceleratorState
1111
from .utils import (
12-
DistributedType,
1312
MindFormersPlugin,
1413
is_mindformers_available,
1514
wait_for_everyone
1615
)
16+
from .utils import DistributedType,accelerate_distributed_type
1717
from ..utils import logging
1818

1919
if is_mindformers_available():
@@ -45,7 +45,7 @@ def __init__(
4545
# init mindformers_plugin from env variables
4646
if mindformers_plugin is None:
4747
mindformers_plugin = (
48-
MindFormersPlugin() if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true" else None
48+
MindFormersPlugin() if accelerate_distributed_type == DistributedType.MINDFORMERS else None
4949
)
5050
else:
5151
os.environ["ACCELERATE_USE_MINDFORMERS"] = "true"
@@ -104,10 +104,11 @@ def prepare(self, *args):
104104
"""
105105
result = []
106106

107-
# Only support mindsormers now
107+
# Only support mindsormers and MULTI_NPU now
108108
if self.distributed_type == DistributedType.MINDFORMERS:
109109
result = self._prepare_mindformers(*args)
110-
110+
elif self.distributed_type == DistributedType.MULTI_NPU:
111+
pass # nothing prepare for data parallel
111112
return result
112113

113114
def _prepare_mindformers(self, *args):

mindnlp/accelerate/state.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
"""accelerate state"""
2-
import os
32
from functools import partial
43
from contextlib import contextmanager
54
from typing import Callable, Any
65
from mindspore import communication
6+
77
try:
88
from mindspore.communication.comm_func import barrier
99
except:
1010
barrier = None
1111

1212
from .utils import (
13-
DistributedType, is_mindformers_available
13+
is_mindformers_available
1414
)
15+
from ..accelerate.utils import accelerate_distributed_type, DistributedType
1516

1617
SharedDict = dict
1718

@@ -341,10 +342,13 @@ def print(self, *args, **kwargs):
341342
print(*args, **kwargs)
342343

343344
def _prepare_backend(self):
344-
# now mindformers only
345-
if is_mindformers_available():
345+
# now mindformers and mindspore data parallel only
346+
if accelerate_distributed_type == DistributedType.MINDFORMERS and is_mindformers_available():
346347
self.backend = "hccl"
347348
self.distributed_type = DistributedType.MINDFORMERS
349+
elif accelerate_distributed_type == DistributedType.MULTI_NPU:
350+
self.backend = "hccl"
351+
self.distributed_type = DistributedType.MULTI_NPU
348352

349353
@num_processes.setter
350354
def num_processes(self, value):
@@ -366,10 +370,14 @@ def __init__(self, mindformers_plugin=None, **kwargs):
366370
if PartialState._shared_state:
367371
PartialState(**kwargs)
368372
self.__dict__.update(PartialState._shared_state)
369-
370-
if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true":
373+
# set distributed_type
374+
if accelerate_distributed_type == DistributedType.MULTI_NPU:
375+
self.distributed_type = DistributedType.MULTI_NPU
376+
elif accelerate_distributed_type == DistributedType.MINDFORMERS:
371377
self.distributed_type = DistributedType.MINDFORMERS
372378
self.mindformers_plugin = mindformers_plugin
379+
else:
380+
self.distributed_type = DistributedType.NO
373381

374382
PartialState._shared_state["distributed_type"] = self.distributed_type
375383

mindnlp/accelerate/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""accelerate utils"""
2+
from .constants import accelerate_distributed_type
23
from .dataclasses import (
34
DistributedType,
45
MindFormersPlugin

mindnlp/accelerate/utils/constants.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""constants"""
2+
import os
3+
import mindspore
4+
import numpy
5+
from .dataclasses import DistributedType
6+
7+
8+
_random_seed = numpy.random.randint(1000)
9+
10+
11+
def _prepare_data_parallel_native_minspore():
12+
# initialize data parallel hcc backend for data_loader and Trainer API
13+
mindspore.set_auto_parallel_context(parallel_mode=mindspore.ParallelMode.DATA_PARALLEL, gradients_mean=True)
14+
mindspore.communication.init()
15+
mindspore.set_seed(_random_seed)
16+
17+
18+
def detect_accelerate_distributed_type():
19+
"""
20+
detect distributed_type
21+
22+
Returns:
23+
_type_: According to the factors such as the available parallel software and hardware environment of the current system and the user-specified parallel scheme,
24+
the optimal parallel strategy is comprehensively decided in different situations.
25+
"""
26+
if os.environ.get("MULTI_NPU", None) == "true":
27+
_prepare_data_parallel_native_minspore()
28+
return DistributedType.MULTI_NPU
29+
if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true":
30+
return DistributedType.MINDFORMERS
31+
else:
32+
return DistributedType.NO
33+
34+
accelerate_distributed_type = detect_accelerate_distributed_type()

mindnlp/accelerate/utils/dataclasses.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ class DistributedType(str, enum.Enum):
1717
1818
Values:
1919
- **MINDFORMERS** -- Using mindformers
20+
- **NO** -- Not a distributed environment, just a single process.
21+
- **MULTI_NPU** -- Distributed data parallel on multiple NPUs.
2022
"""
2123

24+
MULTI_NPU = "MULTI_NPU"
2225
MINDFORMERS = "MINDFORMERS"
2326
NO = "NO"
2427

mindnlp/dataset/load.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@
1818
"""
1919
import os
2020
from typing import Union, Optional, Dict, Sequence, Mapping
21-
from mindspore.dataset import GeneratorDataset
2221
from datasets import load_dataset as hf_load
2322
from datasets import Dataset, IterableDataset, Split, Features, \
2423
DownloadConfig, DownloadMode, VerificationMode, Version
24+
from mindspore.dataset import GeneratorDataset
25+
from mindspore.communication import get_rank, get_group_size
2526
from mindnlp.configs import DEFAULT_ROOT
27+
from ..accelerate import DistributedType
28+
from ..accelerate.utils import accelerate_distributed_type
29+
2630

2731
class TransferIterableDataset():
2832
"""TransferDataset for Huggingface Dataset."""
@@ -331,12 +335,19 @@ def load_dataset(
331335
column_names = list(raw_ds.features.keys())
332336
source = TransferDataset(raw_ds, column_names) if isinstance(raw_ds, Dataset) \
333337
else TransferIterableDataset(raw_ds, column_names)
334-
ms_ds = GeneratorDataset(
335-
source=source,
338+
if accelerate_distributed_type == DistributedType.MULTI_NPU:
339+
ms_ds = GeneratorDataset(source=source,
340+
column_names=column_names,
341+
shuffle=shuffle,
342+
num_parallel_workers=num_proc if num_proc else 1,
343+
num_shards=get_group_size(), shard_id=get_rank())
344+
datasets_dict[key] = ms_ds
345+
else:
346+
ms_ds = GeneratorDataset(source=source,
336347
column_names=column_names,
337348
shuffle=shuffle,
338349
num_parallel_workers=num_proc if num_proc else 1)
339-
datasets_dict[key] = ms_ds
350+
datasets_dict[key] = ms_ds
340351

341352
if len(datasets_dict) == 1:
342353
return datasets_dict.popitem()[1]

mindnlp/engine/trainer/base.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
4646
from ...dataset import BaseMapFunction
4747
from ...utils import logging, find_labels, can_return_loss
48+
from ...accelerate.utils import DistributedType
49+
from ...accelerate.utils import accelerate_distributed_type
4850
from ...utils.import_utils import is_safetensors_available
4951
from ...transformers.modeling_utils import PreTrainedModel
5052
from ...transformers.configuration_utils import PretrainedConfig
@@ -124,7 +126,6 @@ class Trainer:
124126
"""
125127
Trainer is a simple but feature-complete training and eval loop for MindSpore, optimized for 🤗 Transformers.
126128
"""
127-
from ..utils import _get_learning_rate
128129
def __init__(
129130
self,
130131
model: Union[PreTrainedModel, nn.Module] = None,
@@ -284,6 +285,30 @@ def __init__(
284285
# Internal variables to help with automatic batch size reduction
285286
self._train_batch_size = args.train_batch_size
286287
self._created_lr_scheduler = False
288+
self.actual_distributed_type = accelerate_distributed_type
289+
290+
291+
def _get_learning_rate(self):
292+
r"""
293+
This function retrieves the learning rate used by the optimizer.
294+
295+
Args:
296+
self: An instance of the class containing the optimizer and learning rate scheduler.
297+
298+
Returns:
299+
The learning rate value (float) used by the optimizer.
300+
301+
Raises:
302+
None.
303+
"""
304+
if isinstance(self.lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau):
305+
last_lr = self.optimizer.param_groups[0]["lr"]
306+
else:
307+
last_lr = self.lr_scheduler.get_last_lr()[0]
308+
if ops.is_tensor(last_lr):
309+
last_lr = last_lr.item()
310+
return last_lr
311+
287312

288313
def _activate_neftune(self, model):
289314
r"""
@@ -1133,6 +1158,7 @@ def _inner_training_loop(
11331158
model.parameters(),
11341159
args.max_grad_norm,
11351160
)
1161+
11361162
# Optimizer step
11371163
self.optimizer.step()
11381164

@@ -1351,6 +1377,20 @@ def _prepare_inputs(self, inputs: Dict[str, Union[mindspore.Tensor, Any]]) -> Di
13511377

13521378
return inputs
13531379

1380+
1381+
def update_gradient_by_distributed_type(self, model: nn.Module) -> None:
1382+
"""update gradient by distributed_type"""
1383+
if accelerate_distributed_type == DistributedType.NO:
1384+
return
1385+
if accelerate_distributed_type == DistributedType.MULTI_NPU:
1386+
from mindspore.communication import get_group_size
1387+
from mindspore.communication.comm_func import all_reduce
1388+
rank_size = get_group_size()
1389+
for parameter in model.parameters():
1390+
new_grads_mean = all_reduce(parameter.grad) / rank_size
1391+
parameter.grad = new_grads_mean
1392+
1393+
13541394
def training_step(self, model: nn.Module, inputs: Dict[str, Union[mindspore.Tensor, Any]]) -> Tuple[List[mindspore.Tensor], mindspore.Tensor]:
13551395
"""
13561396
Perform a training step on a batch of inputs.
@@ -1382,7 +1422,7 @@ def forward(inputs):
13821422
self.grad_fn = value_and_grad(forward, weights, attach_grads=True)
13831423

13841424
loss = self.grad_fn(inputs)
1385-
1425+
self.update_gradient_by_distributed_type(model)
13861426
return loss / self.args.gradient_accumulation_steps
13871427

13881428
def compute_loss(self, model, inputs, return_outputs=False):

0 commit comments

Comments
 (0)