Skip to content

Commit 5420089

Browse files
authored
Expand CI coverage to 2.2.2, 2.3rc and nightly (#96)
1 parent ec08d71 commit 5420089

File tree

4 files changed

+82
-99
lines changed

4 files changed

+82
-99
lines changed

.github/workflows/regression_test.yml

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ on:
99
- main
1010

1111
jobs:
12-
test:
12+
test-cuda-2-2-2:
1313
runs-on: 4-core-ubuntu-gpu-t4
1414
steps:
1515
- uses: actions/checkout@v2
@@ -22,10 +22,9 @@ jobs:
2222
- name: Install dependencies
2323
run: |
2424
python -m pip install --upgrade pip
25-
pip install torch
25+
pip install torch==2.2.2
2626
pip install -r requirements.txt
2727
pip install -r dev-requirements.txt
28-
2928
3029
- name: Install package
3130
run: |
@@ -35,7 +34,24 @@ jobs:
3534
run: |
3635
pytest test --verbose -s -x
3736
38-
test-nightly:
37+
test-cuda-2-3-rc:
38+
runs-on: 4-core-ubuntu-gpu-t4
39+
steps:
40+
- uses: actions/checkout@v2
41+
42+
- name: Set up Python
43+
uses: actions/setup-python@v2
44+
with:
45+
python-version: 3.9
46+
47+
- name: Install dependencies
48+
run: |
49+
python -m pip install --upgrade pip
50+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
51+
pip install -r requirements.txt
52+
pip install -r dev-requirements.txt
53+
54+
test-cuda-nightly:
3955
runs-on: 4-core-ubuntu-gpu-t4
4056
steps:
4157
- uses: actions/checkout@v2
@@ -103,7 +119,6 @@ jobs:
103119
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
104120
pip install -r requirements.txt
105121
pip install -r dev-requirements.txt
106-
107122
108123
- name: Install package
109124
run: |

torchao/kernel/intmm.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,52 +2,65 @@
22
import os
33
import torch
44

5-
from torch._dynamo import is_compiling as dynamo_is_compiling
6-
from torch._higher_order_ops.out_dtype import out_dtype
5+
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_2
76

87
try:
9-
from torchao.kernel import intmm_triton
8+
# Only works for torch2.2 or newer.
9+
if TORCH_VERSION_AFTER_2_2:
10+
from torchao.kernel import intmm_triton
11+
else:
12+
intmm_triton = None
1013
except ImportError:
14+
# On cpu-only builds might not be available.
1115
intmm_triton = None
1216

1317
AUTOTUNER_ENABLE = bool(int(os.getenv("TORCHAO_AUTOTUNER_ENABLE", 0)))
1418

15-
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
16-
# torch.compile path
17-
if dynamo_is_compiling() or "FakeTensor" in input.__repr__():
18-
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
19-
20-
# error checking for cublas path
21-
assert (
22-
mat2.device == input.device
23-
), f"need both tensors to be on the same device but got {mat2.device} and {input.device}"
24-
device_cpu = "cpu" in [mat2.device.type, input.device.type]
25-
# with input.shape = [i,j] and mat2.shape = [j,k]
26-
i_is_strictly_greater_than_16 = input.shape[0] > 16
27-
j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0)
28-
k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0)
29-
bad_dimensions_for_cublas = not (
30-
i_is_strictly_greater_than_16
31-
and j_is_nonzero_multiple_of_8
32-
and k_is_nonzero_multiple_of_8
33-
)
34-
35-
if device_cpu or bad_dimensions_for_cublas:
36-
# fallback path
37-
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
38-
input.device.type
19+
# torch._int_mm doesn't exist before 2.2
20+
if TORCH_VERSION_AFTER_2_2:
21+
from torch._dynamo import is_compiling as dynamo_is_compiling
22+
from torch._higher_order_ops.out_dtype import out_dtype
23+
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
24+
# torch.compile path
25+
if dynamo_is_compiling() or "FakeTensor" in input.__repr__():
26+
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
27+
28+
# error checking for cublas path
29+
assert (
30+
mat2.device == input.device
31+
), f"need both tensors to be on the same device but got {mat2.device} and {input.device}"
32+
device_cpu = "cpu" in [mat2.device.type, input.device.type]
33+
# with input.shape = [i,j] and mat2.shape = [j,k]
34+
i_is_strictly_greater_than_16 = input.shape[0] > 16
35+
j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0)
36+
k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0)
37+
bad_dimensions_for_cublas = not (
38+
i_is_strictly_greater_than_16
39+
and j_is_nonzero_multiple_of_8
40+
and k_is_nonzero_multiple_of_8
3941
)
40-
41-
# cublas paths
42-
if not mat2.is_contiguous(): # silently gives incorrect result without this
43-
mat2 = mat2.contiguous()
44-
if (not input.is_contiguous()) and (
45-
input.shape[0] % 8 != 0
46-
): # gives cryptic error without this
47-
input = (
48-
input.contiguous()
49-
) # (it seems the transpose makes cublas check the above j constraint on i)
50-
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
42+
43+
if device_cpu or bad_dimensions_for_cublas:
44+
# fallback path
45+
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
46+
input.device.type
47+
)
48+
49+
# cublas paths
50+
if not mat2.is_contiguous(): # silently gives incorrect result without this
51+
mat2 = mat2.contiguous()
52+
if (not input.is_contiguous()) and (
53+
input.shape[0] % 8 != 0
54+
): # gives cryptic error without this
55+
input = (
56+
input.contiguous()
57+
) # (it seems the transpose makes cublas check the above j constraint on i)
58+
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
59+
else:
60+
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
61+
# We can improve on this by writing Triton code that works for older versions of Triton
62+
# that ship with 2.1 or 2.0.
63+
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
5164

