Skip to content

Enabled BnB NF4 inference on Gaudi #1457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: habana_main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .jenkins/test_config_t_compile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,20 @@ stages:
command: >-
export PT_HPU_LAZY_MODE=0 && export VLLM_T_COMPILE_FULLGRAPH=True &&
bash .jenkins/benchmark/run-benchmark.sh -fp8
- name: tests_bnb_nf4
steps:
- name: test_load_4bit_bnb_model
flavor: g2
command: >-
PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true
pytest -v tests/quantization/test_bitsandbytes_hpu.py::test_load_4bit_bnb_model
- name: test_load_pre_quant_4bit_bnb_model
flavor: g2
command: >-
PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true
pytest -v tests/quantization/test_bitsandbytes_hpu.py::test_load_pre_quant_4bit_bnb_model
- name: test_load_tp_4bit_bnb_model
flavor: g2
command: >-
PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true
pytest -v tests/quantization/test_bitsandbytes_hpu.py::test_load_tp_4bit_bnb_model
157 changes: 157 additions & 0 deletions tests/quantization/test_bitsandbytes_hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests whether bitsandbytes computation is enabled correctly.

Run `pytest tests/quantization/test_bitsandbytes.py`.
"""

import gc

import pytest
import torch
from transformers import BitsAndBytesConfig

from tests.quantization.utils import is_quant_method_supported

from ..utils import compare_two_settings, create_new_process_for_each_test

models_4bit_to_test = [
(
"mistralai/Mistral-7B-Instruct-v0.3",
"quantize_inflight_model_with_both_HF_and_Mistral_format_weights",
),
("meta-llama/Llama-3.2-1B", "quantize_llama_model_inflight"),
]

models_pre_quant_4bit_to_test = [("unsloth/Llama-3.2-1B-bnb-4bit",
"read_pre-quantized_4-bit_NF4_model")]


@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
@create_new_process_for_each_test()
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
))
validate_generated_texts(
hf_runner,
vllm_runner,
example_prompts[:1],
model_name,
False,
hf_model_kwargs,
)


@pytest.mark.parametrize("model_name, description",
models_pre_quant_4bit_to_test)
@create_new_process_for_each_test()
def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name, True)


@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
@create_new_process_for_each_test()
def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
))
validate_generated_texts(
hf_runner,
vllm_runner,
example_prompts[:1],
model_name,
False,
hf_model_kwargs,
vllm_tp_size=2,
)


@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
@create_new_process_for_each_test()
def test_load_pp_4bit_bnb_model(model_name, description) -> None:
common_args = [
"--disable-log-stats",
"--disable-log-requests",
"--dtype",
"bfloat16",
"--enable-prefix-caching",
"--quantization",
"bitsandbytes",
"--gpu-memory-utilization",
"0.7",
]
pp_args = [
*common_args,
"--pipeline-parallel-size",
"2",
]
compare_two_settings(model_name, common_args, pp_args)


def log_generated_texts(prompts, outputs, runner_name):
logged_texts = []
for i, (_, generated_text) in enumerate(outputs):
log_entry = {
"prompt": prompts[i],
"runner_name": runner_name,
"generated_text": generated_text,
}
logged_texts.append(log_entry)
return logged_texts


def validate_generated_texts(
hf_runner,
vllm_runner,
prompts,
model_name,
pre_quant=False,
hf_model_kwargs=None,
vllm_tp_size=1,
):
# NOTE: run vLLM first, as it requires a clean process
# when using distributed inference
with vllm_runner(
model_name,
quantization=None if pre_quant else "bitsandbytes",
tensor_parallel_size=vllm_tp_size,
enforce_eager=False,
) as llm:
vllm_outputs = llm.generate_greedy(prompts, 8)
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")

# Clean up the GPU memory for the next test
gc.collect()

if hf_model_kwargs is None:
hf_model_kwargs = {}

# Run with HF runner
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
hf_outputs = llm.generate_greedy(prompts, 8)
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")

# Clean up the GPU memory for the next test
gc.collect()

# Compare the generated strings
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
hf_str = hf_log["generated_text"]
vllm_str = vllm_log["generated_text"]
prompt = hf_log["prompt"]

assert hf_str == vllm_str, (f"Model: {model_name}"
f"Mismatch between HF and vLLM outputs:\n"
f"Prompt: {prompt}\n"
f"HF Output: '{hf_str}'\n"
f"vLLM Output: '{vllm_str}'")
12 changes: 6 additions & 6 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op


Expand Down Expand Up @@ -385,12 +386,11 @@ def _apply_bnb_4bit_fake(


try:
direct_register_custom_op(
op_name="apply_bnb_4bit",
op_func=_apply_bnb_4bit,
mutates_args=["out"],
fake_impl=_apply_bnb_4bit_fake,
)
direct_register_custom_op(op_name="apply_bnb_4bit",
op_func=_apply_bnb_4bit,
mutates_args=["out"],
fake_impl=_apply_bnb_4bit_fake,
dispatch_key=current_platform.dispatch_key)
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit

except AttributeError as error:
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,8 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
...]

# bitsandbytes requires data in GPU
if weight_sub_tensor.is_cuda:
if (weight_sub_tensor.is_cuda
or weight_sub_tensor.device.type == "hpu"):
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.cuda()
Expand Down
3 changes: 2 additions & 1 deletion vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class HpuPlatform(Platform):
simple_compile_backend: str = "hpu_backend" if not is_fake_hpu(
) else "inductor"
supported_quantization: list[str] = [
"compressed-tensors", "fp8", "inc", "awq_hpu", "gptq_hpu"
"compressed-tensors", "fp8", "inc", "awq_hpu", "gptq_hpu",
"bitsandbytes"
]

@classmethod
Expand Down
Loading