Skip to content

Commit 791e2df

Browse files
committed
feat: add _actual_distributed_type constant to decide parallel DistributedType, add data parallel of native mindspore to mindnlp.Trainer.base
1 parent ef64a3b commit 791e2df

File tree

12 files changed

+133
-31
lines changed

12 files changed

+133
-31
lines changed

mindnlp/accelerate/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# DDPCommunicationHookType,
66
# DeepSpeedPlugin,
77
# DistributedDataParallelKwargs,
8-
# DistributedType,
98
# FullyShardedDataParallelPlugin,
109
# GradScalerKwargs,
1110
# InitProcessGroupKwargs,

mindnlp/accelerate/accelerator.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
"""accelerate"""
22
import os
3+
import mindspore
4+
import numpy
5+
36
from contextlib import contextmanager
47
from typing import Optional
5-
6-
import mindspore
78
from mindspore import nn
89
from mindspore.communication import init
910

1011
from .state import AcceleratorState
1112
from .utils import (
12-
DistributedType,
1313
MindFormersPlugin,
1414
is_mindformers_available,
1515
wait_for_everyone
1616
)
17-
from ..utils import logging
17+
from ..utils import _actual_distributed_type, logging, DistributedType
1818

1919
if is_mindformers_available():
2020
from .utils import (
@@ -45,7 +45,7 @@ def __init__(
4545
# init mindformers_plugin from env variables
4646
if mindformers_plugin is None:
4747
mindformers_plugin = (
48-
MindFormersPlugin() if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true" else None
48+
MindFormersPlugin() if _actual_distributed_type == DistributedType.MINDFORMERS else None
4949
)
5050
else:
5151
os.environ["ACCELERATE_USE_MINDFORMERS"] = "true"
@@ -104,12 +104,20 @@ def prepare(self, *args):
104104
"""
105105
result = []
106106

107-
# Only support mindsormers now
107+
# Only support mindsormers and MULTI_NPU_DATA_PARALLEL now
108108
if self.distributed_type == DistributedType.MINDFORMERS:
109109
result = self._prepare_mindformers(*args)
110-
110+
elif self.distributed_type == DistributedType.MULTI_NPU_DATA_PARALLEL:
111+
result = self._prepare_data_parallel_native_minspore(*args)
111112
return result
112113

114+
def _prepare_data_parallel_native_minspore(self, *args):
115+
# initialize data parallel for native mindspore
116+
mindspore.set_context(mode=mindspore.GRAPH_MODE)
117+
mindspore.set_auto_parallel_context(parallel_mode=mindspore.ParallelMode.DATA_PARALLEL, gradients_mean=True)
118+
mindspore.communication.init()
119+
mindspore.set_seed(numpy.random.seed())
120+
113121
def _prepare_mindformers(self, *args):
114122
mindformers_plugin = self.state.mindformers_plugin
115123

mindnlp/accelerate/state.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
from contextlib import contextmanager
55
from typing import Callable, Any
66
from mindspore import communication
7+
78
try:
89
from mindspore.communication.comm_func import barrier
910
except:
1011
barrier = None
1112

1213
from .utils import (
13-
DistributedType, is_mindformers_available
14+
is_mindformers_available
1415
)
16+
from ..utils import _actual_distributed_type, DistributedType
1517

1618
SharedDict = dict
1719

@@ -341,11 +343,14 @@ def print(self, *args, **kwargs):
341343
print(*args, **kwargs)
342344

343345
def _prepare_backend(self):
344-
# now mindformers only
345-
if is_mindformers_available():
346+
# now mindformers and mindspore data parallel only
347+
if _actual_distributed_type == DistributedType.MINDFORMERS and is_mindformers_available():
346348
self.backend = "hccl"
347349
self.distributed_type = DistributedType.MINDFORMERS
348-
350+
elif _actual_distributed_type == DistributedType.MULTI_NPU_DATA_PARALLEL:
351+
self.backend = "hccl"
352+
self.distributed_type = DistributedType.MULTI_NPU_DATA_PARALLEL
353+
349354
@num_processes.setter
350355
def num_processes(self, value):
351356
self._num_processes = value
@@ -366,10 +371,14 @@ def __init__(self, mindformers_plugin=None, **kwargs):
366371
if PartialState._shared_state:
367372
PartialState(**kwargs)
368373
self.__dict__.update(PartialState._shared_state)
369-
370-
if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true":
374+
# set distributed_type
375+
if _actual_distributed_type == DistributedType.MULTI_NPU_DATA_PARALLEL:
376+
self.distributed_type = DistributedType.MULTI_NPU_DATA_PARALLEL
377+
elif _actual_distributed_type == DistributedType.MINDFORMERS:
371378
self.distributed_type = DistributedType.MINDFORMERS
372379
self.mindformers_plugin = mindformers_plugin
380+
else:
381+
self.distributed_type = DistributedType.NO
373382

374383
PartialState._shared_state["distributed_type"] = self.distributed_type
375384

mindnlp/accelerate/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""accelerate utils"""
22
from .dataclasses import (
3-
DistributedType,
43
MindFormersPlugin
54
)
65
from .environment import (

mindnlp/accelerate/utils/dataclasses.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,6 @@
1111
)
1212

1313

14-
class DistributedType(str, enum.Enum):
15-
"""
16-
Represents a type of distributed environment.
17-
18-
Values:
19-
- **MINDFORMERS** -- Using mindformers
20-
"""
21-
22-
MINDFORMERS = "MINDFORMERS"
23-
NO = "NO"
24-
25-
2614
@dataclass
2715
class MindFormersPlugin:
2816
"""

mindnlp/dataset/load.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from datasets import Dataset, IterableDataset, Split, Features, \
2424
DownloadConfig, DownloadMode, VerificationMode, Version
2525
from mindnlp.configs import DEFAULT_ROOT
26+
from ..utils.constants import _actual_distributed_type
27+
from ..utils.dataclasses import DistributedType
2628

2729
class TransferIterableDataset():
2830
"""TransferDataset for Huggingface Dataset."""
@@ -331,8 +333,16 @@ def load_dataset(
331333
column_names = list(raw_ds.features.keys())
332334
source = TransferDataset(raw_ds, column_names) if isinstance(raw_ds, Dataset) \
333335
else TransferIterableDataset(raw_ds, column_names)
334-
ms_ds = GeneratorDataset(
335-
source=source,
336+
ms_ds = ms_ds
337+
if _actual_distributed_type == DistributedType.MULTI_NPU_DATA_PARALLEL:
338+
from mindspore.communication import get_rank, get_group_size
339+
ms_ds = GeneratorDataset(source=source,
340+
column_names=column_names,
341+
shuffle=shuffle,
342+
num_parallel_workers=num_proc if num_proc else 1,
343+
num_shards=get_group_size(), shard_id=get_rank())
344+
else:
345+
ms_ds = GeneratorDataset(source=source,
336346
column_names=column_names,
337347
shuffle=shuffle,
338348
num_parallel_workers=num_proc if num_proc else 1)

mindnlp/engine/trainer/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
4646
from ...dataset import BaseMapFunction
4747
from ...utils import logging, find_labels, can_return_loss
48+
from ...utils.constants import _actual_distributed_type
49+
from ...utils.dataclasses import DistributedType
4850
from ...utils.import_utils import is_safetensors_available
4951
from ...transformers.modeling_utils import PreTrainedModel
5052
from ...transformers.configuration_utils import PretrainedConfig
@@ -88,6 +90,7 @@
8890
TrainerControl,
8991
TrainerState,
9092
)
93+
from ..utils import _get_learning_rate
9194

9295

9396
logger = logging.get_logger(__name__)
@@ -124,7 +127,6 @@ class Trainer:
124127
"""
125128
Trainer is a simple but feature-complete training and eval loop for MindSpore, optimized for 🤗 Transformers.
126129
"""
127-
from ..utils import _get_learning_rate
128130
def __init__(
129131
self,
130132
model: Union[PreTrainedModel, nn.Module] = None,
@@ -284,6 +286,7 @@ def __init__(
284286
# Internal variables to help with automatic batch size reduction
285287
self._train_batch_size = args.train_batch_size
286288
self._created_lr_scheduler = False
289+
self.actual_distributed_type = _actual_distributed_type
287290

288291
def _activate_neftune(self, model):
289292
r"""

mindnlp/engine/trainer/default_func.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
"""
1616
utils for trainer.
1717
"""
18-
from mindspore import ops, value_and_grad
18+
from mindspore import nn, ops, value_and_grad
1919
from mindspore.amp import all_finite
2020

2121
from mindnlp.utils import ModelOutput
22+
from ...utils.constants import _actual_distributed_type
23+
from ...utils.dataclasses import DistributedType
2224

2325
def get_default_forward_fn_with_loss_fn(network, loss_fn, loss_scaler):
2426
"""get default forward function with loss function"""
@@ -64,6 +66,9 @@ def get_default_train_step_fn(forward_fn, optimizer, loss_scaler, check_gradient
6466
def default_run_step(labels, *args, **kwargs):
6567
"""Core process of each step, including the forward propagation process and back propagation of data."""
6668
loss, grads = grad_fn(labels, *args, **kwargs)
69+
if _actual_distributed_type == DistributedType.MULTI_NPU_DATA_PARALLEL:
70+
grads = nn.DistributedGradReducer(optimizer.parameters)
71+
grad_reducer = nn.DistributedGradReducer(optimizer.parameters)
6772
loss = loss_scaler.unscale(loss)
6873
if check_gradients:
6974
is_finite = all_finite(grads)
@@ -78,6 +83,8 @@ def default_run_step(labels, *args, **kwargs):
7883
def default_run_step_for_obj_net(*args, **kwargs):
7984
"""Core process of each step, including the forward propagation process and back propagation of data."""
8085
loss, grads = grad_fn(*args, **kwargs)
86+
if _actual_distributed_type == DistributedType.MULTI_NPU_DATA_PARALLEL:
87+
grads = nn.DistributedGradReducer(optimizer.parameters)
8188
loss = loss_scaler.unscale(loss)
8289
if check_gradients:
8390
is_finite = all_finite(grads)

mindnlp/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from .download import *
2222
from .compatibility import *
2323
from .chat_template_utils import *
24+
from .dataclasses import DistributedType
25+
from .constants import _actual_distributed_type
2426
from .import_utils import requires_backends, is_mindspore_available, OptionalDependencyNotAvailable, is_sentencepiece_available, \
2527
is_tokenizers_available, direct_transformers_import, is_protobuf_available, is_safetensors_available, \
2628
is_cython_available, is_pretty_midi_available, is_essentia_available, is_librosa_available, is_scipy_available, is_pyctcdecode_available, \

mindnlp/utils/constants.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""global constants for mindnlp"""
2+
import os
3+
import psutil
4+
5+
# from .devices import _is_Ascend_npu_avaliable, _avaliable_Ascend_npus_count #TODU: if use acl
6+
from .dataclasses import DistributedType
7+
8+
9+
10+
def detect_actual_distributed_type():
11+
"""
12+
the actual_distributed_type isn't the distributed_type users wanted in the startup command, such as:
13+
1. NPU is available, specified 'msrun' ==> NPU
14+
2. mschrun specifies parallel npu, but npu is not available ==> cpu execution (reasonable)
15+
3. NPU is available, but the user python x.py start without specifying the information of the number of port cards to initialize the communication, and the actual_distributed_type is CPU
16+
.etc
17+
18+
Returns:
19+
_type_: According to the factors such as the available parallel software and hardware environment of the current system and the user-specified parallel scheme,
20+
the optimal parallel strategy is comprehensively decided in different situations.
21+
"""
22+
if os.environ.get("MULTI_NPU_DATA_PARALLEL", None) == "true":
23+
# TODO: 暂时用环境变量 MULTI_NPU_DATA_PARALLEL 作为开关,讨论是否改为这个取代 DistributedType.MINDFORMERS 作为兜底策略
24+
return DistributedType.MULTI_NPU_DATA_PARALLEL
25+
if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true":
26+
# TODO: 在原有逻辑中,没有配置环境变量的情况下默认使用 DistributedType.MINDFORMERS 。这里是否需要删掉
27+
return DistributedType.MINDFORMERS
28+
else:
29+
return DistributedType.NO
30+
31+
_actual_distributed_type = detect_actual_distributed_type()
32+

0 commit comments

Comments
 (0)