1
+ """
2
+ Sparse Finetuning plugin for Axolotl - enables handling of sparse neural networks
3
+ by maintaining masks for zero weights during training.
4
+ """
5
+
6
+ import logging
7
+ from transformers .trainer_callback import TrainerCallback , TrainerState , TrainerControl
8
+ from transformers .training_args import TrainingArguments
9
+
10
+ from ..base import BasePlugin
11
+ from .args import LLMCompressorArgs # pylint: disable=unused-import. # noqa: F401
12
+ from llmcompressor import initialize
13
+ from llmcompressor .core import callbacks as session_callbacks
14
+ from llmcompressor .recipe import Recipe
15
+
16
+ LOG = logging .getLogger ("axolotl.integrations.llmcompressor_sft" )
17
+
18
+ class SFTCallbackHandler (TrainerCallback ):
19
+ """
20
+ Transformer trainer callback for Sparse Finetuning.
21
+ Maintains sparsity patterns during training by applying masks after optimization steps.
22
+ This ensures that optimizer updates to zero weights are canceled out.
23
+ """
24
+
25
+ def __init__ (self , trainer : object , recipe : object ):
26
+ """
27
+ Initialize the callback handler.
28
+
29
+ Args:
30
+ trainer (object): The trainer instance.
31
+ recipe (object): The sparse finetuning recipe to be applied.
32
+ """
33
+ super ().__init__ ()
34
+ self .trainer = trainer
35
+ self .recipe = Recipe .model_validate (recipe )
36
+
37
+ if hasattr (self .trainer , "compute_loss" ):
38
+ self .trainer .compute_loss = compute_loss_wrapper (self .trainer .compute_loss )
39
+
40
+ def on_train_begin (self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs ):
41
+ """
42
+ Event triggered at the beginning of training.
43
+ Updates the session reference to the model, accommodating changes due to wrappers like FSDP.
44
+ """
45
+ super ().on_train_begin (args , state , control , ** kwargs )
46
+ initialize (
47
+ model = self .trainer .model ,
48
+ optimizer = self .trainer .optimizer ,
49
+ start = state .epoch ,
50
+ recipe = self .recipe ,
51
+ )
52
+
53
+ def on_step_begin (self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs ):
54
+ """
55
+ Event triggered at the beginning of a training step.
56
+ Calls batch_start in the active CompressionSession.
57
+ """
58
+ super ().on_step_begin (args , state , control , ** kwargs )
59
+ session_callbacks .batch_start ()
60
+
61
+ def on_step_end (self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs ):
62
+ """
63
+ Event triggered at the end of a training step.
64
+ Calls optimizer pre-step, post-step, and batch_end callbacks.
65
+ """
66
+ super ().on_step_end (args , state , control , ** kwargs )
67
+ session_callbacks .optim_pre_step ()
68
+ session_callbacks .optim_post_step ()
69
+ session_callbacks .batch_end ()
70
+
71
+ def on_substep_end (self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs ):
72
+ """
73
+ Event triggered at the end of a substep during gradient accumulation.
74
+ Calls batch_end in the active CompressionSession.
75
+ """
76
+ super ().on_substep_end (args , state , control , ** kwargs )
77
+ session_callbacks .batch_end ()
78
+
79
+ # def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
80
+ # super().on_prediction_step(args, state, control, **kwargs)
81
+ # session_callbacks.loss_calculated()
82
+
83
+ class SFTPlugin (BasePlugin ):
84
+ """
85
+ Plugin for Sparse Finetuning integration with Axolotl.
86
+ """
87
+
88
+ def get_input_args (self ) -> str :
89
+ """
90
+ Returns the input argument path for the plugin.
91
+ """
92
+ return "axolotl.integrations.llmcompressor_sft.LLMCompressorArgs"
93
+
94
+ def add_callbacks_post_trainer (self , cfg , trainer ):
95
+ """
96
+ Adds Sparse Finetuning callback to the trainer.
97
+
98
+ Args:
99
+ cfg (object): Configuration object containing the recipe.
100
+ trainer (object): Trainer instance to which the callback is added.
101
+
102
+ Returns:
103
+ list: A list containing the Sparse Finetuning callback.
104
+ """
105
+ LOG .info ("Adding Sparse Finetuning callback to the trainer" )
106
+ callback = SFTCallbackHandler (
107
+ trainer = trainer ,
108
+ recipe = cfg .recipe ,
109
+ )
110
+ return [callback ]
111
+
112
+
113
+ def compute_loss_wrapper (compute_loss_func ):
114
+ """
115
+ Wraps the loss computation function to integrate with the active CompressionSession.
116
+
117
+ Args:
118
+ compute_loss_func (function): The original loss computation function.
119
+
120
+ Returns:
121
+ function: Wrapped function that reports the computed loss.
122
+ """
123
+ def wrapper (* args , ** kwargs ):
124
+ loss = compute_loss_func (* args , ** kwargs )
125
+ session_callbacks .loss_calculated (loss = loss )
126
+ # take the mean across multiple GPUs
127
+ # this is done outside the compute_loss function in the parent
128
+ loss = loss .mean ()
129
+ return loss
130
+ return wrapper
0 commit comments