Skip to content

Commit f24f37b

Browse files
Xia-Weiwenpytorchmergebot
authored andcommitted
[CPU] Add concat-linear fusion pass for da8w4 (#2476)
**Summary** This PR adds a concat-linear fusion pass for da8w4 on CPU. The pass fuses the following pattern ``` da8w4_linear_cpu(x, ..., w1, ...) -- y1 / x --da8w4_linear_cpu(x, ..., w2, ...) -- y2 \... da8w4_linear_cpu(x, ..., wN, ...) -- yN ``` to ``` x -- da8w4_linear_cpu(x, ..., w_concat, ...) -- y_concat -- split -- (y1, y2, yN) ``` The fusion pass is registered as a custom post_grad pass in Inductor. The pass takes effect only when `torch._inductor.config.cpp.enable_concat_linear` is true. Benchmarks show that total CPU time of linear is reduced by >5% with concat linear when running Llama3.1-8B with 32 cores on a 6th gen of Intel(R) Xeon(R). **Test plan** ``` pytest test/quantization/test_da8w4_cpu.py -k test_8da4w_concat_linear_cpu ``` Pull Request resolved: #2476 Approved by: https://github.com/leslie-fang-intel, https://github.com/CaoE, https://github.com/jerryzh168
1 parent a45b1f7 commit f24f37b

File tree

6 files changed

+413
-70
lines changed

6 files changed

+413
-70
lines changed

test/quantization/test_da8w4_cpu.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import unittest
9+
10+
import torch
11+
from torch.testing._internal import common_utils
12+
from torch.testing._internal.common_utils import (
13+
TestCase,
14+
run_tests,
15+
)
16+
17+
from torchao import quantize_
18+
from torchao.dtypes import (
19+
Int8DynamicActInt4WeightCPULayout,
20+
PlainLayout,
21+
)
22+
from torchao.quantization.quant_api import (
23+
Int8DynamicActivationInt4WeightConfig,
24+
)
25+
from torchao.quantization.quant_primitives import MappingType
26+
from torchao.utils import (
27+
TORCH_VERSION_AT_LEAST_2_7,
28+
TORCH_VERSION_AT_LEAST_2_8,
29+
)
30+
31+
32+
class ToyLinearModel(torch.nn.Module):
33+
def __init__(self, m=64, n=32, k=64, bias=False):
34+
super().__init__()
35+
self.linear1 = torch.nn.Linear(m, n, bias=bias).to(torch.float)
36+
self.linear2 = torch.nn.Linear(n, k, bias=bias).to(torch.float)
37+
38+
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
39+
return (
40+
torch.randn(
41+
batch_size, self.linear1.in_features, dtype=dtype, device=device
42+
),
43+
)
44+
45+
def forward(self, x):
46+
x = self.linear1(x)
47+
x = self.linear2(x)
48+
return x
49+
50+
51+
class TestDa8w4Cpu(TestCase):
52+
@unittest.skipIf(
53+
"CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"),
54+
reason="cpp kernels not built",
55+
)
56+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Test only enabled for 2.7+")
57+
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
58+
@common_utils.parametrize("x_dim", [2, 3])
59+
@common_utils.parametrize("bias", [True, False])
60+
@common_utils.parametrize("bs", [1, 160])
61+
@common_utils.parametrize("sym_quant_a", [True, False])
62+
def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a):
63+
if sym_quant_a and not TORCH_VERSION_AT_LEAST_2_8:
64+
# not supported until PT 2.8
65+
return
66+
device = "cpu"
67+
m = ToyLinearModel(bias=bias).eval().to(dtype).to(device)
68+
m2 = copy.deepcopy(m)
69+
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device)
70+
if x_dim == 3:
71+
example_inputs = (example_inputs[0].unsqueeze(0),)
72+
73+
with torch.no_grad():
74+
# Currently, the difference between Int8DynamicActInt4WeightCPULayout and PlainLayout
75+
# is that the former packs two int4 weights into one int8, while the latter does not.
76+
quantize_(
77+
m,
78+
Int8DynamicActivationInt4WeightConfig(
79+
group_size=32,
80+
layout=Int8DynamicActInt4WeightCPULayout(),
81+
act_mapping_type=MappingType.SYMMETRIC
82+
if sym_quant_a
83+
else MappingType.ASYMMETRIC,
84+
),
85+
)
86+
y, code = torch._inductor.utils.run_and_get_code(
87+
torch.compile(m, fullgraph=True, dynamic=True),
88+
*example_inputs,
89+
)
90+
# ensure the expected op is in the code
91+
assert "torch.ops.torchao.da8w4_linear_cpu.default" in code[0]
92+
quantize_(
93+
m2,
94+
Int8DynamicActivationInt4WeightConfig(
95+
group_size=32,
96+
layout=PlainLayout(),
97+
act_mapping_type=MappingType.SYMMETRIC
98+
if sym_quant_a
99+
else MappingType.ASYMMETRIC,
100+
),
101+
)
102+
torch._dynamo.reset() # may segfault without this
103+
y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs)
104+
atol, rtol = 4e-7, 1e-5
105+
if dtype == torch.bfloat16:
106+
atol, rtol = 1e-2, 3e-3
107+
elif dtype == torch.half:
108+
atol, rtol = 6e-3, 2e-3
109+
assert torch.allclose(y, y2, atol=atol, rtol=rtol)
110+
# Test get_plain by dequantize()
111+
dqw1 = m.linear1.weight.original_weight_tensor.dequantize()
112+
dqw2 = m.linear2.weight.original_weight_tensor.dequantize()
113+
dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize()
114+
dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize()
115+
assert torch.allclose(dqw1, dqw1_ref)
116+
assert torch.allclose(dqw2, dqw2_ref)
117+
118+
@unittest.skipIf(
119+
"CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"),
120+
reason="cpp kernels not built",
121+
)
122+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Test only enabled for 2.8+")
123+
@common_utils.parametrize("x_dim", [2, 3])
124+
@common_utils.parametrize("bias", [True, False])
125+
def test_8da4w_concat_linear_cpu(self, x_dim, bias):
126+
N, K = 64, 128
127+
128+
class Mod(torch.nn.Module):
129+
def __init__(self, bias):
130+
super().__init__()
131+
self.linear1 = torch.nn.Linear(K, N, bias=bias)
132+
self.linear2 = torch.nn.Linear(K, N, bias=bias)
133+
self.linear3 = torch.nn.Linear(K, N, bias=bias)
134+
135+
def forward(self, x):
136+
a = self.linear1(x)
137+
b = self.linear2(x)
138+
c = self.linear3(x)
139+
return a + b + c
140+
141+
dtype = torch.bfloat16
142+
device = "cpu"
143+
m = Mod(bias).eval().to(dtype).to(device)
144+
x_shape = [2] * x_dim
145+
x_shape[-1] = K
146+
x = torch.rand(x_shape, dtype=dtype, device=device)
147+
with torch.no_grad():
148+
quantize_(
149+
m,
150+
Int8DynamicActivationInt4WeightConfig(
151+
group_size=32,
152+
layout=Int8DynamicActInt4WeightCPULayout(),
153+
act_mapping_type=MappingType.SYMMETRIC,
154+
),
155+
)
156+
# Need to turn on freezing to get the pattern
157+
# set enable_concat_linear to true to enable the fusion
158+
with torch._inductor.config.patch(
159+
{"freezing": True, "cpp.enable_concat_linear": True}
160+
):
161+
y, code = torch._inductor.utils.run_and_get_code(
162+
torch.compile(m, fullgraph=True, dynamic=True),
163+
x,
164+
)
165+
# ensure the expected op occurs only once in the code after fusion
166+
# The trailing "(" is to avoid matching the op in the comment
167+
assert code[0].count("torch.ops.torchao.da8w4_linear_cpu.default(") == 1
168+
with torch._inductor.config.patch(
169+
{"freezing": True, "cpp.enable_concat_linear": False}
170+
):
171+
y_ref, code = torch._inductor.utils.run_and_get_code(
172+
torch.compile(m, fullgraph=True, dynamic=True),
173+
x,
174+
)
175+
assert torch.allclose(y, y_ref)
176+
177+
178+
common_utils.instantiate_parametrized_tests(TestDa8w4Cpu)
179+
180+
181+
if __name__ == "__main__":
182+
run_tests()