5265

5366
def int_matmul(a, b):

torchao/quantization/quant_primitives.py

Lines changed: 2 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from torch.library import impl
1212

1313
from torchao.kernel.intmm import int_scaled_matmul
14+
from .utils import TORCH_VERSION_AFTER_2_4
15+
from torchao.kernel.intmm import safe_int_mm
1416
from .utils import TORCH_VERSION_AFTER_2_3
1517

1618

@@ -40,64 +42,8 @@
4042
# TODO: need to clean up above functions
4143
] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else [])
4244

43-
44-
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
45-
r"""
46-
This function wraps torch._int_mm and avoids several undesirable behaviors of the function for certain inputs while still
47-
returning correct results and being torch.compiled in a performant way.
48-
49-
Assumes both tensors have dimension of 2.
50-
51-
Note: no error checking for torch.compiled path, if input.shape = [i, j] and j<=16 then the triton kernel
52-
will error.
53-
54-
Args:
55-
input (Tensor, int8): the first tensor to be multiplied
56-
mat2 (Tensor, int8): the second tensor to be multiplied
57-
58-
Return:
59-
out (Tensor, int32): the result of the matmul with device matching that of the inputs
60-
"""
61-
# torch.compile path
62-
if dynamo_is_compiling() or "FakeTensor" in input.__repr__():
63-
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
64-
65-
# error checking for cublas path
66-
assert (
67-
mat2.device == input.device
68-
), f"need both tensors to be on the same device but got {mat2.device} and {input.device}"
69-
device_cpu = "cpu" in [mat2.device.type, input.device.type]
70-
# with input.shape = [i,j] and mat2.shape = [j,k]
71-
i_is_strictly_greater_than_16 = input.shape[0] > 16
72-
j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0)
73-
k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0)
74-
bad_dimensions_for_cublas = not (
75-
i_is_strictly_greater_than_16
76-
and j_is_nonzero_multiple_of_8
77-
and k_is_nonzero_multiple_of_8
78-
)
79-
80-
if device_cpu or bad_dimensions_for_cublas:
81-
# fallback path
82-
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
83-
input.device.type
84-
)
85-
86-
# cublas paths
87-
if not mat2.is_contiguous(): # silently gives incorrect result without this
88-
mat2 = mat2.contiguous()
89-
if (not input.is_contiguous()) and (
90-
input.shape[0] % 8 != 0
91-
): # gives cryptic error without this
92-
input = (
93-
input.contiguous()
94-
) # (it seems the transpose makes cublas check the above j constraint on i)
95-
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
96-
97-
9845
# copy-pasta of https://www.internalfb.com/intern/anp/view/?id=3350736
9946

100-
10147
def dynamically_quantize_per_tensor(
10248
x,
10349
quant_min,

torchao/quantization/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,17 @@ def get_model_size_in_bytes(model):
9595
s += b.nelement() * b.element_size()
9696
return s
9797

98+
if version.parse(torch.__version__) >= version.parse("2.4.0.dev"):
99+
TORCH_VERSION_AFTER_2_4 = True
100+
else:
101+
TORCH_VERSION_AFTER_2_4 = False
98102

99103
if version.parse(torch.__version__) >= version.parse("2.3.0.dev"):
100104
TORCH_VERSION_AFTER_2_3 = True
101105
else:
102106
TORCH_VERSION_AFTER_2_3 = False
107+
108+
if version.parse(torch.__version__) >= version.parse("2.2.0.dev"):
109+
TORCH_VERSION_AFTER_2_2 = True
110+
else:
111+
TORCH_VERSION_AFTER_2_2 = False

0 commit comments

Comments
 (0)