Skip to content

Add: Sparse Finetuning Integration for Axolotl #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from

Conversation

rahul-tuli
Copy link
Member

@rahul-tuli rahul-tuli commented Mar 4, 2025

This PR introduces sparse fine-tuning support in Axolotl using LLMCompressor as a plugin. This integration allows users to efficiently fine-tune models with structured/unstructured sparsity.

Key Changes

  • New Plugin Implementation
    • Added src/axolotl/integrations/llmcompressor_sft/__init__.py for integrating sparse fine-tuning.
    • Introduced src/axolotl/integrations/llmcompressor_sft/args.py to configure sparse training.
  • Configuration Updates
    • Added examples/llama-3/sft.yaml to showcase sparse fine-tuning on LLaMA-3 models. (Note: Right now a test is running with SparseLlama 8B 2:4 on gsmk8 will update example and lm_eval results here once complete)
  • Enhancements & Modifications
    • Updated src/axolotl/utils/models.py for sparse model handling.
    • Extended plugin functionality in src/axolotl/integrations/llmcompressor_sft/args.py and __init__.py.

Installation

Expand for Installation Instructions
#!/bin/bash
# ./install-axolotl-llmcompressor.sh
# Exit on error, unset variable usage, and pipeline failures
set -euo pipefail

# Define repository URL
REPO_URL="https://github.com/axolotl-ai-cloud/axolotl.git"
REPO_NAME="axolotl"

# Function to check if a command exists
command_exists() {
    command -v "$1" &>/dev/null
}

# Ensure necessary dependencies are installed
check_dependencies() {
    local dependencies=("git" "python3" "pip3")
    for dep in "${dependencies[@]}"; do
        if ! command_exists "$dep"; then
            echo "Error: $dep is not installed. Please install it before running this script."
            exit 1
        fi
    done
}

# Clone the repository if it does not already exist
clone_repository() {
    if [ ! -d "$REPO_NAME" ]; then
        git clone "$REPO_URL"
    else
        echo "Repository '$REPO_NAME' already exists. Skipping clone."
    fi
}

# Create and activate Python virtual environment
setup_virtualenv() {
    cd "$REPO_NAME"
    python3 -m venv venv
    source ./venv/bin/activate
    python3 -m ensurepip
}

# Install necessary dependencies
install_dependencies() {
    pip install --upgrade pip setuptools wheel
    pip install --upgrade torch pydantic
    pip install "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"
    pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
    pip install --no-build-isolation -e ".[deepspeed,flash-attn]"
    pip install llmcompressor
}

# Main script execution
main() {
    check_dependencies
    clone_repository
    setup_virtualenv
    install_dependencies
    echo "Setup complete! Virtual environment activated."
}

main

Training Command

To train with sparse fine-tuning, use:

axolotl train examples/llama-3/sft.yaml

Test Run Output

Expand to View Output Logs
[2025-03-12 07:05:36] [INFO] PyTorch version 2.6.0 available.
[2025-03-12 07:05:40] [INFO] Plugin loaded successfully: axolotl.integrations.llmcompressor_sft.SFTPlugin
[2025-03-12 07:05:41] [WARNING] batch_size is not recommended. Use gradient_accumulation_steps instead.
[2025-03-12 07:05:49] [INFO] Maximum number of steps set at 1
[2025-03-12 07:05:51] [INFO] Training started...
[2025-03-12 07:05:57] [INFO] Training completed! Saving model to temp_debug/axolotl_outputs/model.

Sparsity Verification

To verify 2:4 structured sparsity is maintained after finetuning

Script
import argparse
import re
import torch
from safetensors import safe_open

# Default path to your .safetensors file
default_file_path = "model.safetensors"

# Set up argument parser
parser = argparse.ArgumentParser(
    description="Print tensor names and shapes from a .safetensors file."
)
parser.add_argument(
    "--file_path",
    default=default_file_path,
    help="Path to the .safetensors file",
)
parser.add_argument("--regex", help="Optional regex to filter tensor names")
parser.add_argument("--check_sparsity", action="store_true", help="Check tensor sparsity percentage and 2:4 sparsity compliance")

args = parser.parse_args()

# Function to check sparsity
def check_sparsity(tensor):
    total_elements = tensor.numel()
    zero_elements = (tensor == 0).sum().item()
    sparsity_percentage = (zero_elements / total_elements) * 100
    
    # Check 2:4 sparsity: at least 2 zeros in every 4 values
    tensor_flat = tensor.view(-1, 4)  # Reshape to groups of 4
    is_2of4_sparse = ((tensor_flat == 0).sum(dim=1) >= 2).all().item()
    
    return sparsity_percentage, is_2of4_sparse

