Skip to content

Commit 3fd8c85

Browse files
jeejeeleejimpang
authored andcommitted
[Quantization] Modify the logic of BNB double quantization (vllm-project#19742)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 028dde7 commit 3fd8c85

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

vllm/model_executor/model_loader/bitsandbytes_loader.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,6 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
492492
raise ValueError("Following weights were not initialized from "
493493
f"checkpoint: {weights_not_loaded}")
494494

495-
torch.cuda.empty_cache()
496-
497495
param_dict = dict(model.named_parameters())
498496
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
499497
# TODO: Change this lazy import to normal import
@@ -545,6 +543,8 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
545543
for param_name, param in param_dict.items():
546544
if param_name in stacked_quant_state_dict:
547545
quant_states = stacked_quant_state_dict[param_name]
546+
# Dequantize double quantized values during weight loading.
547+
dequantize_dq(quant_states)
548548
set_weight_attrs(param, {"bnb_quant_state": quant_states})
549549

550550
pack_ratio = getattr(param, "pack_factor", -1)
@@ -565,6 +565,28 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
565565
if load_8bit:
566566
set_weight_attrs(
567567
param, {"matmul_state": [None] * len(quant_states)})
568-
568+
torch.cuda.empty_cache()
569569
def download_model(self, model_config: ModelConfig) -> None:
570570
self._prepare_weights(model_config.model, model_config.revision)
571+
572+
573+
def dequantize_dq(quant_states: dict) -> None:
574+
"""
575+
When BNB employs Double Quantization, we perform the dequantization of
576+
these constants during weight loading rather than at inference time,
577+
thereby avoiding this computational overhead during inference. This comes
578+
at the cost of increased memory usage.
579+
"""
580+
from bitsandbytes.functional import dequantize_blockwise
581+
for _, quant_state in quant_states.items():
582+
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
583+
if quant_state.nested:
584+
absmax = dequantize_blockwise(quant_state.absmax,
585+
quant_state.state2)
586+
absmax += quant_state.offset
587+
if absmax.dtype != torch.float32:
588+
absmax = absmax.float()
589+
quant_state.absmax = absmax
590+
quant_state.nested = False
591+
quant_state.offset = None
592+
quant_state.state2 = None

0 commit comments

Comments
 (0)