@@ -492,8 +492,6 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
492
492
raise ValueError ("Following weights were not initialized from "
493
493
f"checkpoint: { weights_not_loaded } " )
494
494
495
- torch .cuda .empty_cache ()
496
-
497
495
param_dict = dict (model .named_parameters ())
498
496
stacked_quant_state_dict : dict [str , dict [int , Any ]] = {}
499
497
# TODO: Change this lazy import to normal import
@@ -545,6 +543,8 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
545
543
for param_name , param in param_dict .items ():
546
544
if param_name in stacked_quant_state_dict :
547
545
quant_states = stacked_quant_state_dict [param_name ]
546
+ # Dequantize double quantized values during weight loading.
547
+ dequantize_dq (quant_states )
548
548
set_weight_attrs (param , {"bnb_quant_state" : quant_states })
549
549
550
550
pack_ratio = getattr (param , "pack_factor" , - 1 )
@@ -565,6 +565,28 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
565
565
if load_8bit :
566
566
set_weight_attrs (
567
567
param , {"matmul_state" : [None ] * len (quant_states )})
568
-
568
+ torch . cuda . empty_cache ()
569
569
def download_model (self , model_config : ModelConfig ) -> None :
570
570
self ._prepare_weights (model_config .model , model_config .revision )
571
+
572
+
573
+ def dequantize_dq (quant_states : dict ) -> None :
574
+ """
575
+ When BNB employs Double Quantization, we perform the dequantization of
576
+ these constants during weight loading rather than at inference time,
577
+ thereby avoiding this computational overhead during inference. This comes
578
+ at the cost of increased memory usage.
579
+ """
580
+ from bitsandbytes .functional import dequantize_blockwise
581
+ for _ , quant_state in quant_states .items ():
582
+ # Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
583
+ if quant_state .nested :
584
+ absmax = dequantize_blockwise (quant_state .absmax ,
585
+ quant_state .state2 )
586
+ absmax += quant_state .offset
587
+ if absmax .dtype != torch .float32 :
588
+ absmax = absmax .float ()
589
+ quant_state .absmax = absmax
590
+ quant_state .nested = False
591
+ quant_state .offset = None
592
+ quant_state .state2 = None
0 commit comments