# Open the .safetensors file
with safe_open(args.file_path, framework="pt") as f:
    # Iterate over each tensor in the file
    for tensor_name in f.keys():
        if args.regex:
            if not re.match(args.regex, tensor_name):
                continue
        tensor = f.get_tensor(tensor_name)

        print(f"Tensor Name: {tensor_name}, \n\t Shape: {tensor.shape}, Type: {tensor.dtype}")
        
        if args.check_sparsity:
            sparsity, is_2of4 = check_sparsity(tensor)
            print(f"\t Sparsity: {sparsity:.2f}%, 2:4 Sparse: {is_2of4}")

Command:

python3 check_safetensors.py --file_path /path/to/model.safetensors --regex ".*layers.*[qkv]_proj.*" --check_sparsity
Expand for Sample Output
Tensor Name: model.layers.0.self_attn.k_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.0.self_attn.q_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.0.self_attn.v_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.1.self_attn.k_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.1.self_attn.q_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.1.self_attn.v_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.2.self_attn.k_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.2.self_attn.q_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.2.self_attn.v_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.3.self_attn.k_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.3.self_attn.q_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.3.self_attn.v_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.4.self_attn.k_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.4.self_attn.q_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.4.self_attn.v_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.5.self_attn.k_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.5.self_attn.q_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.5.self_attn.v_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.6.self_attn.k_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.6.self_attn.q_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.6.self_attn.v_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.7.self_attn.k_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.7.self_attn.q_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True
Tensor Name: model.layers.7.self_attn.v_proj.weight, 
         Shape: torch.Size([512, 512]), Type: torch.bfloat16
         Sparsity: 50.00%, 2:4 Sparse: True

Running in vLLM

To load the fine-tuned sparse model in vLLM:

Script
import argparse
from vllm import LLM, SamplingParams


def run_inference(model_path, tensor_parallel_size, prompt="Hello my name is:"):
    """
    Loads a model and performs inference using LLM.
    """
    # Define sampling parameters
    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
    )
    # Load the model
    model = LLM(
        model=model_path,
        enforce_eager=True,
        dtype="auto",
        tensor_parallel_size=tensor_parallel_size,
    )

    # Generate inference
    outputs = model.generate(prompt, sampling_params=sampling_params)
    return outputs[0].outputs[0].text


def main():
    """Main function to handle CLI and process the model."""
    # Argument parsing
    parser = argparse.ArgumentParser(
        description="Run inference on a single model and print results."
    )
    parser.add_argument(
        "model_path", type=str, help="Path to the model to perform inference."
    )
    parser.add_argument(
        "--tensor_parallel_size",
        type=int,
        default=2,
        help="Tensor parallel size for the model. Default is 2.",
    )

    args = parser.parse_args()
    model_path = args.model_path
    tensor_parallel_size = args.tensor_parallel_size

    prompt = "Hello my name is:"

    # Load model and perform inference
    inference_result = run_inference(model_path, tensor_parallel_size)
    print("=" * 20)
    print("Model:", model_path)
    print(prompt, inference_result)


if __name__ == "__main__":
    main()
python run_model.py /path/to/model/ --tensor_parallel_size 1
Expand for Inference Output
(.venv) ➜  vllm git:(main) ✗ python run_model.py /home/rahul/axolotl/devtools/temp_debug/axolotl_outputs/model/ --tensor_parallel_size 1
INFO 03-12 07:27:49 [__init__.py:256] Automatically detected platform cuda.
INFO 03-12 07:27:54 [config.py:576] This model supports multiple tasks: {'generate', 'score', 'embed', 'classify', 'reward'}. Defaulting to 'generate'.
WARNING 03-12 07:27:55 [cuda.py:95] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 03-12 07:27:55 [llm_engine.py:235] Initializing a V0 LLM engine (v0.7.1.dev282+gc9e2d644e) with config: model='/home/rahul/axolotl/devtools/temp_debug/axolotl_outputs/model/', speculative_config=None, tokenizer='/home/rahul/axolotl/devtools/temp_debug/axolotl_outputs/model/', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=compressed-tensors, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=/home/rahul/axolotl/devtools/temp_debug/axolotl_outputs/model/, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[],"max_capture_size":0}, use_cached_outputs=False, 
INFO 03-12 07:27:56 [cuda.py:277] Using Flash Attention backend.
INFO 03-12 07:27:57 [parallel_state.py:948] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 03-12 07:27:57 [model_runner.py:1110] Starting to load model /home/rahul/axolotl/devtools/temp_debug/axolotl_outputs/model/...
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 56.25it/s]