test/quantization/test_quant_api.py

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
AffineQuantizedTensor,
3030
Int4CPULayout,
3131
Int4XPULayout,
32-
Int8DynamicActInt4WeightCPULayout,
3332
PlainLayout,
3433
QDQLayout,
3534
TensorCoreTiledLayout,
@@ -71,7 +70,6 @@
7170
TORCH_VERSION_AT_LEAST_2_4,
7271
TORCH_VERSION_AT_LEAST_2_5,
7372
TORCH_VERSION_AT_LEAST_2_6,
74-
TORCH_VERSION_AT_LEAST_2_7,
7573
TORCH_VERSION_AT_LEAST_2_8,
7674
is_sm_at_least_89,
7775
is_sm_at_least_90,
@@ -699,72 +697,6 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq):
699697
assert "_weight_int4pack_mm_for_cpu" in code[0]
700698
assert "aten.mm.default" not in code[0]
701699

702-
@unittest.skipIf(
703-
"CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"),
704-
reason="cpp kernels not built",
705-
)
706-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Test only enabled for 2.7+")
707-
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
708-
@common_utils.parametrize("x_dim", [2, 3])
709-
@common_utils.parametrize("bias", [True, False])
710-
@common_utils.parametrize("bs", [1, 160])
711-
@common_utils.parametrize("sym_quant_a", [True, False])
712-
def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a):
713-
if sym_quant_a and not TORCH_VERSION_AT_LEAST_2_8:
714-
# not supported until PT 2.8
715-
return
716-
device = "cpu"
717-
m = ToyLinearModel(bias=bias).eval().to(dtype).to(device)
718-
m2 = copy.deepcopy(m)
719-
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device)
720-
if x_dim == 3:
721-
example_inputs = (example_inputs[0].unsqueeze(0),)
722-
723-
with torch.no_grad():
724-
# Currently, the difference between Int8DynamicActInt4WeightCPULayout and PlainLayout
725-
# is that the former packs two int4 weights into one int8, while the latter does not.
726-
quantize_(
727-
m,
728-
Int8DynamicActivationInt4WeightConfig(
729-
group_size=32,
730-
layout=Int8DynamicActInt4WeightCPULayout(),
731-
act_mapping_type=MappingType.SYMMETRIC
732-
if sym_quant_a
733-
else MappingType.ASYMMETRIC,
734-
),
735-
)
736-
y, code = torch._inductor.utils.run_and_get_code(
737-
torch.compile(m, fullgraph=True, dynamic=True),
738-
*example_inputs,
739-
)
740-
# ensure the expected op is in the code
741-
assert "torch.ops.torchao.da8w4_linear_cpu.default" in code[0]
742-
quantize_(
743-
m2,
744-
int8_dynamic_activation_int4_weight(
745-
group_size=32,
746-
layout=PlainLayout(),
747-
act_mapping_type=MappingType.SYMMETRIC
748-
if sym_quant_a
749-
else MappingType.ASYMMETRIC,
750-
),
751-
)
752-
torch._dynamo.reset() # may segfault without this
753-
y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs)
754-
atol, rtol = 4e-7, 1e-5
755-
if dtype == torch.bfloat16:
756-
atol, rtol = 1e-2, 3e-3
757-
elif dtype == torch.half:
758-
atol, rtol = 6e-3, 2e-3
759-
assert torch.allclose(y, y2, atol=atol, rtol=rtol)
760-
# Test get_plain by dequantize()
761-
dqw1 = m.linear1.weight.original_weight_tensor.dequantize()
762-
dqw2 = m.linear2.weight.original_weight_tensor.dequantize()
763-
dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize()
764-
dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize()
765-
assert torch.allclose(dqw1, dqw1_ref)
766-
assert torch.allclose(dqw2, dqw2_ref)
767-
768700
# TODO(#1690): move to new config names
769701
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
770702
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

