diff --git a/examples/quantizing_moe/README.md b/examples/quantizing_moe/README.md index 70243caf1..8a9b257f4 100644 --- a/examples/quantizing_moe/README.md +++ b/examples/quantizing_moe/README.md @@ -17,17 +17,17 @@ pip install -e . The provided example script demonstrates an end-to-end process for applying the quantization algorithm: ```bash -python3 mixtral_moe_w8a8_fp8.py +python3 mixtral_example.py ``` ## Creating a Quantized MoE Model -This example leverages `llm-compressor` and `compressed-tensors` to create an FP8-quantized `Mixtral-8x7B-Instruct-v0.1` model. The model is calibrated and trained using the `open_platypus` dataset. +This example leverages `llm-compressor` and `compressed-tensors` to create an FP8-quantized `Mixtral-8x7B-Instruct-v0.1` model. The model is calibrated and trained using the `ultrachat_200k` dataset. You can follow the detailed steps below or simply run the example script with: ```bash -python mixtral_moe_w8a8_fp8.py +python mixtral_example.py ``` ### Step 1: Select a Model, Dataset, and Recipe @@ -74,7 +74,7 @@ NOTE: Only per-tensor quantization is supported in vLLM as of now (`vllm==0.6.1` The repository supports multiple quantization techniques configured via a recipe. Supported strategies include `tensor`, `group`, and `channel` quantization. -In the above example, FP8 per-tensor quantization is used as specified by the `FP8` scheme. For other preset schemes, refer to the [quantization schemes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py) in the `compressed-tensors` library. +In the above example, quantization is specified by the `W4A18` scheme. For other preset schemes, refer to the [quantization schemes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py) in the `compressed-tensors` library. A custom scheme can also be specified using `config_groups`: @@ -84,18 +84,18 @@ A custom scheme can also be specified using `config_groups`: from llmcompressor.modifiers.quantization.gptq import GPTQModifier config_groups = { - "group_0": { - "targets": ["Linear"], - "input_activations": None, - "output_activations": None, - "weights": { - "num_bits": 8, - "type": "int", - "symmetric": true, - "strategy": "group", - "group_size": 128, - } - } + "group_0": { + "targets": ["Linear"], + "input_activations": None, + "output_activations": None, + "weights": { + "num_bits": 8, + "type": "int", + "symmetric": true, + "strategy": "group", + "group_size": 128, + } + } } recipe = GPTQModifier(config_groups=config_groups) diff --git a/examples/quantizing_moe/deepseek_moe_w4a16.py b/examples/quantizing_moe/deepseek_moe_w4a16.py deleted file mode 100644 index 9880e9248..000000000 --- a/examples/quantizing_moe/deepseek_moe_w4a16.py +++ /dev/null @@ -1,125 +0,0 @@ -import torch -from datasets import load_dataset -from packaging.version import Version -from transformers import AutoModelForCausalLM, AutoTokenizer, __version__ - -from llmcompressor import oneshot -from llmcompressor.utils import dispatch_for_generation - -# NOTE: transformers 4.49.0 has an attribute error with DeepSeek. -# Please consider either downgrading your transformers version to a -# previous version or upgrading to a version where this bug is fixed - -# select a Mixture of Experts model for quantization -MODEL_ID = "deepseek-ai/DeepSeek-V2.5" - -model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True -) -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -# Select calibration dataset. -DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" -NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 2048 - - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") -ds = ds.shuffle(seed=42) - - -def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["messages"], - tokenize=False, - ) - } - - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - add_special_tokens=False, - ) - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -# define a llmcompressor recipe for W416 quantization -# since the MoE gate layers are sensitive to quantization, we add them to the ignore -# list so they remain at full precision -recipe = "deepseek_recipe_w4a16.yaml" - -oneshot( - model=model, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - save_compressed=True, - trust_remote_code_model=True, -) - -# Confirm generations of the quantized model look sane. -# Generation is broken for deepseek models when using the latest transformers package -if Version(__version__) < Version("4.48"): - print("========== SAMPLE GENERATION ==============") - dispatch_for_generation(model) - input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") - output = model.generate(input_ids, max_new_tokens=20) - print(tokenizer.decode(output[0])) - print("==========================================") -else: - print( - "WARNING: cannot perform sample generation of " - "deepseek models with transformers >= 4.48" - ) - -# Save to disk in compressed-tensors format. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) - - -# Run the model on vLLM -try: - from vllm import LLM, SamplingParams - - vllm_installed = True -except ImportError: - vllm_installed = False - -if vllm_installed: - print("vLLM installed, running using vLLM") - sampling_params = SamplingParams(temperature=0.80, top_p=0.95) - llm = LLM( - model=SAVE_DIR, - tensor_parallel_size=2, - trust_remote_code=True, - max_model_len=1042, - dtype=torch.half, - ) - prompts = [ - "The capital of France is", - "The president of the US is", - "My name is", - ] - - outputs = llm.generate(prompts, sampling_params) - print("================= vLLM GENERATION ======================") - for output in outputs: - assert output - prompt = output.prompt - generated_text = output.outputs[0].text - print("PROMPT", prompt) - print("GENERATED TEXT", generated_text) diff --git a/examples/quantizing_moe/deepseek_recipe_w4a16.yaml b/examples/quantizing_moe/deepseek_recipe_w4a16.yaml deleted file mode 100644 index 23f276e2f..000000000 --- a/examples/quantizing_moe/deepseek_recipe_w4a16.yaml +++ /dev/null @@ -1,8 +0,0 @@ -quant_stage: - quant_modifiers: - GPTQModifier: - ignore: [lm_head, "re:.*mlp.gate$"] - config_groups: - group_0: - weights: {num_bits: 4, type: int, symmetric: true, strategy: channel, dynamic: false} - targets: [Linear] diff --git a/examples/quantizing_moe/deepseek_moe_w8a8_int8.py b/examples/quantizing_moe/deepseekv2_5_example.py similarity index 76% rename from examples/quantizing_moe/deepseek_moe_w8a8_int8.py rename to examples/quantizing_moe/deepseekv2_5_example.py index 3ec506c34..c2b3b0305 100644 --- a/examples/quantizing_moe/deepseek_moe_w8a8_int8.py +++ b/examples/quantizing_moe/deepseekv2_5_example.py @@ -12,7 +12,7 @@ # previous version or upgrading to a version where this bug is fixed # select a Mixture of Experts model for quantization -MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" +MODEL_ID = "deepseek-ai/DeepSeek-V2.5" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True @@ -20,10 +20,9 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Select calibration dataset. -# its recommended to use more calibration samples for MoE models so each expert is hit DATASET_ID = "HuggingFaceH4/ultrachat_200k" DATASET_SPLIT = "train_sft" -NUM_CALIBRATION_SAMPLES = 2048 +NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 @@ -57,16 +56,12 @@ def tokenize(sample): ds = ds.map(tokenize, remove_columns=ds.column_names) -# define a llmcompressor recipe for INT8 W8A8 quantization +# Configure the quantization algorithm to run. # since the MoE gate layers are sensitive to quantization, we add them to the ignore # list so they remain at full precision -recipe = [ - GPTQModifier( - targets="Linear", - scheme="W8A8", - ignore=["lm_head", "re:.*mlp.gate$"], - ), -] +recipe = GPTQModifier( + targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"] +) oneshot( model=model, @@ -82,12 +77,10 @@ def tokenize(sample): if Version(__version__) < Version("4.48"): print("========== SAMPLE GENERATION ==============") dispatch_for_generation(model) - SAMPLE_INPUT = ["I love quantization because"] - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - inputs = tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to(model.device) - output = model.generate(**inputs, max_length=50) - text_output = tokenizer.batch_decode(output) - print(text_output) + sample = tokenizer("Hello my name is", return_tensors="pt") + sample = {key: value.to("cuda") for key, value in sample.items()} + output = model.generate(**sample, max_new_tokens=100) + print(tokenizer.decode(output[0])) print("==========================================") else: print( @@ -96,6 +89,6 @@ def tokenize(sample): ) # Save to disk in compressed-tensors format. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8" +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/quantizing_moe/deepseekv3_example.py b/examples/quantizing_moe/deepseekv3_example.py new file mode 100644 index 000000000..1b4c334ff --- /dev/null +++ b/examples/quantizing_moe/deepseekv3_example.py @@ -0,0 +1,88 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor.modeling import prepare_for_quantization +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +from llmcompressor.utils import dispatch_for_generation + +# Select model and load it. +# For DeepSeekv3, we require a full precision model in order to properly calibrate +# `DeepSeek-V3-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16 +model_id = "RedHatAI/DeepSeek-V3-BF16" +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = prepare_for_quantization(model) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# since the MoE gate layers are sensitive to quantization, we add them to the ignore +# list so they remain at full precision +recipe = GPTQModifier( + targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"] +) + +# Apply algorithms. +# due to the large size of DeepSeekV3, we specify sequential targets such that +# only one MLP is loaded into GPU memory at a time +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"], +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +sample = tokenizer("Hello my name is", return_tensors="pt") +sample = {key: value.to("cuda") for key, value in sample.items()} +output = model.generate(**sample, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py b/examples/quantizing_moe/mixtral_example.py similarity index 51% rename from examples/quantizing_moe/deepseek_moe_w8a8_fp8.py rename to examples/quantizing_moe/mixtral_example.py index 0bc9c24df..5021c7947 100644 --- a/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py +++ b/examples/quantizing_moe/mixtral_example.py @@ -1,28 +1,23 @@ +import torch from datasets import load_dataset -from packaging.version import Version -from transformers import AutoModelForCausalLM, AutoTokenizer, __version__ +from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.utils import dispatch_for_generation -# NOTE: transformers 4.49.0 has an attribute error with DeepSeek. -# Please consider either downgrading your transformers version to a -# previous version or upgrading to a version where this bug is fixed - # select a Mixture of Experts model for quantization -MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" +MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1" model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, torch_dtype="auto", trust_remote_code=True + MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Select calibration dataset. -# its recommended to use more calibration samples for MoE models so each expert is hit DATASET_ID = "HuggingFaceH4/ultrachat_200k" DATASET_SPLIT = "train_sft" -NUM_CALIBRATION_SAMPLES = 2048 +NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 @@ -56,16 +51,17 @@ def tokenize(sample): ds = ds.map(tokenize, remove_columns=ds.column_names) -# define a llmcompressor recipe for FP8 W8A8 quantization +# Configure the quantization algorithm to run. # since the MoE gate layers are sensitive to quantization, we add them to the ignore # list so they remain at full precision -recipe = [ - QuantizationModifier( - targets="Linear", - scheme="FP8", - ignore=["lm_head", "re:.*mlp.gate$"], - ), -] +recipe = QuantizationModifier( + scheme="FP8", + targets="Linear", + ignore=[ + "lm_head", + "re:.*block_sparse_moe.gate", # does not quantize well + ], +) oneshot( model=model, @@ -76,22 +72,13 @@ def tokenize(sample): trust_remote_code_model=True, ) -# Confirm generations of the quantized model look sane. -# Generation is broken for deepseek models when using the latest transformers package -if Version(__version__) < Version("4.48"): - print("========== SAMPLE GENERATION ==============") - dispatch_for_generation(model) - SAMPLE_INPUT = ["I love quantization because"] - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - inputs = tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to(model.device) - output = model.generate(**inputs, max_length=50) - text_output = tokenizer.batch_decode(output) - print(text_output) -else: - print( - "WARNING: cannot perform sample generation of " - "deepseek models with transformers >= 4.48" - ) +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +sample = tokenizer("Hello my name is", return_tensors="pt") +sample = {key: value.to("cuda") for key, value in sample.items()} +output = model.generate(**sample, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================") # Save to disk in compressed-tensors format. SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8" diff --git a/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py b/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py deleted file mode 100644 index a17bf873d..000000000 --- a/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import List - -from packaging.version import Version -from transformers import AutoModelForCausalLM, AutoTokenizer, __version__ - -from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import QuantizationModifier -from llmcompressor.utils import dispatch_for_generation - -MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1" - -model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - - -# Dataset config parameters -DATASET_ID = "open_platypus" -DATASET_SPLIT = "train" -MAX_SEQ_LENGTH = 2048 -NUM_CALIBRATION_SAMPLES = 512 - -# Recipe -layers_to_ignore: List[str] = [ - "lm_head", - "re:.*block_sparse_moe.gate", # does not quantize well -] -recipe = QuantizationModifier(scheme="FP8", targets="Linear", ignore=layers_to_ignore) - - -oneshot( - model=model, - tokenizer=tokenizer, - dataset=DATASET_ID, - splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"}, - recipe=recipe, - max_seq_length=MAX_SEQ_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, -) - -# Confirm generations of the quantized model look sane. -# Generation is broken for deepseek models when using the latest transformers package -if Version(__version__) < Version("4.48"): - print("========== SAMPLE GENERATION ==============") - dispatch_for_generation(model) - input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") - output = model.generate(input_ids, max_new_tokens=20) - print(tokenizer.decode(output[0])) - print("==========================================") -else: - print( - "WARNING: cannot perform sample generation of " - "deepseek models with transformers >= 4.48" - ) - -# Save to disk in compressed-tensors format. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/quantizing_moe/qwen_moe_w4a16.py b/examples/quantizing_moe/qwen_example.py similarity index 89% rename from examples/quantizing_moe/qwen_moe_w4a16.py rename to examples/quantizing_moe/qwen_example.py index 40a78a9b7..bb00b530e 100644 --- a/examples/quantizing_moe/qwen_moe_w4a16.py +++ b/examples/quantizing_moe/qwen_example.py @@ -56,7 +56,7 @@ def tokenize(sample): # list so they remain at full precision recipe = GPTQModifier( targets="Linear", - scheme="W4A16", + scheme="FP8", ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"], ) @@ -73,12 +73,13 @@ def tokenize(sample): # Confirm generations of the quantized model look sane. print("========== SAMPLE GENERATION ==============") dispatch_for_generation(model) -input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=20) +sample = tokenizer("Hello my name is", return_tensors="pt") +sample = {key: value.to("cuda") for key, value in sample.items()} +output = model.generate(**sample, max_new_tokens=100) print(tokenizer.decode(output[0])) print("==========================================") # Save to disk in compressed-tensors format. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-quantized.w4a16" +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) diff --git a/tests/examples/test_quantizing_moe.py b/tests/examples/test_quantizing_moe.py index 49686d25c..1f5a53a56 100644 --- a/tests/examples/test_quantizing_moe.py +++ b/tests/examples/test_quantizing_moe.py @@ -44,14 +44,15 @@ def test_doc_example_command(self, example_dir: str, tmp_path: Path): "script_filename", [ pytest.param( - "deepseek_moe_w4a16.py", - marks=[ - pytest.mark.multi_gpu, - pytest.mark.skip(reason="exceptionally long run time"), - ], + "deepseekv2_5_example.py", + marks=pytest.mark.skip(reason="exceptionally long run time"), ), - pytest.param("deepseek_moe_w8a8_fp8.py"), - pytest.param("deepseek_moe_w8a8_int8.py", marks=pytest.mark.multi_gpu), + pytest.param( + "deepseekv3_example.py", + marks=pytest.mark.skip(reason="exceptionally long run time"), + ), + pytest.param("mixtral_example.py"), + pytest.param("qwen_example.py"), ], ) def test_deepseek_example_script(