56
56
import tempfile
57
57
import gc
58
58
from torch .testing ._internal .common_utils import TestCase
59
+ from torch .testing ._internal import common_utils
59
60
60
61
61
62
def dynamic_quant (model , example_inputs ):
@@ -500,12 +501,13 @@ def test_eval_wrapper_llama3(self):
500
501
501
502
# TODO: move to a separate test file
502
503
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
503
- def test_quantized_tensor_subclass_8da4w (self ):
504
+ @common_utils .parametrize ("mapping_type" , [MappingType .SYMMETRIC , MappingType .SYMMETRIC_NO_CLIPPING_ERR ])
505
+ def test_quantized_tensor_subclass_8da4w (self , mapping_type ):
504
506
group_size = 32
505
507
m = ToyLinearModel ().eval ()
506
508
m_copy = copy .deepcopy (m )
507
509
example_inputs = m .example_inputs ()
508
- quantize_ (m , int8_dynamic_activation_int4_weight (group_size = group_size ))
510
+ quantize_ (m , int8_dynamic_activation_int4_weight (group_size = group_size , mapping_type = mapping_type ))
509
511
510
512
assert isinstance (m .linear1 .weight , LinearActivationQuantizedTensor )
511
513
assert isinstance (m .linear2 .weight , LinearActivationQuantizedTensor )
@@ -516,7 +518,7 @@ def test_quantized_tensor_subclass_8da4w(self):
516
518
from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
517
519
from torchao .quantization .GPTQ import Int8DynActInt4WeightLinear
518
520
519
- quantizer = Int8DynActInt4WeightQuantizer (groupsize = group_size )
521
+ quantizer = Int8DynActInt4WeightQuantizer (groupsize = group_size , mapping_type = mapping_type )
520
522
m_copy = quantizer .quantize (m_copy )
521
523
assert isinstance (m_copy .linear1 , Int8DynActInt4WeightLinear )
522
524
assert isinstance (m_copy .linear2 , Int8DynActInt4WeightLinear )
@@ -704,6 +706,8 @@ def reset_memory():
704
706
assert param .is_cuda
705
707
self .assertLess (memory_streaming , memory_baseline )
706
708
709
+ common_utils .instantiate_parametrized_tests (TestQuantFlow )
710
+
707
711
708
712
if __name__ == "__main__" :
709
713
unittest .main ()
0 commit comments