Skip to content

Commit 1bbeed1

Browse files
authored
Re-land the PR of "Add INT8 SDPA path for CPU" (#2215)
* enable int8 sdpa cpu
1 parent 96aec6a commit 1bbeed1

File tree

9 files changed

+2819
-1
lines changed

9 files changed

+2819
-1
lines changed

setup.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def read_version(file_path="version.txt"):
5555
and platform.system() == "Darwin"
5656
)
5757

58+
use_cpp_kernels = os.getenv("USE_CPP_KERNELS", "0") == "1"
59+
60+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
61+
5862
version_prefix = read_version()
5963
# Version is version.dev year month date if using nightlies and version if not
6064
version = (
@@ -307,6 +311,21 @@ def get_extensions():
307311
["-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always"]
308312
)
309313

314+
if (
315+
use_cpp_kernels
316+
and platform.system() == "Linux"
317+
and TORCH_VERSION_AT_LEAST_2_7
318+
):
319+
if torch._C._cpu._is_avx512_supported():
320+
extra_compile_args["cxx"].extend(
321+
[
322+
"-DCPU_CAPABILITY_AVX512",
323+
"-march=native",
324+
"-mfma",
325+
"-fopenmp",
326+
]
327+
)
328+
310329
if debug_mode:
311330
extra_compile_args["cxx"].append("-g")
312331
if "nvcc" in extra_compile_args:
@@ -328,6 +347,12 @@ def get_extensions():
328347

329348
# Collect C++ source files
330349
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
350+
if not use_cpp_kernels or platform.system() != "Linux":
351+
# Remove csrc/cpu/*.cpp
352+
excluded_sources = list(
353+
glob.glob(os.path.join(extensions_dir, "cpu/*.cpp"), recursive=True)
354+
)
355+
sources = [s for s in sources if s not in excluded_sources]
331356

