|
90 | 90 | TrainerControl,
|
91 | 91 | TrainerState,
|
92 | 92 | )
|
93 |
| -from ..utils import _get_learning_rate |
94 | 93 |
|
95 | 94 |
|
96 | 95 | logger = logging.get_logger(__name__)
|
@@ -126,7 +125,7 @@ def _is_peft_model(model):
|
126 | 125 | class Trainer:
|
127 | 126 | """
|
128 | 127 | Trainer is a simple but feature-complete training and eval loop for MindSpore, optimized for 🤗 Transformers.
|
129 |
| - """ |
| 128 | + """ |
130 | 129 | def __init__(
|
131 | 130 | self,
|
132 | 131 | model: Union[PreTrainedModel, nn.Module] = None,
|
@@ -288,6 +287,29 @@ def __init__(
|
288 | 287 | self._created_lr_scheduler = False
|
289 | 288 | self.actual_distributed_type = accelerate_distributed_type
|
290 | 289 |
|
| 290 | + |
| 291 | + def _get_learning_rate(self): |
| 292 | + r""" |
| 293 | + This function retrieves the learning rate used by the optimizer. |
| 294 | + |
| 295 | + Args: |
| 296 | + self: An instance of the class containing the optimizer and learning rate scheduler. |
| 297 | + |
| 298 | + Returns: |
| 299 | + The learning rate value (float) used by the optimizer. |
| 300 | + |
| 301 | + Raises: |
| 302 | + None. |
| 303 | + """ |
| 304 | + if isinstance(self.lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau): |
| 305 | + last_lr = self.optimizer.param_groups[0]["lr"] |
| 306 | + else: |
| 307 | + last_lr = self.lr_scheduler.get_last_lr()[0] |
| 308 | + if ops.is_tensor(last_lr): |
| 309 | + last_lr = last_lr.item() |
| 310 | + return last_lr |
| 311 | + |
| 312 | + |
291 | 313 | def _activate_neftune(self, model):
|
292 | 314 | r"""
|
293 | 315 | Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
|
@@ -1136,6 +1158,7 @@ def _inner_training_loop(
|
1136 | 1158 | model.parameters(),
|
1137 | 1159 | args.max_grad_norm,
|
1138 | 1160 | )
|
| 1161 | + |
1139 | 1162 | # Optimizer step
|
1140 | 1163 | self.optimizer.step()
|
1141 | 1164 |
|
@@ -1376,7 +1399,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[mindspore.Tens
|
1376 | 1399 | inputs = self._prepare_inputs(inputs)
|
1377 | 1400 |
|
1378 | 1401 | def forward(inputs):
|
1379 |
| - if accelerate_distributed_type == DistributedType.MULTI_NPU_DP: |
| 1402 | + if accelerate_distributed_type == DistributedType.MULTI_NPU: |
1380 | 1403 | from mindspore.communication import get_group_size
|
1381 | 1404 | import mindspore.ops as msops
|
1382 | 1405 | rank_size = get_group_size()
|
|
0 commit comments