@@ -794,20 +794,20 @@ def dict_to_list(dic):
794
794
raise NotImplementedError
795
795
return best_alpha
796
796
797
- def _auto_tune_alpha_new (
797
+ def _auto_tune_alpha (
798
798
self , input_maxes , calib_sample_num = 32 , alpha_min = 0.3 , alpha_max = 0.7 , alpha_step = 0.05 , shared_criterion = "min"
799
799
):
800
800
"""Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly.
801
801
802
802
This function takes quantization of the former layers into consideration when qdq one layer
803
803
Also, it reduces the memory usage at the cost of increasingtuning time
804
- TODO may have compatibility issue when setting folding=True
805
- :param input_maxes:
806
- :param calib_sample_num:
807
- :param alpha_min:
808
- :param alpha_max:
809
- :param alpha_step:
810
- :param shared_criterion:
804
+ TODO may have compatibility issue when setting folding=True, check whether having issues when bs!=1
805
+ :param input_maxes: calibration data, input max
806
+ :param calib_sample_num: sample count used to auto tuning alpha
807
+ :param alpha_min: the min value of alpha
808
+ :param alpha_max: the max value of alpha
809
+ :param alpha_step: the alpha step in search space
810
+ :param shared_criterion: the criterion to choose alpha when multiple layers must share one same alpha
811
811
:return:
812
812
"""
813
813
logger .info ("start sq auto tuning" )
@@ -830,13 +830,16 @@ def _auto_tune_alpha_new(
830
830
self .absorb_to_layer , input_maxes , default_alpha , tuning = True
831
831
)
832
832
self ._update_scales_for_auto (absorb_input_scales , weight_scales )
833
- cnt = 0
833
+ total_cnt = 0
834
+ tmp_cnt = 0
834
835
alpha_update_iter = 0
835
836
# multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha
836
- multiply_factor = calib_sample_num // 4 if calib_sample_num >= 4 else calib_sample_num
837
+ tune_cnt = 4
838
+ multiply_factor = calib_sample_num // tune_cnt if calib_sample_num >= tune_cnt else calib_sample_num
837
839
838
840
best_alphas = default_alpha
839
841
if not self .dataloader :
842
+ logger .info (f"Auto-tuning failed due to no dataloader, using { best_alphas } instead." )
840
843
self ._qdq_model_unwrapper_for_auto ()
841
844
return best_alphas
842
845
try :
@@ -857,18 +860,19 @@ def _auto_tune_alpha_new(
857
860
cur_loss = loss_alphas [key ]
858
861
for alpha_key in cur_loss .keys ():
859
862
cur_loss [alpha_key ] += loss_tmp [key ][alpha_key ]
860
- cnt += self .dataloader .batch_size
861
- if cnt // multiply_factor >= 1 :
863
+ total_cnt += self .dataloader .batch_size
864
+ tmp_cnt += self .dataloader .batch_size
865
+ if tmp_cnt // multiply_factor >= 1 :
862
866
alpha_update_iter += 1
863
- cnt = 0
867
+ tmp_cnt = 0
864
868
best_alphas = self ._get_best_alpha (self .absorb_to_layer , loss_alphas , shared_criterion )
865
869
for key in best_alphas .keys ():
866
870
logger .info (f"Auto alpha update iter: { alpha_update_iter } , { key } : { best_alphas [key ]} " )
867
871
absorb_input_scales , weight_scales = self ._cal_scales (
868
872
self .absorb_to_layer , input_maxes , best_alphas , tuning = True
869
873
)
870
874
self ._update_scales_for_auto (absorb_input_scales , weight_scales )
871
- if cnt >= calib_sample_num :
875
+ if total_cnt >= calib_sample_num :
872
876
break
873
877
except :
874
878
for input in self .dataloader :
@@ -888,10 +892,11 @@ def _auto_tune_alpha_new(
888
892
cur_loss = loss_alphas [key ]
889
893
for alpha_key in cur_loss .keys ():
890
894
cur_loss [alpha_key ] += loss_tmp [key ][alpha_key ]
891
- cnt += self .dataloader .batch_size
892
- if cnt // multiply_factor >= 1 :
895
+ total_cnt += self .dataloader .batch_size
896
+ tmp_cnt += self .dataloader .batch_size
897
+ if tmp_cnt // multiply_factor >= 1 :
893
898
alpha_update_iter += 1
894
- cnt = 0
899
+ tmp_cnt = 0
895
900
896
901
best_alphas = self ._get_best_alpha (self .absorb_to_layer , loss_alphas , shared_criterion )
897
902
for key in best_alphas .keys ():
@@ -900,7 +905,7 @@ def _auto_tune_alpha_new(
900
905
self .absorb_to_layer , input_maxes , best_alphas , tuning = True
901
906
)
902
907
self ._update_scales_for_auto (absorb_input_scales , weight_scales )
903
- if cnt >= calib_sample_num :
908
+ if total_cnt >= calib_sample_num :
904
909
break
905
910
906
911
best_alphas = self ._get_best_alpha (self .absorb_to_layer , loss_alphas , shared_criterion )
@@ -934,7 +939,6 @@ def transform(
934
939
logger .warning ("smooth quant is ignored since the model is not a torch module" )
935
940
return self .model
936
941
937
- logger .info ("call new sq" ) ##TODO need to remove later
938
942
if folding :
939
943
self .insert_mul , self .allow_absorb = False , True
940
944
else :
@@ -994,7 +998,7 @@ def transform(
994
998
del self .absorb_to_layer [d ]
995
999
996
1000
if alpha == "auto" :
997
- self .alpha_per_layer = self ._auto_tune_alpha_new (
1001
+ self .alpha_per_layer = self ._auto_tune_alpha (
998
1002
input_maxes_abs , calib_sample_num = 32 , ** auto_alpha_args
999
1003
) ##save the alpha
1000
1004
0 commit comments