332357
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
333358
cuda_sources = list(
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import itertools
2+
import unittest
3+
4+
import torch
5+
import torch.utils.checkpoint
6+
from torch._dynamo.utils import counters
7+
from torch._inductor import config
8+
from torch._inductor.test_case import TestCase, run_tests
9+
from torch._inductor.utils import run_and_get_code
10+
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
11+
from torch.testing._internal.inductor_utils import HAS_CPU
12+
13+
import torchao
14+
from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import (
15+
_int8_sdpa_init,
16+
custom_pass,
17+
)
18+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
19+
20+
21+
class SelfAttnLikeModule(torch.nn.Module):
22+
def __init__(
23+
self,
24+
input_dim,
25+
has_mask,
26+
num_attention_heads=None,
27+
attention_head_size=None,
28+
) -> None:
29+
super().__init__()
30+
self.input_dim = input_dim
31+
self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
32+
self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
33+
self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
34+
self.softmax = torch.nn.Softmax(dim=-1)
35+
assert num_attention_heads is not None
36+
assert attention_head_size is not None
37+
self.num_attention_heads = num_attention_heads
38+
self.attention_head_size = attention_head_size
39+
self.all_head_size = self.num_attention_heads * self.attention_head_size
40+
self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size)
41+
self.dropout = torch.nn.Dropout(0)
42+
self.has_mask = has_mask
43+
44+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
45+
new_x_shape = x.size()[:-1] + (
46+
self.num_attention_heads,
47+
self.attention_head_size,
48+
)
49+
x = x.view(new_x_shape)
50+
return x.permute([0, 2, 1, 3])
51+
52+
def forward(self, x, mask):
53+
q = self.q_proj(x)
54+
k = self.k_proj(x)
55+
v = self.v_proj(x)
56+
q = self.transpose_for_scores(q)
57+
k = self.transpose_for_scores(k)
58+
v = self.transpose_for_scores(v)
59+
scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5)
60+
if self.has_mask and mask.dtype != scores.dtype:
61+
scores = scores + mask
62+
attention = self.softmax(scores)
63+
attention = self.dropout(attention)
64+
context_layer = torch.matmul(attention, v)
65+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
66+
context_layer = context_layer.view(
67+
context_layer.size()[:-2] + (self.all_head_size,)
68+
)
69+
return self.dense(context_layer)
70+
71+
72+
class TestSDPAPatternRewriterTemplate(TestCase):
73+
def _clone_inputs(self, inputs):
74+
def clone(x):
75+
if not isinstance(x, torch.Tensor):
76+
return x
77+
return x.clone()
78+
79+
return [clone(x) for x in inputs]
80+
81+
def _check_common(
82+
self,
83+
dot_prod_attention,
84+
args1=None,
85+
contains=True,
86+
atol=1e-5,
87+
has_fuse_pattern=True,
88+
has_dropout=False,
89+
check_train=True,
90+
override_check_equal=False,
91+
dtype=torch.float,
92+
rtol=1.3e-6,
93+
):
94+
if args1 is None:
95+
tensor_shape = (4, 2, 16, 32)
96+
args1 = [
97+
torch.randn(tensor_shape, device=self.device, dtype=dtype),
98+
torch.randn(tensor_shape, device=self.device, dtype=dtype),
99+
torch.randn(tensor_shape, device=self.device, dtype=dtype),
100+
]
101+
else:
102+
args1 = list(args1)
103+
args2 = self._clone_inputs(args1)
104+
105+
for training in [False, True] if check_train else [False]:
106+
for x in itertools.chain(args1[:], args2[:]):
107+
if isinstance(x, torch.Tensor) and x.is_floating_point():
108+
x.requires_grad = training
109+
110+
dropout_arg = [training] if has_dropout else []
111+
torch.manual_seed(1234)
112+
result1 = dot_prod_attention(*(args1 + dropout_arg))
113+
114+
counters.clear()
115+
torch.manual_seed(1234)
116+
compiled_model = torch.compile(dot_prod_attention, fullgraph=True)
117+
result2, source_code = run_and_get_code(
118+
compiled_model,
119+
*(args2 + dropout_arg),
120+
)
121+
source_code = "\n".join(source_code)
122+
if has_fuse_pattern:
123+
self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1)
124+
if contains:
125+
# many of the patterns get re-expanded in dispatcher
126+
self.assertIn(
127+
"torchao.qscaled_dot_product",
128+
source_code,
129+
)
130+
131+
# some tests configured with very low dropout where we still want to check equality
132+
if not has_dropout or override_check_equal:
133+
self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6)
134+
135+
if training:
136+
result1.sum().backward()
137+
result2.sum().backward()
138+
for arg1, arg2 in zip(args1, args2):
139+
if (
140+
isinstance(arg1, torch.Tensor)
141+
and arg1.is_floating_point()
142+
and (not has_dropout or override_check_equal)
143+
):
144+
self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol)
145+
146+
@skipIfRocm
147+
@unittest.skipIf(
148+
not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later"
149+
)
150+
@unittest.skipIf(
151+
"CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"),
152+
reason="cpp kernels not built",
153+
)
154+
@config.patch({"freezing": True})
155+
def _test_sdpa_int8_rewriter(self):
156+
from torch.export import export_for_training
157+
158+
import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
159+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
160+
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import (
161+
X86InductorQuantizer,
162+
)
163+
164+
# pattern is different for bs=1
165+
torch.manual_seed(1234)
166+
for dtype, has_mask, bs in itertools.product(
167+
[torch.float32, torch.bfloat16], [True, False], [56, 1]
168+
):
169+
seqlen, numhead, headsize = 197, 16, 64
170+
mod = SelfAttnLikeModule(
171+
input_dim=headsize * numhead,
172+
has_mask=has_mask,
173+
num_attention_heads=numhead,
174+
attention_head_size=headsize,
175+
).eval()
176+
inputs = (
177+
torch.randn(
178+
(bs, seqlen, headsize * numhead), device=self.device, dtype=dtype
179+
),
180+
torch.randn((bs, 1, 1, seqlen), device=self.device)
181+
if has_mask
182+
else None,
183+
)
184+
enable_autocast = dtype == torch.bfloat16
185+
with (
186+
torch.no_grad(),
187+
torch.amp.autocast(
188+
self.device, enabled=enable_autocast, dtype=torch.bfloat16
189+
),
190+
config.patch(post_grad_custom_pre_pass=custom_pass),
191+
):
192+
_int8_sdpa_init()
193+
quantizer = X86InductorQuantizer()
194+
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
195+
quantizer.set_function_type_qconfig(
196+
torch.matmul, quantizer.get_global_quantization_config()
197+
)
198+
export_model = export_for_training(
199+
mod,
200+
inputs,
201+
strict=True,
202+
).module()
203+
prepare_model = prepare_pt2e(export_model, quantizer)
204+
prepare_model(*inputs)
205+
convert_model = convert_pt2e(prepare_model)
206+
torchao.quantization.pt2e.move_exported_model_to_eval(convert_model)
207+
self._check_common(
208+
convert_model, args1=inputs, check_train=False, atol=1.0
209+
)
210+
211+
212+
if HAS_CPU:
213+
214+
class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate):
215+
device = "cpu"
216+
test_sdpa_int8_rewriter_cpu = (
217+
TestSDPAPatternRewriterTemplate._test_sdpa_int8_rewriter
218+
)
219+
220+
221+
if __name__ == "__main__":
222+
if IS_LINUX:
223+
run_tests()

0 commit comments

Comments
 (0)