Skip to content

Commit 2cf8fda

Browse files
Enable the CPU int4 with HQQ quant (#1824)
* Enable the CPU int4 with HQQ quant * ruff check * format code
1 parent b6db962 commit 2cf8fda

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

test/integration/test_integration.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,15 @@ def _int8da_int8w_api(
145145
change_linear_weights_to_int8_dqtensors(mod)
146146

147147

148-
def _int4wo_api(mod):
148+
def _int4wo_api(mod, use_hqq=False):
149149
if (
150150
is_device(next(mod.parameters()).device.type, "cpu")
151151
and TORCH_VERSION_AT_LEAST_2_6
152152
):
153153
quantize_(
154-
mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False
154+
mod,
155+
int4_weight_only(layout=Int4CPULayout(), use_hqq=use_hqq),
156+
set_inductor_config=False,
155157
)
156158
unwrap_tensor_subclass(mod)
157159
elif TORCH_VERSION_AT_LEAST_2_4:
@@ -1049,8 +1051,6 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):
10491051
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
10501052
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
10511053
def test_int4_weight_only_quant_subclass_api(self, device, dtype):
1052-
if device == "cpu":
1053-
self.skipTest(f"Temporarily skipping for {device}")
10541054
if dtype != torch.bfloat16:
10551055
self.skipTest(f"Fails for {dtype}")
10561056
for test_shape in [(16, 1024, 16)] + (
@@ -1060,6 +1060,20 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
10601060
_int4wo_api, device, 15, test_shape=test_shape, test_dtype=dtype
10611061
)
10621062

1063+
@parameterized.expand(COMMON_DEVICE_DTYPE)
1064+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "int4 hqq requires torch nightly.")
1065+
def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype):
1066+
if dtype != torch.bfloat16:
1067+
self.skipTest(f"Fails for {dtype}")
1068+
for test_shape in [(16, 1024, 16), (1, 1024, 256)]:
1069+
api = partial(
1070+
_int4wo_api,
1071+
use_hqq=True,
1072+
)
1073+
self._test_lin_weight_subclass_api_impl(
1074+
api, device, 15, test_shape=test_shape, test_dtype=dtype
1075+
)
1076+
10631077
@parameterized.expand(COMMON_DEVICE_DTYPE)
10641078
@unittest.skipIf(
10651079
not TORCH_VERSION_AT_LEAST_2_5, "gemlite tests needs torch 2.5 or greater"
@@ -1111,8 +1125,6 @@ def test_gemlite_layout(self, device, dtype):
11111125
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
11121126
@skip_if_rocm("ROCm enablement in progress")
11131127
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
1114-
if device == "cpu":
1115-
self.skipTest(f"Temporarily skipping for {device}")
11161128
if dtype != torch.bfloat16:
11171129
self.skipTest(f"Fails for {dtype}")
11181130
layout_list = []

test/quantization/test_quant_api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,8 @@ def reset_memory():
782782
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
783783
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
784784
@common_utils.parametrize("x_dim", [2, 3])
785-
def test_int4wo_cpu(self, dtype, x_dim):
785+
@common_utils.parametrize("use_hqq", [True, False])
786+
def test_int4wo_cpu(self, dtype, x_dim, use_hqq):
786787
from torchao.dtypes import Int4CPULayout
787788

788789
device = "cpu"
@@ -792,7 +793,12 @@ def test_int4wo_cpu(self, dtype, x_dim):
792793
example_inputs = (example_inputs[0].unsqueeze(0),)
793794

794795
with torch.no_grad():
795-
quantize_(m, int4_weight_only(group_size=32, layout=Int4CPULayout()))
796+
quantize_(
797+
m,
798+
int4_weight_only(
799+
group_size=32, layout=Int4CPULayout(), use_hqq=use_hqq
800+
),
801+
)
796802
# ensure the expected op is in the code
797803
_, code = torch._inductor.utils.run_and_get_code(
798804
torch.compile(m, fullgraph=True, dynamic=True),

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def from_hp_to_intx(
224224
else input_float.dtype
225225
)
226226
device = input_float.device
227+
from torchao.dtypes import Int4CPULayout
227228
from torchao.dtypes.uintx import TensorCoreTiledLayout
228229

229230
data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(
@@ -235,7 +236,7 @@ def from_hp_to_intx(
235236
device=device,
236237
verbose=False,
237238
raw_output=not isinstance(
238-
_layout, (TensorCoreTiledLayout, PlainLayout)
239+
_layout, (TensorCoreTiledLayout, PlainLayout, Int4CPULayout)
239240
),
240241
# raw_output=False is basically the 'convert to TensorCoreTiledLayout zero_point version' option (add scale*midpoint)
241242
# note in choose_qparams_affine, preserve_zero = False does this same thing while also controlling whether

0 commit comments

Comments
 (0)