From 5ef4f2af08ae4c1828219c4c8a4ea7ff78926146 Mon Sep 17 00:00:00 2001 From: Ruheena Suhani Shaik Date: Mon, 23 Jun 2025 07:27:02 +0300 Subject: [PATCH 1/2] Gaudi support to bnb NF4 inference tests --- tests/quantization/test_bitsandbytes_hpu.py | 157 ++++++++++++++++++ .../layers/quantization/bitsandbytes.py | 12 +- vllm/model_executor/model_loader/loader.py | 3 +- vllm/platforms/hpu.py | 3 +- 4 files changed, 167 insertions(+), 8 deletions(-) create mode 100644 tests/quantization/test_bitsandbytes_hpu.py diff --git a/tests/quantization/test_bitsandbytes_hpu.py b/tests/quantization/test_bitsandbytes_hpu.py new file mode 100644 index 00000000000..dca84f2106b --- /dev/null +++ b/tests/quantization/test_bitsandbytes_hpu.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests whether bitsandbytes computation is enabled correctly. + +Run `pytest tests/quantization/test_bitsandbytes.py`. +""" + +import gc + +import pytest +import torch +from transformers import BitsAndBytesConfig + +from tests.quantization.utils import is_quant_method_supported + +from ..utils import compare_two_settings, create_new_process_for_each_test + +models_4bit_to_test = [ + ( + "mistralai/Mistral-7B-Instruct-v0.3", + "quantize_inflight_model_with_both_HF_and_Mistral_format_weights", + ), + ("meta-llama/Llama-3.2-1B", "quantize_llama_model_inflight"), +] + +models_pre_quant_4bit_to_test = [("hugging-quants/Meta-Llama-3.1-8B-BNB-NF4", + "read_pre-quantized_4-bit_NF4_opt_model")] + + +@pytest.mark.parametrize("model_name, description", models_4bit_to_test) +@create_new_process_for_each_test() +def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + )) + validate_generated_texts( + hf_runner, + vllm_runner, + example_prompts[:1], + model_name, + False, + hf_model_kwargs, + ) + + +@pytest.mark.parametrize("model_name, description", + models_pre_quant_4bit_to_test) +@create_new_process_for_each_test() +def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], + model_name, True) + + +@pytest.mark.parametrize("model_name, description", models_4bit_to_test) +@create_new_process_for_each_test() +def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + )) + validate_generated_texts( + hf_runner, + vllm_runner, + example_prompts[:1], + model_name, + False, + hf_model_kwargs, + vllm_tp_size=2, + ) + + +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", models_4bit_to_test) +@create_new_process_for_each_test() +def test_load_pp_4bit_bnb_model(model_name, description) -> None: + common_args = [ + "--disable-log-stats", + "--disable-log-requests", + "--dtype", + "bfloat16", + "--enable-prefix-caching", + "--quantization", + "bitsandbytes", + "--gpu-memory-utilization", + "0.7", + ] + pp_args = [ + *common_args, + "--pipeline-parallel-size", + "2", + ] + compare_two_settings(model_name, common_args, pp_args) + + +def log_generated_texts(prompts, outputs, runner_name): + logged_texts = [] + for i, (_, generated_text) in enumerate(outputs): + log_entry = { + "prompt": prompts[i], + "runner_name": runner_name, + "generated_text": generated_text, + } + logged_texts.append(log_entry) + return logged_texts + + +def validate_generated_texts( + hf_runner, + vllm_runner, + prompts, + model_name, + pre_quant=False, + hf_model_kwargs=None, + vllm_tp_size=1, +): + # NOTE: run vLLM first, as it requires a clean process + # when using distributed inference + with vllm_runner( + model_name, + quantization=None if pre_quant else "bitsandbytes", + tensor_parallel_size=vllm_tp_size, + enforce_eager=False, + ) as llm: + vllm_outputs = llm.generate_greedy(prompts, 8) + vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") + + # Clean up the GPU memory for the next test + gc.collect() + + if hf_model_kwargs is None: + hf_model_kwargs = {} + + # Run with HF runner + with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: + hf_outputs = llm.generate_greedy(prompts, 8) + hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") + + # Clean up the GPU memory for the next test + gc.collect() + + # Compare the generated strings + for hf_log, vllm_log in zip(hf_logs, vllm_logs): + hf_str = hf_log["generated_text"] + vllm_str = vllm_log["generated_text"] + prompt = hf_log["prompt"] + + assert hf_str == vllm_str, (f"Model: {model_name}" + f"Mismatch between HF and vLLM outputs:\n" + f"Prompt: {prompt}\n" + f"HF Output: '{hf_str}'\n" + f"vLLM Output: '{vllm_str}'") diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index a472779d930..f4762280a1d 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -10,6 +10,7 @@ from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -385,12 +386,11 @@ def _apply_bnb_4bit_fake( try: - direct_register_custom_op( - op_name="apply_bnb_4bit", - op_func=_apply_bnb_4bit, - mutates_args=["out"], - fake_impl=_apply_bnb_4bit_fake, - ) + direct_register_custom_op(op_name="apply_bnb_4bit", + op_func=_apply_bnb_4bit, + mutates_args=["out"], + fake_impl=_apply_bnb_4bit_fake, + dispatch_key=current_platform.dispatch_key) apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit except AttributeError as error: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index ac984e8f2da..a9ddb1acff3 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1108,7 +1108,8 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, ...] # bitsandbytes requires data in GPU - if weight_sub_tensor.is_cuda: + if (weight_sub_tensor.is_cuda + or weight_sub_tensor.device.type == "hpu"): loaded_weight = weight_sub_tensor else: loaded_weight = weight_sub_tensor.cuda() diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 4777a02b1ff..e2a9cf3e27f 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -30,7 +30,8 @@ class HpuPlatform(Platform): simple_compile_backend: str = "hpu_backend" if not is_fake_hpu( ) else "inductor" supported_quantization: list[str] = [ - "compressed-tensors", "fp8", "inc", "awq_hpu", "gptq_hpu" + "compressed-tensors", "fp8", "inc", "awq_hpu", "gptq_hpu", + "bitsandbytes" ] @classmethod From 2f4e982a5effc6db3bc6183d63e6b34f5b3051e0 Mon Sep 17 00:00:00 2001 From: Ruheena Suhani Shaik Date: Mon, 23 Jun 2025 12:09:17 +0300 Subject: [PATCH 2/2] Added bnb NF4 tests to CI --- .jenkins/test_config_t_compile.yaml | 17 +++++++++++++++++ tests/quantization/test_bitsandbytes_hpu.py | 4 ++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/.jenkins/test_config_t_compile.yaml b/.jenkins/test_config_t_compile.yaml index 615d1f15145..9a32a0e4dda 100644 --- a/.jenkins/test_config_t_compile.yaml +++ b/.jenkins/test_config_t_compile.yaml @@ -267,3 +267,20 @@ stages: command: >- export PT_HPU_LAZY_MODE=0 && export VLLM_T_COMPILE_FULLGRAPH=True && bash .jenkins/benchmark/run-benchmark.sh -fp8 + - name: tests_bnb_nf4 + steps: + - name: test_load_4bit_bnb_model + flavor: g2 + command: >- + PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true + pytest -v tests/quantization/test_bitsandbytes_hpu.py::test_load_4bit_bnb_model + - name: test_load_pre_quant_4bit_bnb_model + flavor: g2 + command: >- + PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true + pytest -v tests/quantization/test_bitsandbytes_hpu.py::test_load_pre_quant_4bit_bnb_model + - name: test_load_tp_4bit_bnb_model + flavor: g2 + command: >- + PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true + pytest -v tests/quantization/test_bitsandbytes_hpu.py::test_load_tp_4bit_bnb_model diff --git a/tests/quantization/test_bitsandbytes_hpu.py b/tests/quantization/test_bitsandbytes_hpu.py index dca84f2106b..e79573b2b2a 100644 --- a/tests/quantization/test_bitsandbytes_hpu.py +++ b/tests/quantization/test_bitsandbytes_hpu.py @@ -22,8 +22,8 @@ ("meta-llama/Llama-3.2-1B", "quantize_llama_model_inflight"), ] -models_pre_quant_4bit_to_test = [("hugging-quants/Meta-Llama-3.1-8B-BNB-NF4", - "read_pre-quantized_4-bit_NF4_opt_model")] +models_pre_quant_4bit_to_test = [("unsloth/Llama-3.2-1B-bnb-4bit", + "read_pre-quantized_4-bit_NF4_model")] @pytest.mark.parametrize("model_name, description", models_4bit_to_test)