Skip to content

Commit 56fb1b7

Browse files
authored
Additional regression tests for cpu-only (#84)
1 parent 62c7871 commit 56fb1b7

File tree

6 files changed

+137
-77
lines changed

6 files changed

+137
-77
lines changed

.github/workflows/regression_test.yml

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,58 @@ jobs:
5353
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
5454
5555
56+
- name: Install package
57+
run: |
58+
pip install .
59+
60+
- name: Run tests
61+
run: |
62+
pytest test
63+
64+
test-cpu:
65+
runs-on: 32-core-ubuntu
66+
steps:
67+
- uses: actions/checkout@v2
68+
69+
- name: Set up Python
70+
uses: actions/setup-python@v2
71+
with:
72+
python-version: 3.9
73+
74+
- name: Install dependencies
75+
run: |
76+
python -m pip install --upgrade pip
77+
pip install -r requirements.txt
78+
pip install -r dev-requirements.txt
79+
pip install torch --index-url https://download.pytorch.org/whl/cpu
80+
81+
82+
- name: Install package
83+
run: |
84+
pip install .
85+
86+
- name: Run tests
87+
run: |
88+
pytest test
89+
90+
test-nightly-cpu:
91+
runs-on: 32-core-ubuntu
92+
steps:
93+
- uses: actions/checkout@v2
94+
95+
- name: Set up Python
96+
uses: actions/setup-python@v2
97+
with:
98+
python-version: 3.9
99+
100+
- name: Install dependencies
101+
run: |
102+
python -m pip install --upgrade pip
103+
pip install -r requirements.txt
104+
pip install -r dev-requirements.txt
105+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
106+
107+
56108
- name: Install package
57109
run: |
58110
pip install .

test/dtypes/test_nf4.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def test_to_bfloat16(self):
184184
assert type(inpt_tensor_nf4.to(torch.bfloat16)) == torch.Tensor
185185
assert inpt_tensor_nf4.to(torch.bfloat16).dtype == torch.bfloat16
186186

187+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
187188
def test_smoketest_linear(self):
188189
a = torch.randn(32, 32, dtype=torch.bfloat16, device='cuda')
189190
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
@@ -192,6 +193,7 @@ def test_smoketest_linear(self):
192193
out2 = torch.nn.functional.linear(inp, a_nf4)
193194

194195
@unittest.skipIf(torch.__version__.split('+')[0] == '2.2.1', "Broken on stable.")
196+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
195197
def test_smoketest_linear_compile(self):
196198
a = torch.randn(32, 32, dtype=torch.bfloat16, device='cuda')
197199
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)

test/kernel/test_autotuner.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,41 +27,39 @@ def tearDown(self):
2727

2828
@parameterized.expand(
2929
[
30-
("cuda", torch.bfloat16),
3130
("cuda", torch.bfloat16),
3231
# TODO: ("cpu", torch.bfloat16),
3332
("cuda", torch.float16),
34-
("cuda", torch.float16),
3533
# TODO: ("cpu", torch.float16),
3634
]
3735
)
36+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
3837
def test_int_mm(self, device, dtype):
39-
from torchao.kernel import intmm_triton
38+
from torchao.kernel import intmm
4039

4140
dtype = torch.bfloat16
4241
m, k, n = (128, 64, 16)
4342
x = torch.randn(m, k, dtype=dtype, device=device)
4443
w = torch.randn(n, k, dtype=dtype, device=device).t()
4544
x_int = x.to(dtype=torch.int8)
4645
w_int = w.to(dtype=torch.int8)
47-
out32_1 = intmm_triton.safe_int_mm(x_int, w_int)
46+
out32_1 = intmm.safe_int_mm(x_int, w_int)
4847
assert out32_1.dtype == torch.int32
49-
out32_2 = intmm_triton.int_matmul(x_int, w_int)
48+
out32_2 = intmm.int_matmul(x_int, w_int)
5049
assert out32_2.dtype == out32_1.dtype
5150
torch.testing.assert_allclose(out32_1, out32_2)
5251

