Skip to content

Commit b0cfeec

Browse files
authored
Support INT8 SDPA template for CPU (#2148)
* support int8 sdpa template
1 parent 1017c7e commit b0cfeec

File tree

6 files changed

+1998
-8
lines changed

6 files changed

+1998
-8
lines changed

test/prototype/inductor/test_int8_sdpa_fusion.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,14 @@ def _check_common(
122122
if has_fuse_pattern:
123123
self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1)
124124
if contains:
125-
# many of the patterns get re-expanded in dispatcher
126-
self.assertIn(
127-
"torchao.qscaled_dot_product",
128-
source_code,
125+
self.assertTrue(
126+
any(
127+
op_name in source_code
128+
for op_name in [
129+
"qscaled_dot_product",
130+
"cpp_fused_quantize_per_tensor",
131+
]
132+
)
129133
)
130134

131135
# some tests configured with very low dropout where we still want to check equality
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .cpp_int8_sdpa_template import CppInt8SdpaTemplate
2+
3+
__all__ = [
4+
"CppInt8SdpaTemplate",
5+
]

0 commit comments

Comments
 (0)