diff --git a/examples/quantization_2of4_sparse_w4a16/README.md b/examples/quantization_2of4_sparse_w4a16/README.md index 51e04dd98..c1e34c280 100644 --- a/examples/quantization_2of4_sparse_w4a16/README.md +++ b/examples/quantization_2of4_sparse_w4a16/README.md @@ -45,37 +45,79 @@ It contains instructions to prune the model to 2:4 sparsity, run one epoch of re and quantize to 4 bits in one show using GPTQ. ```python +from pathlib import Path + import torch -from transformers import AutoModelForCausalLM +from loguru import logger +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot, train +# load the model in as bfloat16 to save on memory and compute model_stub = "neuralmagic/Llama-2-7b-ultrachat200k" model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.bfloat16) +tokenizer = AutoTokenizer.from_pretrained(model_stub) +# uses LLM Compressor's built-in preprocessing for ultra chat dataset = "ultrachat-200k" -splits = {"calibration": "train_gen[:5%]", "train": "train_gen"} +# Select the recipe for 2 of 4 sparsity and 4-bit activation quantization recipe = "2of4_w4a16_recipe.yaml" + +# save location of quantized model +output_dir = "output_llama7b_2of4_w4a16_channel" +output_path = Path(output_dir) + +# set dataset config parameters +splits = {"calibration": "train_gen[:5%]", "train": "train_gen"} +max_seq_length = 512 +num_calibration_samples = 512 + +# set training parameters for finetuning +num_train_epochs = 0.5 +logging_steps = 500 +save_steps = 5000 +gradient_checkpointing = True # saves memory during training +learning_rate = 0.0001 +bf16 = False # using full precision for training +lr_scheduler_type = "cosine" +warmup_ratio = 0.1 +preprocessing_num_workers = 8 ``` -## Step 2: Run sparsification using `apply` -The `apply` function applies the given recipe to our model and dataset. -The hardcoded kwargs may be altered based on each model's needs. -After running, the sparsified model will be saved to `output_llama7b_2of4_w4a16_channel`. +## Step 2: Run `sparsification`, `fine-tuning`, and `quantization` +The compression process now runs in three stages: sparsification, fine-tuning, and quantization. +Each stage saves the intermediate model outputs to the `output_llama7b_2of4_w4a16_channel` directory. ```python -from llmcompressor.transformers import apply +from llmcompressor import oneshot, train +from pathlib import Path output_dir = "output_llama7b_2of4_w4a16_channel" +output_path = Path(output_dir) -apply( +# 1. Oneshot sparsification: apply pruning +oneshot( model=model, dataset=dataset, recipe=recipe, - bf16=False, # use full precision for training + splits=splits, + num_calibration_samples=512, + preprocessing_num_workers=8, output_dir=output_dir, + stage="sparsity_stage", +) + +# 2. Sparse fine-tuning: improve accuracy on pruned model +train( + model=output_path / "sparsity_stage", + dataset=dataset, + recipe=recipe, splits=splits, - max_seq_length=512, num_calibration_samples=512, + preprocessing_num_workers=8, + bf16=False, + max_seq_length=512, num_train_epochs=0.5, logging_steps=500, save_steps=5000, @@ -83,11 +125,34 @@ apply( learning_rate=0.0001, lr_scheduler_type="cosine", warmup_ratio=0.1, + output_dir=output_dir, + stage="finetuning_stage", ) +# 3. Oneshot quantization: compress model weights to lower precision +quantized_model = oneshot( + model=output_path / "finetuning_stage", + dataset=dataset, + recipe=recipe, + splits=splits, + num_calibration_samples=512, + preprocessing_num_workers=8, + output_dir=output_dir, + stage="quantization_stage", +) +quantized_model.save_pretrained( + f"{output_dir}/quantization_stage", skip_sparsity_compression_stats=False +) +tokenizer.save_pretrained(f"{output_dir}/quantization_stage") + ``` ### Custom Quantization -The current repo supports multiple quantization techniques configured using a recipe. Supported strategies are `tensor`, `group` and `channel`. -The above recipe (`2of4_w4a16_recipe.yaml`) uses channel-wise quantization specified by `strategy: "channel"` in its config group. -To use quantize per tensor, change strategy from `channel` to `tensor`. To use group size quantization, change from `channel` to `group` and specify its value, say 128, by including `group_size: 128`. A group size quantization example is shown in `2of4_w4a16_group-128_recipe.yaml`. +The current repo supports multiple quantization techniques configured using a recipe. Supported strategies are tensor, group, and channel. + +The recipe (`2of4_w4a16_recipe.yaml`) uses channel-wise quantization (`strategy: "channel"`). +To change the quantization strategy, edit the recipe file accordingly: + +Use `tensor` for per-tensor quantization +Use `group` for group-wise quantization and specify the group_size parameter (e.g., 128) +See `2of4_w4a16_group-128_recipe.yaml` for a group-size example