5352
@parameterized.expand(
5453
[
55-
("cuda", torch.bfloat16),
5654
("cuda", torch.bfloat16),
5755
# TODO: ("cpu", torch.bfloat16),
5856
("cuda", torch.float16),
59-
("cuda", torch.float16),
6057
# TODO: ("cpu", torch.float16),
6158
]
6259
)
60+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
6361
def test_int_scaled_mm(self, device, dtype):
64-
from torchao.kernel import intmm_triton
62+
from torchao.kernel import intmm
6563

6664
dtype = torch.bfloat16
6765
m, k, n = (128, 64, 16)
@@ -70,9 +68,9 @@ def test_int_scaled_mm(self, device, dtype):
7068
w = torch.randn(n, k, dtype=dtype, device=device).t()
7169
x_int = x.to(dtype=torch.int8)
7270
w_int = w.to(dtype=torch.int8)
73-
out32_1 = intmm_triton.safe_int_mm(x_int, w_int) * scales
71+
out32_1 = intmm.safe_int_mm(x_int, w_int) * scales
7472
assert out32_1.dtype == torch.bfloat16
75-
out32_2 = intmm_triton.int_scaled_matmul(x_int, w_int, scales)
73+
out32_2 = intmm.int_scaled_matmul(x_int, w_int, scales)
7674
assert out32_2.dtype == out32_1.dtype
7775
torch.testing.assert_allclose(out32_1, out32_2)
7876

torchao/kernel/intmm.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import itertools
2+
import os
3+
import torch
4+
5+
from torch._dynamo import is_compiling as dynamo_is_compiling
6+
from torch._higher_order_ops.out_dtype import out_dtype
7+
8+
try:
9+
from torchao.kernel import intmm_triton
10+
except ImportError:
11+
intmm_triton = None
12+
13+
AUTOTUNER_ENABLE = bool(int(os.getenv("TORCHAO_AUTOTUNER_ENABLE", 0)))
14+
15+
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
16+
# torch.compile path
17+
if dynamo_is_compiling() or "FakeTensor" in input.__repr__():
18+
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
19+
20+
# error checking for cublas path
21+
assert (
22+
mat2.device == input.device
23+
), f"need both tensors to be on the same device but got {mat2.device} and {input.device}"
24+
device_cpu = "cpu" in [mat2.device.type, input.device.type]
25+
# with input.shape = [i,j] and mat2.shape = [j,k]
26+
i_is_strictly_greater_than_16 = input.shape[0] > 16
27+
j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0)
28+
k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0)
29+
bad_dimensions_for_cublas = not (
30+
i_is_strictly_greater_than_16
31+
and j_is_nonzero_multiple_of_8
32+
and k_is_nonzero_multiple_of_8
33+
)
34+
35+
if device_cpu or bad_dimensions_for_cublas:
36+
# fallback path
37+
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
38+
input.device.type
39+
)
40+
41+
# cublas paths
42+
if not mat2.is_contiguous(): # silently gives incorrect result without this
43+
mat2 = mat2.contiguous()
44+
if (not input.is_contiguous()) and (
45+
input.shape[0] % 8 != 0
46+
): # gives cryptic error without this
47+
input = (
48+
input.contiguous()
49+
) # (it seems the transpose makes cublas check the above j constraint on i)
50+
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
51+
52+
53+
def int_matmul(a, b):
54+
if intmm_triton is not None and AUTOTUNER_ENABLE:
55+
return torch.ops.torchao.int_matmul(a, b)
56+
return safe_int_mm(a, b)
57+
58+
59+
def int_scaled_matmul(a, b, scales1):
60+
assert a.is_contiguous(), "Matrix A must be contiguous"
61+
assert b.transpose(0, 1).is_contiguous(), "Matrix B must be transpose contiguous"
62+
M, K = a.shape
63+
K, N = b.shape
64+
assert M == scales1.size(0)
65+
assert 1 == scales1.size(1)
66+
assert scales1.is_contiguous()
67+
assert scales1.dtype == torch.bfloat16
68+
scales1 = scales1.expand((M, N))
69+
assert scales1.dim() == 2
70+
if intmm_triton is not None and AUTOTUNER_ENABLE:
71+
return torch.ops.torchao.int_scaled_matmul(a, b, scales1)
72+
73+
c = safe_int_mm(a, b)
74+
return c * scales1