INFO 03-12 07:27:57 [loader.py:422] Loading weights took 0.02 seconds
INFO 03-12 07:27:58 [model_runner.py:1117] Model loading took 0.0586 GB and 0.196761 seconds
INFO 03-12 07:27:58 [worker.py:267] Memory profiling takes 0.82 seconds
INFO 03-12 07:27:58 [worker.py:267] the current vLLM instance can use total_gpu_memory (79.22GiB) x gpu_memory_utilization (0.90) = 71.29GiB
INFO 03-12 07:27:58 [worker.py:267] model weights take 0.06GiB; non_torch_memory takes 0.16GiB; PyTorch activation peak memory takes 0.33GiB; the rest of the memory reserved for KV Cache is 70.75GiB.
INFO 03-12 07:27:59 [executor_base.py:111] # cuda blocks: 289799, # CPU blocks: 16384
INFO 03-12 07:27:59 [executor_base.py:116] Maximum concurrency for 4096 tokens per request: 1132.03x
INFO 03-12 07:28:00 [llm_engine.py:441] init engine (profile, create kv cache, warmup model) took 2.31 seconds
Processed prompts: 100%|████████████████| 1/1 [00:00<00:00, 14.30it/s, est. speed input: 85.85 toks/s, output: 228.90 toks/s]
====================
Model: /home/rahul/axolotl/devtools/temp_debug/axolotl_outputs/model/
Hello my name is:  It is my name. His name is his book. His name is it.
[rank0]:[W312 07:28:04.173295912 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
Flow diagram
[Configuration File (sft.yaml)] --- read config ---> [Axolotl Framework]
                                      |
                                      | initialize with config
                                      v
                                  [Plugin Manager]
                                      |
                                      | load plugin
                                      v
                             [SFT Plugin]
                                      |
                                      | fetch recipe
                                      v
                                  [Model]
                                      |
                                      | load pre-trained sparse model
                                      v
                                  [Dataset]
                                      |
                                      | start training loop
                                      v
                                  [on_train_begin] (SFTPlugin)
                                      | initialize session with model, optimizer, recipe, start epoch
                                      | start training loop
                                      v
                              [Training Loop]
                                      |
                                      | call training step hook
                                      v
                             [SFT Plugin]
                                      |
                                      | invoke appropriate callbacks based on lifecycle hooks
                                      v
                                  [Model]
                                      |
                                      | save model
                                      v
                         [Fine-tuned Sparse Model]

Order of callbacks:

        for gradient_batches in training loop:
            for batch in gradient_batches or [gradient_batches]:
                BATCH_START
                LOSS_CALCULATED

                if not last batch:
                    BATCH_END
                else:
                    OPTIM_PRE_STEP
                    OPTIM_POST_STEP
                    BATCH_END

@rahul-tuli rahul-tuli force-pushed the llmcompressor-sft branch 2 times, most recently from 14757d1 to 4438e5d Compare March 12, 2025 07:10
@rahul-tuli rahul-tuli changed the title Init Add: Sparse Finetuning Integration with llmcompressor Mar 12, 2025
@rahul-tuli rahul-tuli changed the title Add: Sparse Finetuning Integration with llmcompressor Add: Sparse Finetuning Integration for Axolotl Mar 12, 2025
@rahul-tuli rahul-tuli self-assigned this Mar 12, 2025
@rahul-tuli rahul-tuli marked this pull request as ready for review March 12, 2025 08:16
Copy link

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome! really nice and compact. Should we get feedback from MLR team? Maybe a demo in a call?

Input arguments for Sparse Finetuning.
"""

recipe: Optional[Any] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the note above in the example for namespacing this and setting up proper types. Along those lines, we should set expectations for the type of the recipe, specifically that we support a dictionary arg passed in or a string representing the path to a file or model stub

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now we only support specifying full recipes, can make that update if needed in a separate diff

Copy link
Member

@markurtz markurtz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor things, but otherwise looks good to me

LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")


class SFTCallbackHandler(TrainerCallback):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this to something like CompressorCallbackHandler? I don't want to hardcode this on sparsity since the implementation is more generic and ideally will enable other pathways rather than just sparsity in the near future

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

rahul-tuli and others added 2 commits April 4, 2025 10:36
Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>
@rahul-tuli
Copy link
Member Author

@markurtz I've addressed all your changes in commits: 46296bc3, and 5cf596c0; I'm closing this PR in favor of upstream PR axolotl-ai-cloud#2479 ;

@rahul-tuli rahul-tuli closed this Apr 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants