Skip to content

Commit 67e7bd3

Browse files
committed
refactor: move class function inside Trainer
1 parent 59fab8b commit 67e7bd3

File tree

2 files changed

+26
-25
lines changed

2 files changed

+26
-25
lines changed

mindnlp/engine/trainer/base.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@
9090
TrainerControl,
9191
TrainerState,
9292
)
93-
from ..utils import _get_learning_rate
9493

9594

9695
logger = logging.get_logger(__name__)
@@ -126,7 +125,7 @@ def _is_peft_model(model):
126125
class Trainer:
127126
"""
128127
Trainer is a simple but feature-complete training and eval loop for MindSpore, optimized for 🤗 Transformers.
129-
"""
128+
"""
130129
def __init__(
131130
self,
132131
model: Union[PreTrainedModel, nn.Module] = None,
@@ -288,6 +287,29 @@ def __init__(
288287
self._created_lr_scheduler = False
289288
self.actual_distributed_type = accelerate_distributed_type
290289

290+
291+
def _get_learning_rate(self):
292+
r"""
293+
This function retrieves the learning rate used by the optimizer.
294+
295+
Args:
296+
self: An instance of the class containing the optimizer and learning rate scheduler.
297+
298+
Returns:
299+
The learning rate value (float) used by the optimizer.
300+
301+
Raises:
302+
None.
303+
"""
304+
if isinstance(self.lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau):
305+
last_lr = self.optimizer.param_groups[0]["lr"]
306+
else:
307+
last_lr = self.lr_scheduler.get_last_lr()[0]
308+
if ops.is_tensor(last_lr):
309+
last_lr = last_lr.item()
310+
return last_lr
311+
312+
291313
def _activate_neftune(self, model):
292314
r"""
293315
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
@@ -1136,6 +1158,7 @@ def _inner_training_loop(
11361158
model.parameters(),
11371159
args.max_grad_norm,
11381160
)
1161+
11391162
# Optimizer step
11401163
self.optimizer.step()
11411164

@@ -1376,7 +1399,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[mindspore.Tens
13761399
inputs = self._prepare_inputs(inputs)
13771400

13781401
def forward(inputs):
1379-
if accelerate_distributed_type == DistributedType.MULTI_NPU_DP:
1402+
if accelerate_distributed_type == DistributedType.MULTI_NPU:
13801403
from mindspore.communication import get_group_size
13811404
import mindspore.ops as msops
13821405
rank_size = get_group_size()

mindnlp/engine/utils.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -505,28 +505,6 @@ def speed_metrics(split, start_time, num_samples=None, num_steps=None, num_token
505505
result[f"{split}_tokens_per_second"] = round(tokens_per_second, 3)
506506
return result
507507

508-
def _get_learning_rate(self):
509-
r"""
510-
This function retrieves the learning rate used by the optimizer.
511-
512-
Args:
513-
self: An instance of the class containing the optimizer and learning rate scheduler.
514-
515-
Returns:
516-
The learning rate value (float) used by the optimizer.
517-
518-
Raises:
519-
None.
520-
"""
521-
if isinstance(self.lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau):
522-
last_lr = self.optimizer.param_groups[0]["lr"]
523-
else:
524-
last_lr = self.lr_scheduler.get_last_lr()[0]
525-
if ops.is_tensor(last_lr):
526-
last_lr = last_lr.item()
527-
return last_lr
528-
529-
530508
def find_batch_size(tensors):
531509
"""
532510
Find the first dimension of a tensor in a nested list/tuple/dict of tensors.

0 commit comments

Comments
 (0)