@@ -642,7 +642,9 @@ def _get_auto_loss(self, output, output_q, loss_type="abs", loss_alpha=1.0):
642
642
if len (output .shape ) <= 2 :
643
643
max_value = torch .max (torch .abs (output ))
644
644
else :
645
- max_value = torch .max (torch .abs (output .reshape (output .shape [0 ], - 1 )), dim = - 1 ).values
645
+ output = output .reshape (output .shape [0 ], - 1 )
646
+ output_q = output_q .reshape (output_q .shape [0 ], - 1 )
647
+ max_value = torch .max (torch .abs (output ), dim = - 1 ).values .unsqueeze (- 1 )
646
648
max_value = torch .clip (max_value , 1e-5 )
647
649
output = output / max_value ##FIXME need copy not replace
648
650
output_q = output_q / max_value
@@ -712,7 +714,7 @@ def _update_scales_for_auto(self, absorb_scales, weight_scales):
712
714
weight_scale = self ._reshape_scale_for_weight (layer , weight_scale )
713
715
layer .update_scale (input_scale , weight_scale ) ##FIXME
714
716
715
- def _get_one_sample_auto_loss (self , input , alpha_space , orig_best_alpha , input_maxes ):
717
+ def _get_one_batch_auto_loss (self , input , alpha_space , orig_best_alpha , input_maxes ):
716
718
self ._change_qdq_for_auto (enable = False )
717
719
718
720
forward_wrapper (self .model , input , self .device ) ##disable quant and get fp32 output
@@ -793,15 +795,15 @@ def dict_to_list(dic):
793
795
return best_alpha
794
796
795
797
def _auto_tune_alpha_new (
796
- self , input_maxes , auto_calib_iter = 32 , alpha_min = 0.3 , alpha_max = 0.7 , alpha_step = 0.05 , shared_criterion = "min"
798
+ self , input_maxes , calib_sample_num = 32 , alpha_min = 0.3 , alpha_max = 0.7 , alpha_step = 0.05 , shared_criterion = "min"
797
799
):
798
800
"""Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly.
799
801
800
802
This function takes quantization of the former layers into consideration when qdq one layer
801
803
Also, it reduces the memory usage at the cost of increasingtuning time
802
804
TODO may have compatibility issue when setting folding=True
803
805
:param input_maxes:
804
- :param auto_calib_iter :
806
+ :param calib_sample_num :
805
807
:param alpha_min:
806
808
:param alpha_max:
807
809
:param alpha_step:
@@ -828,88 +830,82 @@ def _auto_tune_alpha_new(
828
830
self .absorb_to_layer , input_maxes , default_alpha , tuning = True
829
831
)
830
832
self ._update_scales_for_auto (absorb_input_scales , weight_scales )
831
- loss_alphas = {}
832
833
cnt = 0
833
- multiply_factor = auto_calib_iter // 4 if auto_calib_iter >= 4 else auto_calib_iter
834
+ alpha_update_iter = 0
835
+ # multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha
836
+ multiply_factor = calib_sample_num // 4 if calib_sample_num >= 4 else calib_sample_num
834
837
835
838
best_alphas = default_alpha
836
839
if not self .dataloader :
837
840
self ._qdq_model_unwrapper_for_auto ()
838
841
return best_alphas
839
842
try :
840
843
for input , label in self .dataloader :
844
+ loss_alphas = {}
841
845
best_alphas_per_module = best_alphas
842
846
if isinstance (best_alphas , dict ):
843
847
for key in self .absorb_to_layer .keys ():
844
848
layer_names = self .absorb_to_layer [key ]
845
849
for layer_name in layer_names :
846
850
best_alphas_per_module [layer_name ] = best_alphas_per_module [key ]
847
851
848
- loss_tmp = self ._get_one_sample_auto_loss (input , alpha_space , best_alphas_per_module , input_maxes )
852
+ loss_tmp = self ._get_one_batch_auto_loss (input , alpha_space , best_alphas_per_module , input_maxes )
849
853
if loss_alphas == {}:
850
854
loss_alphas = loss_tmp
851
855
else :
852
856
for key in loss_alphas .keys ():
853
857
cur_loss = loss_alphas [key ]
854
858
for alpha_key in cur_loss .keys ():
855
859
cur_loss [alpha_key ] += loss_tmp [key ][alpha_key ]
856
- if isinstance (input , list ):
857
- input = move_input_to_device (input , self .device )
858
- for inp in input :
859
- cnt += inp .shape [0 ]
860
- else :
861
- cnt += input .shape [0 ]
862
-
863
- if cnt % multiply_factor == 0 and (auto_calib_iter - cnt ) >= multiply_factor :
860
+ cnt += self .dataloader .batch_size
861
+ if cnt // multiply_factor >= 1 :
862
+ alpha_update_iter += 1
863
+ cnt = 0
864
864
best_alphas = self ._get_best_alpha (self .absorb_to_layer , loss_alphas , shared_criterion )
865
865
for key in best_alphas .keys ():
866
- logger .info (f"{ cnt // multiply_factor } , { key } :{ best_alphas [key ]} " )
866
+ logger .info (f"Auto alpha update iter: { alpha_update_iter } , { key } : { best_alphas [key ]} " )
867
867
absorb_input_scales , weight_scales = self ._cal_scales (
868
868
self .absorb_to_layer , input_maxes , best_alphas , tuning = True
869
869
)
870
870
self ._update_scales_for_auto (absorb_input_scales , weight_scales )
871
- loss_alphas = {} ##TODO check need to remove this one
872
- if cnt >= auto_calib_iter :
871
+ if cnt >= calib_sample_num :
873
872
break
874
873
except :
875
874
for input in self .dataloader :
875
+ loss_alphas = {}
876
876
best_alphas_per_module = best_alphas
877
877
if isinstance (best_alphas , dict ):
878
878
for key in self .absorb_to_layer .keys ():
879
879
layer_names = self .absorb_to_layer [key ]
880
880
for layer_name in layer_names :
881
881
best_alphas_per_module [layer_name ] = best_alphas_per_module [key ]
882
882
883
- loss_tmp = self ._get_one_sample_auto_loss (input , alpha_space , best_alphas_per_module , input_maxes )
883
+ loss_tmp = self ._get_one_batch_auto_loss (input , alpha_space , best_alphas_per_module , input_maxes )
884
884
if loss_alphas == {}:
885
885
loss_alphas = loss_tmp
886
886
else :
887
887
for key in loss_alphas .keys ():
888
888
cur_loss = loss_alphas [key ]
889
889
for alpha_key in cur_loss .keys ():
890
890
cur_loss [alpha_key ] += loss_tmp [key ][alpha_key ]
891
- if isinstance (input , list ):
892
- input = move_input_to_device (input , self .device )
893
- for inp in input :
894
- cnt += inp .shape [0 ]
895
- else :
896
- cnt += input .shape [0 ]
891
+ cnt += self .dataloader .batch_size
892
+ if cnt // multiply_factor >= 1 :
893
+ alpha_update_iter += 1
894
+ cnt = 0
897
895
898
- if cnt % multiply_factor == 0 and (auto_calib_iter - cnt ) >= multiply_factor :
899
896
best_alphas = self ._get_best_alpha (self .absorb_to_layer , loss_alphas , shared_criterion )
900
897
for key in best_alphas .keys ():
901
- logger .info (f"{ cnt // multiply_factor } , { key } :{ best_alphas [key ]} " )
898
+ logger .info (f"Auto alpha update iter: { alpha_update_iter } , { key } : { best_alphas [key ]} " )
902
899
absorb_input_scales , weight_scales = self ._cal_scales (
903
900
self .absorb_to_layer , input_maxes , best_alphas , tuning = True
904
901
)
905
902
self ._update_scales_for_auto (absorb_input_scales , weight_scales )
906
- loss_alphas = {} ##TODO check need to remove this one
907
- if cnt >= auto_calib_iter :
903
+ if cnt >= calib_sample_num :
908
904
break
909
905
910
906
best_alphas = self ._get_best_alpha (self .absorb_to_layer , loss_alphas , shared_criterion )
911
907
for key in best_alphas .keys ():
912
- logger .info (f"final { key } :{ best_alphas [key ]} " )
908
+ logger .info (f"Final alpha { key } :{ best_alphas [key ]} " )
913
909
self ._qdq_model_unwrapper_for_auto ()
914
910
logger .info ("auto tuning done" )
915
911
return best_alphas
@@ -999,7 +995,7 @@ def transform(
999
995
1000
996
if alpha == "auto" :
1001
997
self .alpha_per_layer = self ._auto_tune_alpha_new (
1002
- input_maxes_abs , auto_calib_iter = 32 , ** auto_alpha_args
998
+ input_maxes_abs , calib_sample_num = 32 , ** auto_alpha_args
1003
999
) ##save the alpha
1004
1000
1005
1001
if alpha == "auto" :
0 commit comments