Skip to content

Commit fc5d2c8

Browse files
authored
Support int_scaled_mm on CPU (#121)
1 parent 8713b7d commit fc5d2c8

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

test/kernel/test_autotuner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,15 @@ def test_int_mm(self, device, dtype):
5252
@parameterized.expand(
5353
[
5454
("cuda", torch.bfloat16),
55-
# TODO: ("cpu", torch.bfloat16),
55+
("cpu", torch.bfloat16),
5656
("cuda", torch.float16),
57-
# TODO: ("cpu", torch.float16),
57+
("cpu", torch.float16),
5858
]
5959
)
60-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
6160
def test_int_scaled_mm(self, device, dtype):
61+
if device == "cuda" and not torch.cuda.is_available():
62+
self.skipTest(f"{device} not available")
63+
6264
from torchao.kernel import intmm
6365

6466
dtype = torch.bfloat16

torchao/kernel/intmm_triton.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,3 +356,9 @@ def int_scaled_matmul_cuda(a, b, scales1):
356356
int_scaled_matmul_kernel, [a, b, scales1, c], int8_mm_kernel_configs
357357
)
358358
return int_scaled_matmul_kernel(a, b, scales1, c, best_config)
359+
360+
361+
@torch.library.impl(lib, "int_scaled_matmul", "CPU")
362+
def int_scaled_matmul_cpu(a, b, scales1):
363+
c = torch._int_mm(a, b)
364+
return c.to(scales1.dtype) * scales1

0 commit comments

Comments
 (0)