|
2 | 2 | import os
|
3 | 3 | import torch
|
4 | 4 |
|
5 |
| -from torch._dynamo import is_compiling as dynamo_is_compiling |
6 |
| -from torch._higher_order_ops.out_dtype import out_dtype |
| 5 | +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_2 |
7 | 6 |
|
8 | 7 | try:
|
9 |
| - from torchao.kernel import intmm_triton |
| 8 | + # Only works for torch2.2 or newer. |
| 9 | + if TORCH_VERSION_AFTER_2_2: |
| 10 | + from torchao.kernel import intmm_triton |
| 11 | + else: |
| 12 | + intmm_triton = None |
10 | 13 | except ImportError:
|
| 14 | + # On cpu-only builds might not be available. |
11 | 15 | intmm_triton = None
|
12 | 16 |
|
13 | 17 | AUTOTUNER_ENABLE = bool(int(os.getenv("TORCHAO_AUTOTUNER_ENABLE", 0)))
|
14 | 18 |
|
15 |
| -def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: |
16 |
| - # torch.compile path |
17 |
| - if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): |
18 |
| - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) |
19 |
| - |
20 |
| - # error checking for cublas path |
21 |
| - assert ( |
22 |
| - mat2.device == input.device |
23 |
| - ), f"need both tensors to be on the same device but got {mat2.device} and {input.device}" |
24 |
| - device_cpu = "cpu" in [mat2.device.type, input.device.type] |
25 |
| - # with input.shape = [i,j] and mat2.shape = [j,k] |
26 |
| - i_is_strictly_greater_than_16 = input.shape[0] > 16 |
27 |
| - j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) |
28 |
| - k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) |
29 |
| - bad_dimensions_for_cublas = not ( |
30 |
| - i_is_strictly_greater_than_16 |
31 |
| - and j_is_nonzero_multiple_of_8 |
32 |
| - and k_is_nonzero_multiple_of_8 |
33 |
| - ) |
34 |
| - |
35 |
| - if device_cpu or bad_dimensions_for_cublas: |
36 |
| - # fallback path |
37 |
| - return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( |
38 |
| - input.device.type |
| 19 | +# torch._int_mm doesn't exist before 2.2 |
| 20 | +if TORCH_VERSION_AFTER_2_2: |
| 21 | + from torch._dynamo import is_compiling as dynamo_is_compiling |
| 22 | + from torch._higher_order_ops.out_dtype import out_dtype |
| 23 | + def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: |
| 24 | + # torch.compile path |
| 25 | + if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): |
| 26 | + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) |
| 27 | + |
| 28 | + # error checking for cublas path |
| 29 | + assert ( |
| 30 | + mat2.device == input.device |
| 31 | + ), f"need both tensors to be on the same device but got {mat2.device} and {input.device}" |
| 32 | + device_cpu = "cpu" in [mat2.device.type, input.device.type] |
| 33 | + # with input.shape = [i,j] and mat2.shape = [j,k] |
| 34 | + i_is_strictly_greater_than_16 = input.shape[0] > 16 |
| 35 | + j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) |
| 36 | + k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) |
| 37 | + bad_dimensions_for_cublas = not ( |
| 38 | + i_is_strictly_greater_than_16 |
| 39 | + and j_is_nonzero_multiple_of_8 |
| 40 | + and k_is_nonzero_multiple_of_8 |
39 | 41 | )
|
40 |
| - |
41 |
| - # cublas paths |
42 |
| - if not mat2.is_contiguous(): # silently gives incorrect result without this |
43 |
| - mat2 = mat2.contiguous() |
44 |
| - if (not input.is_contiguous()) and ( |
45 |
| - input.shape[0] % 8 != 0 |
46 |
| - ): # gives cryptic error without this |
47 |
| - input = ( |
48 |
| - input.contiguous() |
49 |
| - ) # (it seems the transpose makes cublas check the above j constraint on i) |
50 |
| - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) |
| 42 | + |
| 43 | + if device_cpu or bad_dimensions_for_cublas: |
| 44 | + # fallback path |
| 45 | + return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( |
| 46 | + input.device.type |
| 47 | + ) |
| 48 | + |
| 49 | + # cublas paths |
| 50 | + if not mat2.is_contiguous(): # silently gives incorrect result without this |
| 51 | + mat2 = mat2.contiguous() |
| 52 | + if (not input.is_contiguous()) and ( |
| 53 | + input.shape[0] % 8 != 0 |
| 54 | + ): # gives cryptic error without this |
| 55 | + input = ( |
| 56 | + input.contiguous() |
| 57 | + ) # (it seems the transpose makes cublas check the above j constraint on i) |
| 58 | + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) |
| 59 | +else: |
| 60 | + def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: |
| 61 | + # We can improve on this by writing Triton code that works for older versions of Triton |
| 62 | + # that ship with 2.1 or 2.0. |
| 63 | + return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) |
51 | 64 |
|
52 | 65 |
|
53 | 66 | def int_matmul(a, b):
|
|
0 commit comments