Skip to content

Commit e9c14a5

Browse files
xin3hechensuyue
authored andcommitted
fix bug in smoothquant for auto alpha (#1287)
(cherry picked from commit 496bd60)
1 parent 424cf3a commit e9c14a5

File tree

2 files changed

+45
-45
lines changed

2 files changed

+45
-45
lines changed

neural_compressor/adaptor/torch_utils/smooth_quant.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,9 @@ def _get_auto_loss(self, output, output_q, loss_type="abs", loss_alpha=1.0):
642642
if len(output.shape) <= 2:
643643
max_value = torch.max(torch.abs(output))
644644
else:
645-
max_value = torch.max(torch.abs(output.reshape(output.shape[0], -1)), dim=-1).values
645+
output = output.reshape(output.shape[0], -1)
646+
output_q = output_q.reshape(output_q.shape[0], -1)
647+
max_value = torch.max(torch.abs(output), dim=-1).values.unsqueeze(-1)
646648
max_value = torch.clip(max_value, 1e-5)
647649
output = output / max_value ##FIXME need copy not replace
648650
output_q = output_q / max_value
@@ -712,7 +714,7 @@ def _update_scales_for_auto(self, absorb_scales, weight_scales):
712714
weight_scale = self._reshape_scale_for_weight(layer, weight_scale)
713715
layer.update_scale(input_scale, weight_scale) ##FIXME
714716

715-
def _get_one_sample_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes):
717+
def _get_one_batch_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes):
716718
self._change_qdq_for_auto(enable=False)
717719

718720
forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output
@@ -793,15 +795,15 @@ def dict_to_list(dic):
793795
return best_alpha
794796