torchao/kernel/intmm_triton.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,9 @@
55

66
import triton
77
import triton.language as tl
8-
from torch._dynamo import is_compiling as dynamo_is_compiling
9-
from torch._higher_order_ops.out_dtype import out_dtype
108

119
from torchao.kernel.autotuner import get_best_config_fn
1210

13-
AUTOTUNER_ENABLE = bool(int(os.getenv("TORCHAO_AUTOTUNER_ENABLE", 0)))
14-
1511
int8_powers_of_two = [32, 64, 128, 256]
1612
int8_mm_kernel_configs = sum(
1713
[
@@ -338,50 +334,6 @@ def int_matmul_cuda(a, b):
338334
return int_matmul_kernel(a, b, c, best_config)
339335

340336

341-
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
342-
# torch.compile path
343-
if dynamo_is_compiling() or "FakeTensor" in input.__repr__():
344-
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
345-
346-
# error checking for cublas path
347-
assert (
348-
mat2.device == input.device
349-
), f"need both tensors to be on the same device but got {mat2.device} and {input.device}"
350-
device_cpu = "cpu" in [mat2.device.type, input.device.type]
351-
# with input.shape = [i,j] and mat2.shape = [j,k]
352-
i_is_strictly_greater_than_16 = input.shape[0] > 16
353-
j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0)
354-
k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0)
355-
bad_dimensions_for_cublas = not (
356-
i_is_strictly_greater_than_16
357-
and j_is_nonzero_multiple_of_8
358-
and k_is_nonzero_multiple_of_8
359-
)
360-
361-
if device_cpu or bad_dimensions_for_cublas:
362-
# fallback path
363-
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
364-
input.device.type
365-
)
366-
367-
# cublas paths
368-
if not mat2.is_contiguous(): # silently gives incorrect result without this
369-
mat2 = mat2.contiguous()
370-
if (not input.is_contiguous()) and (
371-
input.shape[0] % 8 != 0
372-
): # gives cryptic error without this
373-
input = (
374-
input.contiguous()
375-
) # (it seems the transpose makes cublas check the above j constraint on i)
376-
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
377-
378-
379-
def int_matmul(a, b):
380-
if AUTOTUNER_ENABLE:
381-
return torch.ops.torchao.int_matmul(a, b)
382-
return safe_int_mm(a, b)
383-
384-
385337
@torch.library.impl(lib, "int_scaled_matmul", "Meta")
386338
def int_scaled_matmul_meta(a, b, scales1):
387339
M, K = a.shape
@@ -404,21 +356,3 @@ def int_scaled_matmul_cuda(a, b, scales1):
404356
int_scaled_matmul_kernel, [a, b, scales1, c], int8_mm_kernel_configs
405357
)
406358
return int_scaled_matmul_kernel(a, b, scales1, c, best_config)
407-
408-
409-
def int_scaled_matmul(a, b, scales1):
410-
assert a.is_contiguous(), "Matrix A must be contiguous"
411-
assert b.transpose(0, 1).is_contiguous(), "Matrix B must be transpose contiguous"
412-
M, K = a.shape
413-
K, N = b.shape
414-
assert M == scales1.size(0)
415-
assert 1 == scales1.size(1)
416-
assert scales1.is_contiguous()
417-
assert scales1.dtype == torch.bfloat16
418-
scales1 = scales1.expand((M, N))
419-
assert scales1.dim() == 2
420-
if AUTOTUNER_ENABLE:
421-
return torch.ops.torchao.int_scaled_matmul(a, b, scales1)
422-
423-
c = safe_int_mm(a, b)
424-
return c * scales1

torchao/quantization/quant_primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
1111
from torch.library import impl
1212

13-
from torchao.kernel.intmm_triton import int_scaled_matmul
13+
from torchao.kernel.intmm import int_scaled_matmul
1414
from .utils import TORCH_VERSION_AFTER_2_4
1515

1616

0 commit comments

Comments
 (0)