Skip to content

Commit 8020e98

Browse files
authored
[Quantization][1/N] MoE support BNB-Inflight Quantization (#20061)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 762be26 commit 8020e98

File tree

8 files changed

+561
-88
lines changed

8 files changed

+561
-88
lines changed

tests/models/quantization/test_bitsandbytes.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tests.quantization.utils import is_quant_method_supported
1515

1616
from ...utils import compare_two_settings, multi_gpu_test
17-
from ..utils import check_embeddings_close
17+
from ..utils import check_embeddings_close, check_logprobs_close
1818

1919
models_4bit_to_test = [
2020
("facebook/opt-125m", "quantize opt model inflight"),
@@ -26,6 +26,10 @@
2626
("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"),
2727
]
2828

29+
models_4bit_to_moe_test = [
30+
("allenai/OLMoE-1B-7B-0125-Instruct", "quantize moe model inflight"),
31+
]
32+
2933
models_pre_qaunt_4bit_to_test = [
3034
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
3135
'read pre-quantized 4-bit FP4 model'),
@@ -115,6 +119,35 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
115119
compare_two_settings(model_name, common_args, pp_args)
116120

117121

122+
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
123+
reason='bitsandbytes is not supported on this GPU type.')
124+
@pytest.mark.parametrize("model_name, description", models_4bit_to_moe_test)
125+
def test_4bit_bnb_moe_model(hf_runner, vllm_runner, example_prompts,
126+
model_name, description) -> None:
127+
128+
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
129+
load_in_4bit=True,
130+
bnb_4bit_quant_type="nf4",
131+
bnb_4bit_use_double_quant=True,
132+
))
133+
with vllm_runner(model_name,
134+
quantization='bitsandbytes',
135+
enforce_eager=False) as llm:
136+
vllm_outputs = llm.generate_greedy_logprobs(example_prompts,
137+
max_tokens=32,
138+
num_logprobs=5)
139+
140+
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
141+
transformers_outputs = llm.generate_greedy_logprobs_limit(
142+
example_prompts, max_tokens=32, num_logprobs=5)
143+
check_logprobs_close(
144+
outputs_0_lst=transformers_outputs,
145+
outputs_1_lst=vllm_outputs,
146+
name_0="transformers",
147+
name_1="vllm",
148+
)
149+
150+
118151
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
119152
reason='bitsandbytes is not supported on this GPU type.')
120153
@pytest.mark.parametrize("model_name, description",
@@ -182,15 +215,17 @@ def validate_generated_texts(hf_runner,
182215
model_name,
183216
pre_quant=False,
184217
hf_model_kwargs=None,
185-
vllm_tp_size=1):
218+
vllm_tp_size=1,
219+
max_tokens=8):
186220

187221
# NOTE: run vLLM first, as it requires a clean process
188222
# when using distributed inference
189223
with vllm_runner(model_name,
190224
quantization=None if pre_quant else 'bitsandbytes',
191225
tensor_parallel_size=vllm_tp_size,
192226
enforce_eager=False) as llm:
193-
vllm_outputs = llm.generate_greedy(prompts, 8)
227+
228+
vllm_outputs = llm.generate_greedy(prompts, max_tokens)
194229
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
195230

196231
# Clean up the GPU memory for the next test
@@ -202,19 +237,17 @@ def validate_generated_texts(hf_runner,
202237

203238
# Run with HF runner
204239
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
205-
hf_outputs = llm.generate_greedy(prompts, 8)
240+
hf_outputs = llm.generate_greedy(prompts, max_tokens)
206241
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
207242

208243
# Clean up the GPU memory for the next test
209244
gc.collect()
210245
torch.cuda.empty_cache()
211-
212246
# Compare the generated strings
213247
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
214248
hf_str = hf_log["generated_text"]
215249
vllm_str = vllm_log["generated_text"]
216250
prompt = hf_log["prompt"]
217-
218251
assert hf_str == vllm_str, (f"Model: {model_name}"
219252
f"Mismatch between HF and vLLM outputs:\n"
220253
f"Prompt: {prompt}\n"

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -883,14 +883,21 @@ def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
883883
expert_data=expert_data,
884884
tp_rank=tp_rank)
885885

886-
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
887-
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
886+
def _load_w13(self,
887+
expert_data: torch.Tensor,
888+
shard_dim: int,
889+
shard_id: str,
890+
loaded_weight: torch.Tensor,
891+
tp_rank: int,
892+
load_full: bool = False):
888893

889894
# Index the loaded weight for tp sharding.
890895
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
891896
shard_size = expert_data.shape[shard_dim] // 2
892-
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
893-
shard_size)
897+
if not load_full:
898+
loaded_weight = loaded_weight.narrow(shard_dim,
899+
shard_size * tp_rank,
900+
shard_size)
894901
# Narrow parameter and load.
895902
# w1, gate_proj: Load into first logical weight of w13.
896903
if shard_id == "w1":
@@ -998,6 +1005,27 @@ def weight_loader(self,
9981005
param.data.copy_(loaded_weight)
9991006
return True if return_success else None
10001007

1008+
# Case for BitsAndBytes
1009+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1010+
if use_bitsandbytes_4bit:
1011+
shard_dim = 0
1012+
1013+
expert_data = param.data[expert_id]
1014+
if shard_id == "w2":
1015+
expert_data.copy_(loaded_weight)
1016+
elif shard_id in ("w1", "w3"):
1017+
# BNB inflight quantization has already sharded the weights
1018+
full_load = True
1019+
self._load_w13(
1020+
shard_id=shard_id,
1021+
shard_dim=shard_dim,
1022+
loaded_weight=loaded_weight,
1023+
expert_data=expert_data,
1024+
tp_rank=self.tp_rank,
1025+
load_full=full_load,
1026+
)
1027+
return True if return_success else None
1028+
10011029
# is_transposed: if the dim to shard the weight
10021030
# should be flipped. Required by GPTQ, compressed-tensors
10031031
# should be whatever dimension intermediate_size_per_partition is

0 commit comments

Comments
 (0)