795797
def _auto_tune_alpha_new(
796-
self, input_maxes, auto_calib_iter=32, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, shared_criterion="min"
798+
self, input_maxes, calib_sample_num=32, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, shared_criterion="min"
797799
):
798800
"""Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly.
799801
800802
This function takes quantization of the former layers into consideration when qdq one layer
801803
Also, it reduces the memory usage at the cost of increasingtuning time
802804
TODO may have compatibility issue when setting folding=True
803805
:param input_maxes:
804-
:param auto_calib_iter:
806+
:param calib_sample_num:
805807
:param alpha_min:
806808
:param alpha_max:
807809
:param alpha_step:
@@ -828,88 +830,82 @@ def _auto_tune_alpha_new(
828830
self.absorb_to_layer, input_maxes, default_alpha, tuning=True
829831
)
830832
self._update_scales_for_auto(absorb_input_scales, weight_scales)
831-
loss_alphas = {}
832833
cnt = 0
833-
multiply_factor = auto_calib_iter // 4 if auto_calib_iter >= 4 else auto_calib_iter
834+
alpha_update_iter = 0
835+
# multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha
836+
multiply_factor = calib_sample_num // 4 if calib_sample_num >= 4 else calib_sample_num
834837

835838
best_alphas = default_alpha
836839
if not self.dataloader:
837840
self._qdq_model_unwrapper_for_auto()
838841
return best_alphas
839842
try:
840843
for input, label in self.dataloader:
844+
loss_alphas = {}
841845
best_alphas_per_module = best_alphas
842846
if isinstance(best_alphas, dict):
843847
for key in self.absorb_to_layer.keys():
844848
layer_names = self.absorb_to_layer[key]
845849
for layer_name in layer_names:
846850
best_alphas_per_module[layer_name] = best_alphas_per_module[key]
847851

848-
loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
852+
loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
849853
if loss_alphas == {}:
850854
loss_alphas = loss_tmp
851855
else:
852856
for key in loss_alphas.keys():
853857
cur_loss = loss_alphas[key]
854858
for alpha_key in cur_loss.keys():
855859
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
856-
if isinstance(input, list):
857-
input = move_input_to_device(input, self.device)
858-
for inp in input:
859-
cnt += inp.shape[0]
860-
else:
861-
cnt += input.shape[0]
862-
863-
if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
860+
cnt += self.dataloader.batch_size
861+
if cnt // multiply_factor >= 1:
862+
alpha_update_iter += 1
863+
cnt = 0
864864
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
865865
for key in best_alphas.keys():
866-
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
866+
logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}")
867867
absorb_input_scales, weight_scales = self._cal_scales(
868868
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
869869
)
870870
self._update_scales_for_auto(absorb_input_scales, weight_scales)
871-
loss_alphas = {} ##TODO check need to remove this one
872-
if cnt >= auto_calib_iter:
871+
if cnt >= calib_sample_num:
873872
break
874873
except:
875874
for input in self.dataloader:
875+
loss_alphas = {}
876876
best_alphas_per_module = best_alphas
877877
if isinstance(best_alphas, dict):
878878
for key in self.absorb_to_layer.keys():
879879
layer_names = self.absorb_to_layer[key]
880880
for layer_name in layer_names:
881881
best_alphas_per_module[layer_name] = best_alphas_per_module[key]
882882

883-
loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
883+
loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
884884
if loss_alphas == {}:
885885
loss_alphas = loss_tmp
886886
else:
887887
for key in loss_alphas.keys():
888888
cur_loss = loss_alphas[key]
889889
for alpha_key in cur_loss.keys():
890890
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
891-
if isinstance(input, list):
892-
input = move_input_to_device(input, self.device)
893-
for inp in input:
894-
cnt += inp.shape[0]
895-
else:
896-
cnt += input.shape[0]
891+
cnt += self.dataloader.batch_size
892+
if cnt // multiply_factor >= 1:
893+
alpha_update_iter += 1
894+
cnt = 0
897895

898-
if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
899896
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
900897
for key in best_alphas.keys():
901-
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
898+
logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}")
902899
absorb_input_scales, weight_scales = self._cal_scales(
903900
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
904901
)
905902
self._update_scales_for_auto(absorb_input_scales, weight_scales)
906-
loss_alphas = {} ##TODO check need to remove this one
907-
if cnt >= auto_calib_iter:
903+
if cnt >= calib_sample_num:
908904
break
909905

910906
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
911907
for key in best_alphas.keys():
912-
logger.info(f"final {key}:{best_alphas[key]}")
908+
logger.info(f"Final alpha {key}:{best_alphas[key]}")
913909
self._qdq_model_unwrapper_for_auto()
914910
logger.info("auto tuning done")
915911
return best_alphas
@@ -999,7 +995,7 @@ def transform(
999995

1000996
if alpha == "auto":
1001997
self.alpha_per_layer = self._auto_tune_alpha_new(
1002-
input_maxes_abs, auto_calib_iter=32, **auto_alpha_args
998+
input_maxes_abs, calib_sample_num=32, **auto_alpha_args
1003999
) ##save the alpha
10041000

10051001
if alpha == "auto":

test/algorithm/test_smooth_quant.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,19 @@ def __iter__(self):
5353

5454
class LLMCalibDataloader:
5555
def __init__(self):
56-
self.batch_size = 1
56+
self.batch_size = 3
5757

5858
def __iter__(self):
59-
yield torch.ones([1, 3], dtype=torch.long)
59+
for i in range(4):
60+
yield torch.ones([3, 3], dtype=torch.long)
6061

6162

6263
class TestSqDepthwiseConv(unittest.TestCase):
6364
@classmethod
6465
def setUpClass(self):
6566
class RandDataloader:
6667
def __init__(self):
67-
pass
68+
self.batch_size = 1
6869

6970
def __iter__(self):
7071
yield torch.rand((1, 3, 1, 1))
@@ -141,7 +142,7 @@ class TestSqConvOpFuseAuto(unittest.TestCase):
141142
def setUpClass(self):
142143
class RandDataloader:
143144
def __init__(self):
144-
pass
145+
self.batch_size = 1
145146

146147
def __iter__(self):
147148
yield torch.rand((1, 3, 1, 1))
@@ -181,7 +182,7 @@ class TestSqConvOpFuse(unittest.TestCase):
181182
def setUpClass(self):
182183
class RandDataloader:
183184
def __init__(self):
184-
pass
185+
self.batch_size = 1
185186

186187
def __iter__(self):
187188
yield torch.rand((1, 3, 1, 1))
@@ -386,21 +387,21 @@ class TestSqListInput(unittest.TestCase):
386387
def setUpClass(self):
387388
class ListDataloader:
388389
def __init__(self):
389-
pass
390+
self.batch_size = 1
390391

391392
def __iter__(self):
392393
yield [torch.rand((1, 3))]
393394

394395
class TupleDataloader:
395396
def __init__(self):
396-
pass
397+
self.batch_size = 1
397398

398399
def __iter__(self):
399400
yield (torch.rand((1, 3)))
400401

401402
class ListTupleDataLoader:
402403
def __init__(self):
403-
pass
404+
self.batch_size = 1
404405

405406
def __iter__(self):
406407
input1 = torch.rand((1, 3))
@@ -499,7 +500,7 @@ class TestAlphaAutoLinear(unittest.TestCase):
499500
def setUpClass(self):
500501
class RandDataloader:
501502
def __init__(self):
502-
pass
503+
self.batch_size = 1
503504

504505
def __iter__(self):
505506
yield torch.rand((1, 3))
@@ -535,7 +536,7 @@ class TestSqLinearOpFuse(unittest.TestCase):
535536
def setUpClass(self):
536537
class RandDataloader:
537538
def __init__(self):
538-
pass
539+
self.batch_size = 1
539540

540541
def __iter__(self):
541542
yield torch.rand((1, 3))
@@ -736,6 +737,8 @@ def test_sq_qkv(self):
736737
sq.transform(alpha=0.5, calib_iter=-1, folding=False)
737738
assert isinstance(sq.model.model.decoder.layers[0].self_attn.k_proj, SQLinearWrapper)
738739

740+
741+
class TestExample(unittest.TestCase):
739742
def test_sq_quant(self):
740743
from neural_compressor import PostTrainingQuantConfig, quantization
741744

@@ -763,10 +766,11 @@ def forward(self, x):
763766

764767
class CalibDataloader:
765768
def __init__(self):
766-
self.batch_size = 1
769+
self.batch_size = 3
767770

768771
def __iter__(self):
769-
yield input_ids
772+
for i in range(4):
773+
yield input_ids
770774

771775
def calib_func(model):
772776
for i in range(10):
@@ -935,7 +939,7 @@ class TestSqSkipOp(unittest.TestCase):
935939
def setUpClass(self):
936940
class RandDataloader:
937941
def __init__(self):
938-
pass
942+
self.batch_size = 1
939943

940944
def __iter__(self):
941945
yield torch.rand((1, 4))
@@ -992,7 +996,7 @@ class TestSqSkipOp_attn(unittest.TestCase):
992996
def setUpClass(self):
993997
class RandDataloader:
994998
def __init__(self):
995-
pass
999+
self.batch_size = 1
9961000

9971001
def __iter__(self):
9981002
yield torch.rand((1, 4))

0 commit comments

Comments
 (0)