@@ -295,7 +295,7 @@ def __torch_function__(
295
295
SQNR (DQ , DQ_from_qtensor ),
296
296
)
297
297
298
- qparams2 = cls .get_qparams_func (W )
298
+ qparams2 = cls .get_qparams_func (W , W . dtype )
299
299
Q2 = cls .quantize_func (W , qparams2 )
300
300
DQ2 = cls .dequantize_func (Q2 , qparams2 ).to (W .dtype )
301
301
old_q_out = (
@@ -444,7 +444,9 @@ def faster_quant(cls, H, W, device):
444
444
group_end = min (group_start + group_size , columns )
445
445
if group_start % group_size == 0 :
446
446
# needed for when group_size == columns so only calculate qparams once
447
- cur_qparams = cls .get_qparams_func (W [:, group_start :group_end ])
447
+ cur_qparams = cls .get_qparams_func (
448
+ W [:, group_start :group_end ], orig_dtype
449
+ )
448
450
all_qparams .append (cur_qparams )
449
451
450
452
for index in range (group_start , group_end ): # within each group
@@ -679,10 +681,11 @@ def __init__(
679
681
else :
680
682
self .zero_point_domain = ZeroPointDomain .FLOAT
681
683
682
- self .get_qparams_func = lambda w : get_groupwise_affine_qparams (
684
+ self .get_qparams_func = lambda w , precision : get_groupwise_affine_qparams (
683
685
w ,
684
686
n_bit ,
685
687
group_size ,
688
+ dtype = precision ,
686
689
zero_point_domain = self .zero_point_domain ,
687
690
)
688
691
self .quantize_func = (
0 commit comments