Skip to content

Commit 6781af5

Browse files
authored
[Quantization] Pool model support bitsandbytes (vllm-project#18087)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 1b15df2 commit 6781af5

File tree

2 files changed

+79
-3
lines changed

2 files changed

+79
-3
lines changed

tests/quantization/test_bitsandbytes.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88

99
import pytest
1010
import torch
11+
from transformers import BitsAndBytesConfig
1112

1213
from tests.quantization.utils import is_quant_method_supported
1314

15+
from ..models.utils import check_embeddings_close
1416
from ..utils import compare_two_settings, create_new_process_for_each_test
1517

1618
models_4bit_to_test = [
@@ -19,6 +21,10 @@
1921
"quantize inflight model with both HF and Mistral format weights")
2022
]
2123

24+
models_4bit_to_embedding_test = [
25+
("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"),
26+
]
27+
2228
models_pre_qaunt_4bit_to_test = [
2329
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
2430
'read pre-quantized 4-bit FP4 model'),
@@ -31,6 +37,12 @@
3137
("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
3238
]
3339

40+
models_pre_quant_8bit_to_test = [
41+
('meta-llama/Llama-Guard-3-8B-INT8',
42+
'read pre-quantized llama 8-bit model'),
43+
("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
44+
]
45+
3446

3547
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
3648
reason='bitsandbytes is not supported on this GPU type.')
@@ -39,7 +51,8 @@
3951
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
4052
model_name, description) -> None:
4153

42-
hf_model_kwargs = {"load_in_4bit": True}
54+
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
55+
load_in_4bit=True))
4356
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
4457
model_name, False, hf_model_kwargs)
4558

@@ -77,7 +90,8 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
7790
def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
7891
model_name, description) -> None:
7992

80-
hf_model_kwargs = {"load_in_4bit": True}
93+
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
94+
load_in_4bit=True))
8195
validate_generated_texts(hf_runner,
8296
vllm_runner,
8397
example_prompts[:1],
@@ -113,6 +127,54 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
113127
compare_two_settings(model_name, common_args, pp_args)
114128

115129

130+
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
131+
reason='bitsandbytes is not supported on this GPU type.')
132+
@pytest.mark.parametrize("model_name, description",
133+
models_4bit_to_embedding_test)
134+
@pytest.mark.parametrize("dtype", ["half"])
135+
@create_new_process_for_each_test()
136+
def test_4bit_bnb_embedding_model(
137+
model_name,
138+
description,
139+
hf_runner,
140+
vllm_runner,
141+
example_prompts,
142+
dtype: str,
143+
) -> None:
144+
145+
# The example_prompts has ending "\n", for example:
146+
# "Write a short story about a robot that dreams for the first time.\n"
147+
# sentence_transformers will strip the input texts, see:
148+
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
149+
# This makes the input_ids different between hf_model and vllm_model.
150+
# So we need to strip the input texts to avoid test failing.
151+
example_prompts = [str(s).strip() for s in example_prompts]
152+
153+
# Inflight 4bit quantization
154+
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
155+
load_in_4bit=True))
156+
with hf_runner(
157+
model_name,
158+
dtype=dtype,
159+
model_kwargs=hf_model_kwargs,
160+
is_sentence_transformer=True,
161+
) as hf_model:
162+
hf_outputs = hf_model.encode(example_prompts)
163+
164+
with vllm_runner(model_name,
165+
task="embed",
166+
dtype=dtype,
167+
quantization="bitsandbytes") as vllm_model:
168+
vllm_outputs = vllm_model.encode(example_prompts)
169+
check_embeddings_close(
170+
embeddings_0_lst=hf_outputs,
171+
embeddings_1_lst=vllm_outputs,
172+
name_0="hf",
173+
name_1="vllm",
174+
tol=5e-2,
175+
)
176+
177+
116178
def log_generated_texts(prompts, outputs, runner_name):
117179
logged_texts = []
118180
for i, (_, generated_text) in enumerate(outputs):

vllm/model_executor/model_loader/bitsandbytes_loader.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
download_safetensors_index_file_from_hf, download_weights_from_hf,
3636
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
3737
pt_weights_iterator, safetensors_weights_iterator)
38+
from vllm.model_executor.models import is_pooling_model
3839
from vllm.model_executor.utils import set_weight_attrs
3940
from vllm.platforms import current_platform
4041

@@ -133,6 +134,16 @@ def _prepare_weights(self, model_name_or_path: str,
133134
return hf_weights_files, use_safetensors
134135

135136
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
137+
def _maybe_pool_model(module_name:str):
138+
# For pool model, we need to add the prefix `model.`
139+
# for the weight name if possible.
140+
if self.is_pool_model and self.target_modules[0]. \
141+
startswith("model.") and not module_name.startswith(
142+
"model."):
143+
return "model."+module_name
144+
145+
return module_name
146+
136147
if use_safetensors:
137148
iterator = safetensors_weights_iterator(
138149
hf_weights_files,
@@ -148,6 +159,9 @@ def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
148159
# mapping weight names from transformers to vllm while preserving
149160
# original names.
150161
mapped_name = self.weight_mapper(org_name)
162+
mapped_name=_maybe_pool_model(mapped_name)
163+
164+
151165
yield org_name, mapped_name, param
152166

153167
def _get_quantized_weights_iterator(
@@ -405,7 +419,7 @@ def _load_weights(self, model_config: ModelConfig,
405419
raise AttributeError(
406420
f"Model {type(model).__name__} does not support BitsAndBytes "
407421
"quantization yet. No 'packed_modules_mapping' found.")
408-
422+
self.is_pool_model=is_pooling_model(model)
409423
self.modules_mapping = ParamMapping(
410424
copy.deepcopy(model.packed_modules_mapping))
411425

0 commit comments

Comments
 (0)