1616)
1717
1818from  torchao .quantization  import  (
19-     float8_dynamic_activation_float8_weight ,
20-     float8_weight_only ,
21-     int4_weight_only ,
22-     int8_dynamic_activation_int8_weight ,
23-     int8_weight_only ,
19+     Float8DynamicActivationFloat8WeightConfig ,
20+     Float8WeightOnlyConfig ,
21+     Int4WeightOnlyConfig ,
22+     Int8DynamicActivationInt8WeightConfig ,
23+     Int8WeightOnlyConfig ,
2424)
2525from  torchao .quantization .observer  import  PerRow , PerTensor 
2626from  torchao .quantization .quant_api  import  quantize_ 
4242class  TestAffineQuantizedTensorParallel (DTensorTestBase ):
4343    """Basic test case for tensor subclasses""" 
4444
45-     QUANT_METHOD_FN  =  staticmethod (int8_weight_only )
45+     QUANT_METHOD_FN  =  staticmethod (Int8WeightOnlyConfig )
4646    QUANT_METHOD_KWARGS  =  {}
4747
4848    @staticmethod  
@@ -133,7 +133,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
133133
134134
135135class  TestInt8woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
136-     QUANT_METHOD_FN  =  staticmethod (int8_weight_only )
136+     QUANT_METHOD_FN  =  staticmethod (Int8WeightOnlyConfig )
137137    COMMON_DTYPES  =  [torch .bfloat16 , torch .float16 , torch .float32 ]
138138
139139    @common_utils .parametrize ("dtype" , COMMON_DTYPES ) 
@@ -144,7 +144,7 @@ def test_tp(self, dtype):
144144
145145
146146class  TestInt4woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
147-     QUANT_METHOD_FN  =  staticmethod (int4_weight_only )
147+     QUANT_METHOD_FN  =  staticmethod (Int4WeightOnlyConfig )
148148    QUANT_METHOD_KWARGS  =  {"version" : 1 }
149149    COMMON_DTYPES  =  [torch .bfloat16 ]
150150
@@ -167,20 +167,20 @@ class TestGemliteLayoutTensorParallel(TestAffineQuantizedTensorParallel):
167167    @unittest .skipIf (not  torch .cuda .is_available (), "Need CUDA available" ) 
168168    @unittest .skipIf (not  has_gemlite , "gemlite not available" ) 
169169    def  test_tp_gemlite (self , dtype ):
170-         from  torchao .quantization  import  gemlite_uintx_weight_only 
170+         from  torchao .quantization  import  GemliteUIntXWeightOnlyConfig 
171171
172172        for  packing_bitwidth  in  [32 , 8 ]:
173173            for  bit_width  in  [4 , 8 ]:
174174                for  group_size  in  [64 , 32 , None ] if  bit_width  ==  4  else  [None ]:
175-                     api  =  lambda : gemlite_uintx_weight_only (
175+                     api  =  lambda : GemliteUIntXWeightOnlyConfig (
176176                        group_size , bit_width , packing_bitwidth 
177177                    )
178178                    self .QUANT_METHOD_FN  =  staticmethod (api )
179179                    return  self ._test_tp (dtype )
180180
181181
182182class  TestInt8dqAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
183-     QUANT_METHOD_FN  =  staticmethod (int8_dynamic_activation_int8_weight )
183+     QUANT_METHOD_FN  =  staticmethod (Int8DynamicActivationInt8WeightConfig )
184184    COMMON_DTYPES  =  [torch .bfloat16 ]
185185
186186    @common_utils .parametrize ("dtype" , COMMON_DTYPES ) 
@@ -199,7 +199,7 @@ def test_tp(self, dtype):
199199if  torch .cuda .is_available () and  torch .cuda .get_device_capability () >=  (9 , 0 ):
200200
201201    class  TestFloat8woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
202-         QUANT_METHOD_FN  =  staticmethod (float8_weight_only )
202+         QUANT_METHOD_FN  =  staticmethod (Float8WeightOnlyConfig )
203203        COMMON_DTYPES  =  [torch .bfloat16 , torch .float16 , torch .float32 ]
204204
205205        @common_utils .parametrize ("dtype" , COMMON_DTYPES ) 
@@ -211,7 +211,7 @@ def test_tp(self, dtype):
211211    class  TestFloat8dqTensorAffineQuantizedTensorParallel (
212212        TestAffineQuantizedTensorParallel 
213213    ):
214-         QUANT_METHOD_FN  =  staticmethod (float8_dynamic_activation_float8_weight )
214+         QUANT_METHOD_FN  =  staticmethod (Float8DynamicActivationFloat8WeightConfig )
215215        QUANT_METHOD_KWARGS  =  {"granularity" : PerTensor ()}
216216        COMMON_DTYPES  =  [torch .bfloat16 , torch .float16 , torch .float32 ]
217217
@@ -224,7 +224,7 @@ def test_tp(self, dtype):
224224    class  TestFloat8dqRowAffineQuantizedTensorParallel (
225225        TestAffineQuantizedTensorParallel 
226226    ):
227-         QUANT_METHOD_FN  =  staticmethod (float8_dynamic_activation_float8_weight )
227+         QUANT_METHOD_FN  =  staticmethod (Float8DynamicActivationFloat8WeightConfig )
228228        QUANT_METHOD_KWARGS  =  {"granularity" : PerRow ()}
229229        COMMON_DTYPES  =  [torch .bfloat16 ]
230230
0 commit comments