|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import copy |
| 8 | +import unittest |
| 9 | + |
| 10 | +import torch |
| 11 | +from torch.testing._internal import common_utils |
| 12 | +from torch.testing._internal.common_utils import ( |
| 13 | + TestCase, |
| 14 | + run_tests, |
| 15 | +) |
| 16 | + |
| 17 | +from torchao import quantize_ |
| 18 | +from torchao.dtypes import ( |
| 19 | + Int8DynamicActInt4WeightCPULayout, |
| 20 | + PlainLayout, |
| 21 | +) |
| 22 | +from torchao.quantization.quant_api import ( |
| 23 | + Int8DynamicActivationInt4WeightConfig, |
| 24 | +) |
| 25 | +from torchao.quantization.quant_primitives import MappingType |
| 26 | +from torchao.utils import ( |
| 27 | + TORCH_VERSION_AT_LEAST_2_7, |
| 28 | + TORCH_VERSION_AT_LEAST_2_8, |
| 29 | +) |
| 30 | + |
| 31 | + |
| 32 | +class ToyLinearModel(torch.nn.Module): |
| 33 | + def __init__(self, m=64, n=32, k=64, bias=False): |
| 34 | + super().__init__() |
| 35 | + self.linear1 = torch.nn.Linear(m, n, bias=bias).to(torch.float) |
| 36 | + self.linear2 = torch.nn.Linear(n, k, bias=bias).to(torch.float) |
| 37 | + |
| 38 | + def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): |
| 39 | + return ( |
| 40 | + torch.randn( |
| 41 | + batch_size, self.linear1.in_features, dtype=dtype, device=device |
| 42 | + ), |
| 43 | + ) |
| 44 | + |
| 45 | + def forward(self, x): |
| 46 | + x = self.linear1(x) |
| 47 | + x = self.linear2(x) |
| 48 | + return x |
| 49 | + |
| 50 | + |
| 51 | +class TestDa8w4Cpu(TestCase): |
| 52 | + @unittest.skipIf( |
| 53 | + "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), |
| 54 | + reason="cpp kernels not built", |
| 55 | + ) |
| 56 | + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Test only enabled for 2.7+") |
| 57 | + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) |
| 58 | + @common_utils.parametrize("x_dim", [2, 3]) |
| 59 | + @common_utils.parametrize("bias", [True, False]) |
| 60 | + @common_utils.parametrize("bs", [1, 160]) |
| 61 | + @common_utils.parametrize("sym_quant_a", [True, False]) |
| 62 | + def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a): |
| 63 | + if sym_quant_a and not TORCH_VERSION_AT_LEAST_2_8: |
| 64 | + # not supported until PT 2.8 |
| 65 | + return |
| 66 | + device = "cpu" |
| 67 | + m = ToyLinearModel(bias=bias).eval().to(dtype).to(device) |
| 68 | + m2 = copy.deepcopy(m) |
| 69 | + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) |
| 70 | + if x_dim == 3: |
| 71 | + example_inputs = (example_inputs[0].unsqueeze(0),) |
| 72 | + |
| 73 | + with torch.no_grad(): |
| 74 | + # Currently, the difference between Int8DynamicActInt4WeightCPULayout and PlainLayout |
| 75 | + # is that the former packs two int4 weights into one int8, while the latter does not. |
| 76 | + quantize_( |
| 77 | + m, |
| 78 | + Int8DynamicActivationInt4WeightConfig( |
| 79 | + group_size=32, |
| 80 | + layout=Int8DynamicActInt4WeightCPULayout(), |
| 81 | + act_mapping_type=MappingType.SYMMETRIC |
| 82 | + if sym_quant_a |
| 83 | + else MappingType.ASYMMETRIC, |
| 84 | + ), |
| 85 | + ) |
| 86 | + y, code = torch._inductor.utils.run_and_get_code( |
| 87 | + torch.compile(m, fullgraph=True, dynamic=True), |
| 88 | + *example_inputs, |
| 89 | + ) |
| 90 | + # ensure the expected op is in the code |
| 91 | + assert "torch.ops.torchao.da8w4_linear_cpu.default" in code[0] |
| 92 | + quantize_( |
| 93 | + m2, |
| 94 | + Int8DynamicActivationInt4WeightConfig( |
| 95 | + group_size=32, |
| 96 | + layout=PlainLayout(), |
| 97 | + act_mapping_type=MappingType.SYMMETRIC |
| 98 | + if sym_quant_a |
| 99 | + else MappingType.ASYMMETRIC, |
| 100 | + ), |
| 101 | + ) |
| 102 | + torch._dynamo.reset() # may segfault without this |
| 103 | + y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) |
| 104 | + atol, rtol = 4e-7, 1e-5 |
| 105 | + if dtype == torch.bfloat16: |
| 106 | + atol, rtol = 1e-2, 3e-3 |
| 107 | + elif dtype == torch.half: |
| 108 | + atol, rtol = 6e-3, 2e-3 |
| 109 | + assert torch.allclose(y, y2, atol=atol, rtol=rtol) |
| 110 | + # Test get_plain by dequantize() |
| 111 | + dqw1 = m.linear1.weight.original_weight_tensor.dequantize() |
| 112 | + dqw2 = m.linear2.weight.original_weight_tensor.dequantize() |
| 113 | + dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize() |
| 114 | + dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize() |
| 115 | + assert torch.allclose(dqw1, dqw1_ref) |
| 116 | + assert torch.allclose(dqw2, dqw2_ref) |
| 117 | + |
| 118 | + @unittest.skipIf( |
| 119 | + "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), |
| 120 | + reason="cpp kernels not built", |
| 121 | + ) |
| 122 | + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Test only enabled for 2.8+") |
| 123 | + @common_utils.parametrize("x_dim", [2, 3]) |
| 124 | + @common_utils.parametrize("bias", [True, False]) |
| 125 | + def test_8da4w_concat_linear_cpu(self, x_dim, bias): |
| 126 | + N, K = 64, 128 |
| 127 | + |
| 128 | + class Mod(torch.nn.Module): |
| 129 | + def __init__(self, bias): |
| 130 | + super().__init__() |
| 131 | + self.linear1 = torch.nn.Linear(K, N, bias=bias) |
| 132 | + self.linear2 = torch.nn.Linear(K, N, bias=bias) |
| 133 | + self.linear3 = torch.nn.Linear(K, N, bias=bias) |
| 134 | + |
| 135 | + def forward(self, x): |
| 136 | + a = self.linear1(x) |
| 137 | + b = self.linear2(x) |
| 138 | + c = self.linear3(x) |
| 139 | + return a + b + c |
| 140 | + |
| 141 | + dtype = torch.bfloat16 |
| 142 | + device = "cpu" |
| 143 | + m = Mod(bias).eval().to(dtype).to(device) |
| 144 | + x_shape = [2] * x_dim |
| 145 | + x_shape[-1] = K |
| 146 | + x = torch.rand(x_shape, dtype=dtype, device=device) |
| 147 | + with torch.no_grad(): |
| 148 | + quantize_( |
| 149 | + m, |
| 150 | + Int8DynamicActivationInt4WeightConfig( |
| 151 | + group_size=32, |
| 152 | + layout=Int8DynamicActInt4WeightCPULayout(), |
| 153 | + act_mapping_type=MappingType.SYMMETRIC, |
| 154 | + ), |
| 155 | + ) |
| 156 | + # Need to turn on freezing to get the pattern |
| 157 | + # set enable_concat_linear to true to enable the fusion |
| 158 | + with torch._inductor.config.patch( |
| 159 | + {"freezing": True, "cpp.enable_concat_linear": True} |
| 160 | + ): |
| 161 | + y, code = torch._inductor.utils.run_and_get_code( |
| 162 | + torch.compile(m, fullgraph=True, dynamic=True), |
| 163 | + x, |
| 164 | + ) |
| 165 | + # ensure the expected op occurs only once in the code after fusion |
| 166 | + # The trailing "(" is to avoid matching the op in the comment |
| 167 | + assert code[0].count("torch.ops.torchao.da8w4_linear_cpu.default(") == 1 |
| 168 | + with torch._inductor.config.patch( |
| 169 | + {"freezing": True, "cpp.enable_concat_linear": False} |
| 170 | + ): |
| 171 | + y_ref, code = torch._inductor.utils.run_and_get_code( |
| 172 | + torch.compile(m, fullgraph=True, dynamic=True), |
| 173 | + x, |
| 174 | + ) |
| 175 | + assert torch.allclose(y, y_ref) |
| 176 | + |
| 177 | + |
| 178 | +common_utils.instantiate_parametrized_tests(TestDa8w4Cpu) |
| 179 | + |
| 180 | + |
| 181 | +if __name__ == "__main__": |
| 182 | + run_tests() |
0 commit comments