@@ -145,13 +145,15 @@ def _int8da_int8w_api(
145
145
change_linear_weights_to_int8_dqtensors (mod )
146
146
147
147
148
- def _int4wo_api (mod ):
148
+ def _int4wo_api (mod , use_hqq = False ):
149
149
if (
150
150
is_device (next (mod .parameters ()).device .type , "cpu" )
151
151
and TORCH_VERSION_AT_LEAST_2_6
152
152
):
153
153
quantize_ (
154
- mod , int4_weight_only (layout = Int4CPULayout ()), set_inductor_config = False
154
+ mod ,
155
+ int4_weight_only (layout = Int4CPULayout (), use_hqq = use_hqq ),
156
+ set_inductor_config = False ,
155
157
)
156
158
unwrap_tensor_subclass (mod )
157
159
elif TORCH_VERSION_AT_LEAST_2_4 :
@@ -1049,8 +1051,6 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):
1049
1051
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
1050
1052
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
1051
1053
def test_int4_weight_only_quant_subclass_api (self , device , dtype ):
1052
- if device == "cpu" :
1053
- self .skipTest (f"Temporarily skipping for { device } " )
1054
1054
if dtype != torch .bfloat16 :
1055
1055
self .skipTest (f"Fails for { dtype } " )
1056
1056
for test_shape in [(16 , 1024 , 16 )] + (
@@ -1060,6 +1060,20 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
1060
1060
_int4wo_api , device , 15 , test_shape = test_shape , test_dtype = dtype
1061
1061
)
1062
1062
1063
+ @parameterized .expand (COMMON_DEVICE_DTYPE )
1064
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "int4 hqq requires torch nightly." )
1065
+ def test_int4_weight_only_hqq_quant_subclass_api (self , device , dtype ):
1066
+ if dtype != torch .bfloat16 :
1067
+ self .skipTest (f"Fails for { dtype } " )
1068
+ for test_shape in [(16 , 1024 , 16 ), (1 , 1024 , 256 )]:
1069
+ api = partial (
1070
+ _int4wo_api ,
1071
+ use_hqq = True ,
1072
+ )
1073
+ self ._test_lin_weight_subclass_api_impl (
1074
+ api , device , 15 , test_shape = test_shape , test_dtype = dtype
1075
+ )
1076
+
1063
1077
@parameterized .expand (COMMON_DEVICE_DTYPE )
1064
1078
@unittest .skipIf (
1065
1079
not TORCH_VERSION_AT_LEAST_2_5 , "gemlite tests needs torch 2.5 or greater"
@@ -1111,8 +1125,6 @@ def test_gemlite_layout(self, device, dtype):
1111
1125
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
1112
1126
@skip_if_rocm ("ROCm enablement in progress" )
1113
1127
def test_int4_weight_only_quant_subclass_api_grouped (self , device , dtype ):
1114
- if device == "cpu" :
1115
- self .skipTest (f"Temporarily skipping for { device } " )
1116
1128
if dtype != torch .bfloat16 :
1117
1129
self .skipTest (f"Fails for { dtype } " )
1118
1130
layout_list = []
0 commit comments