Skip to content

Commit 016b8d1

Browse files
authored
Enabled BnB NF4 inference on Gaudi (#20172)
Signed-off-by: Ruheena Suhani Shaik <rsshaik@habana.ai>
1 parent 80305c1 commit 016b8d1

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

vllm/model_executor/layers/quantization/bitsandbytes.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.model_executor.layers.quantization import QuantizationMethods
1414
from vllm.model_executor.layers.quantization.base_config import (
1515
QuantizationConfig)
16+
from vllm.platforms import current_platform
1617
from vllm.utils import direct_register_custom_op
1718

1819

@@ -390,12 +391,11 @@ def _apply_bnb_4bit_fake(
390391

391392

392393
try:
393-
direct_register_custom_op(
394-
op_name="apply_bnb_4bit",
395-
op_func=_apply_bnb_4bit,
396-
mutates_args=["out"],
397-
fake_impl=_apply_bnb_4bit_fake,
398-
)
394+
direct_register_custom_op(op_name="apply_bnb_4bit",
395+
op_func=_apply_bnb_4bit,
396+
mutates_args=["out"],
397+
fake_impl=_apply_bnb_4bit_fake,
398+
dispatch_key=current_platform.dispatch_key)
399399
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
400400

401401
except AttributeError as error:

vllm/model_executor/model_loader/bitsandbytes_loader.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ def _get_quantized_weights_iterator(
199199

200200
if self.pre_quant:
201201
if self.load_8bit:
202+
if current_platform.is_hpu():
203+
raise ValueError(
204+
"currently hpu supports 4bit quantization only")
205+
202206
return self._quantized_8bit_generator(
203207
hf_weights_files, use_safetensors,
204208
quant_state_dict), quant_state_dict
@@ -302,6 +306,10 @@ def _parse_quant_state(param_name: str,
302306
in temp_state_dict):
303307
quant_state = _parse_quant_state(mapped_weight_name,
304308
temp_state_dict)
309+
if current_platform.is_hpu():
310+
assert quant_state.quant_type == "nf4", (
311+
"currently hpu supports nf4 quant_type only")
312+
305313
quant_state_dict[mapped_weight_name] = quant_state
306314
yield org_weight_name, weight_tensor
307315
else:
@@ -372,10 +380,12 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
372380
...]
373381

374382
# bitsandbytes requires data in GPU
375-
if weight_sub_tensor.is_cuda:
383+
if (weight_sub_tensor.is_cuda
384+
or weight_sub_tensor.device.type == "hpu"):
376385
loaded_weight = weight_sub_tensor
377386
else:
378-
loaded_weight = weight_sub_tensor.cuda()
387+
loaded_weight = weight_sub_tensor.to(
388+
device=current_platform.device_type)
379389

380390
# remove the following after the issue is fixed:
381391
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342

0 commit comments

Comments
 (0)