torchao/csrc/cpu/da8w4_linear.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ da8w4_linear_prepack_impl(
6565
at::Tensor blocked_scales = new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
6666
at::Tensor blocked_qzeros = new_qzeros.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
6767
// Compensation = Σ(k)(W[k][n] - ZP[n]) for each block.
68+
// Reorder compensation to [N/block_n, K/block_k, block_n]
6869
auto weight_sub_qzero = weight.view({Nc, block_n, G, -1}).to(at::kInt) - new_qzeros.view({Nc, block_n, G, -1});
6970
weight_sub_qzero = weight_sub_qzero.view({Nc, block_n, Kc, block_k});
7071
at::Tensor compensation = weight_sub_qzero.sum(-1);
@@ -622,9 +623,9 @@ void _da8w4_linear_impl(
622623
} else if (M < 64) {
623624
return 32;
624625
} else if (M < 96) {
625-
return 48;
626-
} else {
627626
return 64;
627+
} else {
628+
return 128;
628629
}
629630
}();
630631
int64_t Mc = (M + block_m - 1) / block_m;

torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ def from_plain(
124124
if zero_point.dim() == 1:
125125
zero_point.unsqueeze_(-1)
126126

127+
# Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n].
128+
# Pack the inner blocks [block_k, block_n] to VNNI layout if AMX is available.
129+
# Pack scales/qzeros from [N, num_groups] to [N / block_n, num_groups, block_n].
130+
# Compensation shape = [N / block_n, K / block_k, block_n].
127131
weight_int4, scales, qzeros, compensation = (
128132
torch.ops.torchao.da8w4_linear_prepack_cpu(int_data, scale, zero_point)
129133
)
@@ -310,3 +314,9 @@ def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias):
310314
y = y.reshape(*orig_act_size[:-1], orig_out_features)
311315

312316
return y.to(orig_dtype)
317+
318+
319+
# Register the concat linear fusion pass
320+
from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass
321+
322+
register_da8w4_concat_linear_cpu_pass()
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from .da8w4_concat_linear_fusion_cpu import register_da8w4_concat_linear_cpu_pass
12
from .int8_sdpa_fusion import _int8_sdpa_init
23

34
__all__ = [
45
"_int8_sdpa_init",
6+
"register_da8w4_concat_linear_cpu_pass",
57
]

0 commit comments

Comments
 (0)