Skip to content

Commit 59fab8b

Browse files
committed
feat: add _actual_distributed_type constant to decide parallel DistributedType, add data parallel of native mindspore to mindnlp.Trainer.base
1 parent 791e2df commit 59fab8b

File tree

15 files changed

+110
-106
lines changed

15 files changed

+110
-106
lines changed

mindnlp/accelerate/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# DeepSpeedPlugin,
77
# DistributedDataParallelKwargs,
88
# FullyShardedDataParallelPlugin,
9+
accelerate_distributed_type,
10+
DistributedType,
911
# GradScalerKwargs,
1012
# InitProcessGroupKwargs,
1113
# ProfileKwargs,

mindnlp/accelerate/accelerator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
is_mindformers_available,
1515
wait_for_everyone
1616
)
17-
from ..utils import _actual_distributed_type, logging, DistributedType
17+
from .utils import DistributedType,accelerate_distributed_type
18+
from ..utils import logging
1819

1920
if is_mindformers_available():
2021
from .utils import (
@@ -45,7 +46,7 @@ def __init__(
4546
# init mindformers_plugin from env variables
4647
if mindformers_plugin is None:
4748
mindformers_plugin = (
48-
MindFormersPlugin() if _actual_distributed_type == DistributedType.MINDFORMERS else None
49+
MindFormersPlugin() if accelerate_distributed_type == DistributedType.MINDFORMERS else None
4950
)
5051
else:
5152
os.environ["ACCELERATE_USE_MINDFORMERS"] = "true"
@@ -104,10 +105,10 @@ def prepare(self, *args):
104105
"""
105106
result = []
106107

107-
# Only support mindsormers and MULTI_NPU_DATA_PARALLEL now
108+
# Only support mindsormers and MULTI_NPU_DP now
108109
if self.distributed_type == DistributedType.MINDFORMERS:
109110
result = self._prepare_mindformers(*args)
110-
elif self.distributed_type == DistributedType.MULTI_NPU_DATA_PARALLEL:
111+
elif self.distributed_type == DistributedType.MULTI_NPU_DP:
111112
result = self._prepare_data_parallel_native_minspore(*args)
112113
return result
113114

mindnlp/accelerate/state.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .utils import (
1414
is_mindformers_available
1515
)
16-
from ..utils import _actual_distributed_type, DistributedType
16+
from ..accelerate.utils import accelerate_distributed_type, DistributedType
1717

1818
SharedDict = dict
1919

@@ -344,12 +344,12 @@ def print(self, *args, **kwargs):
344344

345345
def _prepare_backend(self):
346346
# now mindformers and mindspore data parallel only
347-
if _actual_distributed_type == DistributedType.MINDFORMERS and is_mindformers_available():
347+
if accelerate_distributed_type == DistributedType.MINDFORMERS and is_mindformers_available():
348348
self.backend = "hccl"
349349
self.distributed_type = DistributedType.MINDFORMERS
350-
elif _actual_distributed_type == DistributedType.MULTI_NPU_DATA_PARALLEL:
350+
elif accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
351351
self.backend = "hccl"
352-
self.distributed_type = DistributedType.MULTI_NPU_DATA_PARALLEL
352+
self.distributed_type = DistributedType.MULTI_NPU_DP
353353

354354
@num_processes.setter
355355
def num_processes(self, value):
@@ -372,9 +372,9 @@ def __init__(self, mindformers_plugin=None, **kwargs):
372372
PartialState(**kwargs)
373373
self.__dict__.update(PartialState._shared_state)
374374
# set distributed_type
375-
if _actual_distributed_type == DistributedType.MULTI_NPU_DATA_PARALLEL:
376-
self.distributed_type = DistributedType.MULTI_NPU_DATA_PARALLEL
377-
elif _actual_distributed_type == DistributedType.MINDFORMERS:
375+
if accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
376+
self.distributed_type = DistributedType.MULTI_NPU_DP
377+
elif accelerate_distributed_type == DistributedType.MINDFORMERS:
378378
self.distributed_type = DistributedType.MINDFORMERS
379379
self.mindformers_plugin = mindformers_plugin
380380
else:

mindnlp/accelerate/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""accelerate utils"""
2+
from .constants import accelerate_distributed_type
23
from .dataclasses import (
4+
DistributedType,
35
MindFormersPlugin
46
)
57
from .environment import (

mindnlp/accelerate/utils/constants.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""constants"""
2+
import os
3+
from .dataclasses import DistributedType
4+
5+
def detect_accelerate_distributed_type():
6+
"""
7+
detect distributed_type
8+
9+
Returns:
10+
_type_: According to the factors such as the available parallel software and hardware environment of the current system and the user-specified parallel scheme,
11+
the optimal parallel strategy is comprehensively decided in different situations.
12+
"""
13+
if os.environ.get("MULTI_NPU_DP", None) == "true":
14+
return DistributedType.MULTI_NPU_DP
15+
if os.environ.get("ACCELERATE_USE_MINDFORMERS", "false") == "true":
16+
return DistributedType.MINDFORMERS
17+
else:
18+
return DistributedType.NO
19+
20+
accelerate_distributed_type = detect_accelerate_distributed_type()
21+

mindnlp/accelerate/utils/dataclasses.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,21 @@
1111
)
1212

1313

14+
class DistributedType(str, enum.Enum):
15+
"""
16+
Represents a type of distributed environment.
17+
18+
Values:
19+
- **MINDFORMERS** -- Using mindformers
20+
- **NO** -- Not a distributed environment, just a single process.
21+
- **MULTI_NPU_DP** -- Distributed data parallel on multiple NPUs.
22+
"""
23+
24+
MULTI_NPU_DP = "MULTI_NPU_DP"
25+
MINDFORMERS = "MINDFORMERS"
26+
NO = "NO"
27+
28+
1429
@dataclass
1530
class MindFormersPlugin:
1631
"""

mindnlp/dataset/load.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
from datasets import Dataset, IterableDataset, Split, Features, \
2424
DownloadConfig, DownloadMode, VerificationMode, Version
2525
from mindnlp.configs import DEFAULT_ROOT
26-
from ..utils.constants import _actual_distributed_type
27-
from ..utils.dataclasses import DistributedType
26+
from mindspore.communication import get_rank, get_group_size
27+
from ..accelerate import DistributedType
28+
from ..accelerate.utils import accelerate_distributed_type
29+
2830

2931
class TransferIterableDataset():
3032
"""TransferDataset for Huggingface Dataset."""
@@ -333,20 +335,19 @@ def load_dataset(
333335
column_names = list(raw_ds.features.keys())
334336
source = TransferDataset(raw_ds, column_names) if isinstance(raw_ds, Dataset) \
335337
else TransferIterableDataset(raw_ds, column_names)
336-
ms_ds = ms_ds
337-
if _actual_distributed_type == DistributedType.MULTI_NPU_DATA_PARALLEL:
338-
from mindspore.communication import get_rank, get_group_size
338+
if accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
339339
ms_ds = GeneratorDataset(source=source,
340340
column_names=column_names,
341341
shuffle=shuffle,
342342
num_parallel_workers=num_proc if num_proc else 1,
343343
num_shards=get_group_size(), shard_id=get_rank())
344+
datasets_dict[key] = ms_ds
344345
else:
345346
ms_ds = GeneratorDataset(source=source,
346347
column_names=column_names,
347348
shuffle=shuffle,
348349
num_parallel_workers=num_proc if num_proc else 1)
349-
datasets_dict[key] = ms_ds
350+
datasets_dict[key] = ms_ds
350351

351352
if len(datasets_dict) == 1:
352353
return datasets_dict.popitem()[1]

mindnlp/engine/trainer/base.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@
4545
WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
4646
from ...dataset import BaseMapFunction
4747
from ...utils import logging, find_labels, can_return_loss
48-
from ...utils.constants import _actual_distributed_type
49-
from ...utils.dataclasses import DistributedType
48+
from ...accelerate.utils import DistributedType
49+
from ...accelerate.utils import accelerate_distributed_type
5050
from ...utils.import_utils import is_safetensors_available
5151
from ...transformers.modeling_utils import PreTrainedModel
5252
from ...transformers.configuration_utils import PretrainedConfig
@@ -286,7 +286,7 @@ def __init__(
286286
# Internal variables to help with automatic batch size reduction
287287
self._train_batch_size = args.train_batch_size
288288
self._created_lr_scheduler = False
289-
self.actual_distributed_type = _actual_distributed_type
289+
self.actual_distributed_type = accelerate_distributed_type
290290

291291
def _activate_neftune(self, model):
292292
r"""
@@ -1376,6 +1376,14 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[mindspore.Tens
13761376
inputs = self._prepare_inputs(inputs)
13771377

13781378
def forward(inputs):
1379+
if accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
1380+
from mindspore.communication import get_group_size
1381+
import mindspore.ops as msops
1382+
rank_size = get_group_size()
1383+
for parameter in model.parameters():
1384+
all_reduce_sum = msops.AllReduce(msops.ReduceOp.SUM)
1385+
new_grads_mean = all_reduce_sum(parameter.grad) / rank_size
1386+
parameter.grad = new_grads_mean
13791387
return self.compute_loss(model, inputs)
13801388

13811389
if getattr(self, 'grad_fn', None) is None or self.model_reload:

mindnlp/engine/trainer/default_func.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from mindspore.amp import all_finite
2020

2121
from mindnlp.utils import ModelOutput
22-
from ...utils.constants import _actual_distributed_type
23-
from ...utils.dataclasses import DistributedType
22+
from ...accelerate.utils import DistributedType
23+
from ...accelerate.utils import accelerate_distributed_type
2424

2525
def get_default_forward_fn_with_loss_fn(network, loss_fn, loss_scaler):
2626
"""get default forward function with loss function"""
@@ -66,9 +66,6 @@ def get_default_train_step_fn(forward_fn, optimizer, loss_scaler, check_gradient
6666
def default_run_step(labels, *args, **kwargs):
6767
"""Core process of each step, including the forward propagation process and back propagation of data."""
6868
loss, grads = grad_fn(labels, *args, **kwargs)
69-
if _actual_distributed_type == DistributedType.MULTI_NPU_DATA_PARALLEL:
70-
grads = nn.DistributedGradReducer(optimizer.parameters)
71-
grad_reducer = nn.DistributedGradReducer(optimizer.parameters)
7269
loss = loss_scaler.unscale(loss)
7370
if check_gradients:
7471
is_finite = all_finite(grads)
@@ -83,8 +80,6 @@ def default_run_step(labels, *args, **kwargs):
8380
def default_run_step_for_obj_net(*args, **kwargs):
8481
"""Core process of each step, including the forward propagation process and back propagation of data."""
8582
loss, grads = grad_fn(*args, **kwargs)
86-
if _actual_distributed_type == DistributedType.MULTI_NPU_DATA_PARALLEL:
87-
grads = nn.DistributedGradReducer(optimizer.parameters)
8883
loss = loss_scaler.unscale(loss)
8984
if check_gradients:
9085
is_finite = all_finite(grads)

mindnlp/utils/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
from .download import *
2222
from .compatibility import *
2323
from .chat_template_utils import *
24-
from .dataclasses import DistributedType
25-
from .constants import _actual_distributed_type
2624
from .import_utils import requires_backends, is_mindspore_available, OptionalDependencyNotAvailable, is_sentencepiece_available, \
2725
is_tokenizers_available, direct_transformers_import, is_protobuf_available, is_safetensors_available, \
2826
is_cython_available, is_pretty_midi_available, is_essentia_available, is_librosa_available, is_scipy_available, is_pyctcdecode_available, \

0 commit comments

Comments
 (0)