Skip to content

Commit 4438e5d

Browse files
committed
Add: SFTPlugin with llmcompressor
1 parent f0072f3 commit 4438e5d

File tree

4 files changed

+237
-0
lines changed

4 files changed

+237
-0
lines changed

examples/llama-3/sft.yaml

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
base_model: "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed"
2+
# TODO: change to
3+
# base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
4+
5+
plugins:
6+
- axolotl.integrations.llmcompressor_sft.SFTPlugin
7+
8+
load_in_8bit: false
9+
load_in_4bit: false
10+
strict: false
11+
12+
datasets:
13+
- path: tatsu-lab/alpaca
14+
type: alpaca
15+
dataset_prepared_path: last_run_prepared
16+
val_set_size: 0.05
17+
output_dir: ./outputs/out
18+
19+
sequence_len: 4096
20+
sample_packing: true
21+
pad_to_sequence_len: true
22+
eval_sample_packing: false
23+
24+
wandb_project:
25+
wandb_entity:
26+
wandb_watch:
27+
wandb_name:
28+
wandb_log_model:
29+
30+
# gradient_accumulation_steps: 8
31+
micro_batch_size: 1
32+
num_epochs: 1
33+
optimizer: paged_adamw_8bit
34+
lr_scheduler: cosine
35+
learning_rate: 2e-5
36+
37+
train_on_inputs: false
38+
group_by_length: false
39+
bf16: auto
40+
fp16:
41+
tf32: false
42+
43+
gradient_checkpointing: true
44+
gradient_checkpointing_kwargs:
45+
use_reentrant: false
46+
early_stopping_patience:
47+
resume_from_checkpoint:
48+
logging_steps: 1
49+
xformers_attention:
50+
flash_attention: true
51+
52+
warmup_steps: 100
53+
evals_per_epoch: 2
54+
eval_table_size:
55+
saves_per_epoch: 1
56+
debug:
57+
deepspeed:
58+
weight_decay: 0.0
59+
fsdp:
60+
fsdp_config:
61+
special_tokens:
62+
pad_token: <|end_of_text|>
63+
recipe:
64+
finetuning_stage:
65+
finetuning_modifiers:
66+
ConstantPruningModifier:
67+
targets: [
68+
're:.*q_proj.weight',
69+
're:.*k_proj.weight',
70+
're:.*v_proj.weight',
71+
're:.*o_proj.weight',
72+
're:.*gate_proj.weight',
73+
're:.*up_proj.weight',
74+
're:.*down_proj.weight',
75+
]
76+
start: 0
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""
2+
Sparse Finetuning plugin for Axolotl - enables handling of sparse neural networks
3+
by maintaining masks for zero weights during training.
4+
"""
5+
6+
import logging
7+
from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
8+
from transformers.training_args import TrainingArguments
9+
10+
from ..base import BasePlugin
11+
from .args import LLMCompressorArgs # pylint: disable=unused-import. # noqa: F401
12+
from llmcompressor import initialize
13+
from llmcompressor.core import callbacks as session_callbacks
14+
from llmcompressor.recipe import Recipe
15+
16+
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
17+
18+
class SFTCallbackHandler(TrainerCallback):
19+
"""
20+
Transformer trainer callback for Sparse Finetuning.
21+
Maintains sparsity patterns during training by applying masks after optimization steps.
22+
This ensures that optimizer updates to zero weights are canceled out.
23+
"""
24+
25+
def __init__(self, trainer: object, recipe: object):
26+
"""
27+
Initialize the callback handler.
28+
29+
Args:
30+
trainer (object): The trainer instance.
31+
recipe (object): The sparse finetuning recipe to be applied.
32+
"""
33+
super().__init__()
34+
self.trainer = trainer
35+
self.recipe = Recipe.model_validate(recipe)
36+
37+
if hasattr(self.trainer, "compute_loss"):
38+
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
39+
40+
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
41+
"""
42+
Event triggered at the beginning of training.
43+
Updates the session reference to the model, accommodating changes due to wrappers like FSDP.
44+
"""
45+
super().on_train_begin(args, state, control, **kwargs)
46+
initialize(
47+
model=self.trainer.model,
48+
optimizer=self.trainer.optimizer,
49+
start=state.epoch,
50+
recipe=self.recipe,
51+
)
52+
53+
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
54+
"""
55+
Event triggered at the beginning of a training step.
56+
Calls batch_start in the active CompressionSession.
57+
"""
58+
super().on_step_begin(args, state, control, **kwargs)
59+
session_callbacks.batch_start()
60+
61+
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
62+
"""
63+
Event triggered at the end of a training step.
64+
Calls optimizer pre-step, post-step, and batch_end callbacks.
65+
"""
66+
super().on_step_end(args, state, control, **kwargs)
67+
session_callbacks.optim_pre_step()
68+
session_callbacks.optim_post_step()
69+
session_callbacks.batch_end()
70+
71+
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
72+
"""
73+
Event triggered at the end of a substep during gradient accumulation.
74+
Calls batch_end in the active CompressionSession.
75+
"""
76+
super().on_substep_end(args, state, control, **kwargs)
77+
session_callbacks.batch_end()
78+
79+
# def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
80+
# super().on_prediction_step(args, state, control, **kwargs)
81+
# session_callbacks.loss_calculated()
82+
83+
class SFTPlugin(BasePlugin):
84+
"""
85+
Plugin for Sparse Finetuning integration with Axolotl.
86+
"""
87+
88+
def get_input_args(self) -> str:
89+
"""
90+
Returns the input argument path for the plugin.
91+
"""
92+
return "axolotl.integrations.llmcompressor_sft.LLMCompressorArgs"
93+
94+
def add_callbacks_post_trainer(self, cfg, trainer):
95+
"""
96+
Adds Sparse Finetuning callback to the trainer.
97+
98+
Args:
99+
cfg (object): Configuration object containing the recipe.
100+
trainer (object): Trainer instance to which the callback is added.
101+
102+
Returns:
103+
list: A list containing the Sparse Finetuning callback.
104+
"""
105+
LOG.info("Adding Sparse Finetuning callback to the trainer")
106+
callback = SFTCallbackHandler(
107+
trainer=trainer,
108+
recipe=cfg.recipe,
109+
)
110+
return [callback]
111+
112+
113+
def compute_loss_wrapper(compute_loss_func):
114+
"""
115+
Wraps the loss computation function to integrate with the active CompressionSession.
116+
117+
Args:
118+
compute_loss_func (function): The original loss computation function.
119+
120+
Returns:
121+
function: Wrapped function that reports the computed loss.
122+
"""
123+
def wrapper(*args, **kwargs):
124+
loss = compute_loss_func(*args, **kwargs)
125+
session_callbacks.loss_calculated(loss=loss)
126+
# take the mean across multiple GPUs
127+
# this is done outside the compute_loss function in the parent
128+
loss = loss.mean()
129+
return loss
130+
return wrapper
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
Pydantic model for accepting `llmcompressor` specific arguments.
3+
"""
4+
from typing import Optional, Any
5+
from pydantic import BaseModel
6+
7+
8+
class LLMCompressorArgs(BaseModel):
9+
"""
10+
Input arguments for Sparse Finetuning.
11+
"""
12+
13+
recipe: Optional[Any] = None

src/axolotl/utils/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,24 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef
103103
hasattr(model_config, "quantization_config")
104104
and model_config.quantization_config
105105
)
106+
107+
# TODO: Use a better fix to handle
108+
# config.json produced by compressed-tensors
109+
# sparse-only model -> will also have a quantization_config
110+
111+
is_sparse_only_quant_config = bool(
112+
not quant_config_exists
113+
or (
114+
quant_config_exists
115+
and model_config.quantization_config["quant_method"] == "compressed-tensors"
116+
and not model_config.quantization_config.get("config_groups", False)
117+
and model_config.quantization_config.get("sparsity_config", False)
118+
)
119+
)
120+
121+
if is_sparse_only_quant_config:
122+
quant_config_exists = False
123+
106124
quant_config_method_is_gptq = (
107125
quant_config_exists
108126
and "quant_method" in model_config.quantization_config

0 commit comments

Comments
 (0)