|
| 1 | +import torch |
| 2 | +import unittest |
| 3 | +from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import MOEFeedForwardAOQuantizable |
| 4 | +from torchao.quantization.prototype.moe_quant.utils import MoEQuantConfig, cond_ffn_filter, FakeExtraDimTensor |
| 5 | +from torchao.quantization.quant_api import ( |
| 6 | + Int8WeightOnlyConfig, |
| 7 | + Int8DynamicActivationInt8WeightConfig, |
| 8 | + Int4WeightOnlyConfig, |
| 9 | + Float8WeightOnlyConfig, |
| 10 | + Int8DynamicActivationIntxWeightConfig, |
| 11 | + PackedLinearInt8DynamicActivationIntxWeightLayout, |
| 12 | + quantize_, |
| 13 | + AffineQuantizedTensor, |
| 14 | + LinearActivationQuantizedTensor, |
| 15 | + Float8DynamicActivationFloat8WeightConfig, |
| 16 | +) |
| 17 | +from torchao.quantization.utils import compute_error |
| 18 | +from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl |
| 19 | +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl |
| 20 | +from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl |
| 21 | +from parameterized import param, parameterized |
| 22 | + |
| 23 | +class TestMoEQuantCompile(unittest.TestCase): |
| 24 | + DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k |
| 25 | + |
| 26 | + @torch.no_grad() |
| 27 | + def _test_impl_moe_quant(self, |
| 28 | + config, |
| 29 | + num_tokens=1, |
| 30 | + model_params=None, |
| 31 | + base_class=AffineQuantizedTensor, |
| 32 | + tensor_impl_class=None, |
| 33 | + dtype=torch.bfloat16, |
| 34 | + device="cuda", |
| 35 | + fullgraph=False |
| 36 | + ): |
| 37 | + """ |
| 38 | + Tests moe quant for techniques using fake extra dim |
| 39 | + """ |
| 40 | + if model_params is None: |
| 41 | + model_params=self.DEFAULT_PARAMS |
| 42 | + |
| 43 | + input_shape = (num_tokens, model_params[0]) |
| 44 | + model = MOEFeedForwardAOQuantizable(*model_params).to(dtype).to(device) |
| 45 | + input = torch.randn(input_shape, dtype=torch.bfloat16, device=device) |
| 46 | + |
| 47 | + out = model(input) |
| 48 | + |
| 49 | + import copy |
| 50 | + new_mod = copy.deepcopy(model) |
| 51 | + |
| 52 | + quantize_(model, config, cond_ffn_filter) |
| 53 | + |
| 54 | + if isinstance(config, MoEQuantConfig): |
| 55 | + self.assertIsInstance(model.experts.w1, FakeExtraDimTensor) |
| 56 | + if base_class is not None: |
| 57 | + self.assertIsInstance(model.experts.w1.head_tensor, base_class) |
| 58 | + if tensor_impl_class is not None: |
| 59 | + self.assertIsInstance(model.experts.w1.head_tensor.tensor_impl, tensor_impl_class) |
| 60 | + else: |
| 61 | + if base_class is not None: |
| 62 | + self.assertIsInstance(model.experts.w1, base_class) |
| 63 | + if tensor_impl_class is not None: |
| 64 | + self.assertIsInstance(model.experts.w1.tensor_impl, tensor_impl_class) |
| 65 | + |
| 66 | + out_q = model(input) |
| 67 | + |
| 68 | + torch._dynamo.config.capture_scalar_outputs = True |
| 69 | + torch._dynamo.config.capture_dynamic_output_shape_ops = True |
| 70 | + model_c = torch.compile(model, mode="reduce-overhead", fullgraph=fullgraph) |
| 71 | + |
| 72 | + model_c(input) |
| 73 | + model_c(input) |
| 74 | + out_qc = model_c(input).clone() |
| 75 | + |
| 76 | + for i in range(10): |
| 77 | + input = torch.randn(input_shape, dtype=torch.bfloat16, device=device) |
| 78 | + model_c(input) |
| 79 | + |
| 80 | + self.assertGreaterEqual(compute_error(out_q, out), 10) |
| 81 | + self.assertGreaterEqual(compute_error(out_qc, out), 10) |
| 82 | + print(compute_error(out_q, out), compute_error(out_qc, out)) |
| 83 | + |
| 84 | + |
| 85 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 86 | + @parameterized.expand([ |
| 87 | + ("single_token", 1, False), |
| 88 | + ("multiple_tokens", 8, False), |
| 89 | + ]) |
| 90 | + def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): |
| 91 | + config = MoEQuantConfig(Int4WeightOnlyConfig()) |
| 92 | + tensor_impl_class = TensorCoreTiledAQTTensorImpl |
| 93 | + |
| 94 | + self._test_impl_moe_quant( |
| 95 | + config=config, |
| 96 | + num_tokens=num_tokens, |
| 97 | + tensor_impl_class=tensor_impl_class, |
| 98 | + fullgraph=fullgraph |
| 99 | + ) |
| 100 | + |
| 101 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 102 | + @parameterized.expand([ |
| 103 | + ("single_token", 1, True), |
| 104 | + ("multiple_tokens", 8, False), |
| 105 | + ]) |
| 106 | + def test_int4wo_base(self, name, num_tokens, fullgraph): |
| 107 | + config = Int4WeightOnlyConfig() |
| 108 | + tensor_impl_class = TensorCoreTiledAQTTensorImpl |
| 109 | + |
| 110 | + self._test_impl_moe_quant( |
| 111 | + config=config, |
| 112 | + num_tokens=num_tokens, |
| 113 | + tensor_impl_class=tensor_impl_class, |
| 114 | + fullgraph=fullgraph |
| 115 | + ) |
| 116 | + |
| 117 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 118 | + @parameterized.expand([ |
| 119 | + ("single_token", 1, False), |
| 120 | + ("multiple_tokens", 8, False), |
| 121 | + ]) |
| 122 | + def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): |
| 123 | + config = MoEQuantConfig(Int8WeightOnlyConfig()) |
| 124 | + tensor_impl_class = PlainAQTTensorImpl |
| 125 | + |
| 126 | + self._test_impl_moe_quant( |
| 127 | + config=config, |
| 128 | + num_tokens=num_tokens, |
| 129 | + tensor_impl_class=tensor_impl_class, |
| 130 | + fullgraph=fullgraph |
| 131 | + ) |
| 132 | + |
| 133 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 134 | + @parameterized.expand([ |
| 135 | + ("single_token", 1, True), |
| 136 | + ("multiple_tokens", 8, False), |
| 137 | + ]) |
| 138 | + def test_int8wo_base(self, name, num_tokens, fullgraph): |
| 139 | + config = Int8WeightOnlyConfig() |
| 140 | + tensor_impl_class = PlainAQTTensorImpl |
| 141 | + |
| 142 | + self._test_impl_moe_quant( |
| 143 | + config=config, |
| 144 | + num_tokens=num_tokens, |
| 145 | + tensor_impl_class=tensor_impl_class, |
| 146 | + fullgraph=fullgraph |
| 147 | + ) |
| 148 | + |
| 149 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 150 | + @parameterized.expand([ |
| 151 | + ("multiple_tokens", 32, False), |
| 152 | + ]) |
| 153 | + def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): |
| 154 | + config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) |
| 155 | + base_class = LinearActivationQuantizedTensor |
| 156 | + |
| 157 | + self._test_impl_moe_quant( |
| 158 | + model_params=(512, 256, 2, 2), |
| 159 | + config=config, |
| 160 | + num_tokens=num_tokens, |
| 161 | + base_class=base_class, |
| 162 | + fullgraph=fullgraph |
| 163 | + ) |
| 164 | + |
| 165 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 166 | + @parameterized.expand([ |
| 167 | + ("multiple_tokens", 32, False), |
| 168 | + ]) |
| 169 | + def test_int8dq_base(self, name, num_tokens, fullgraph): |
| 170 | + config = Int8DynamicActivationInt8WeightConfig() |
| 171 | + base_class = LinearActivationQuantizedTensor |
| 172 | + |
| 173 | + self._test_impl_moe_quant( |
| 174 | + model_params=(512, 256, 2, 2), |
| 175 | + config=config, |
| 176 | + num_tokens=num_tokens, |
| 177 | + base_class=base_class, |
| 178 | + fullgraph=fullgraph |
| 179 | + ) |
| 180 | + |
| 181 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 182 | + @parameterized.expand([ |
| 183 | + ("single_token", 1, False), |
| 184 | + ("multiple_tokens", 8, False), |
| 185 | + |
| 186 | + ]) |
| 187 | + def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph): |
| 188 | + config = MoEQuantConfig(Float8WeightOnlyConfig()) |
| 189 | + tensor_impl_class = Float8AQTTensorImpl |
| 190 | + |
| 191 | + self._test_impl_moe_quant( |
| 192 | + config=config, |
| 193 | + num_tokens=num_tokens, |
| 194 | + tensor_impl_class=tensor_impl_class, |
| 195 | + fullgraph=fullgraph |
| 196 | + ) |
| 197 | + |
| 198 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 199 | + @parameterized.expand([ |
| 200 | + ("single_token", 1, True), |
| 201 | + ("multiple_tokens", 8, False), |
| 202 | + ]) |
| 203 | + def test_fp8wo_base(self, name, num_tokens, fullgraph): |
| 204 | + config = Float8WeightOnlyConfig() |
| 205 | + tensor_impl_class = Float8AQTTensorImpl |
| 206 | + |
| 207 | + self._test_impl_moe_quant( |
| 208 | + config=config, |
| 209 | + num_tokens=num_tokens, |
| 210 | + tensor_impl_class=tensor_impl_class, |
| 211 | + fullgraph=fullgraph |
| 212 | + ) |
| 213 | + |
| 214 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 215 | + @parameterized.expand([ |
| 216 | + ("single_token", 1, False), |
| 217 | + ("multiple_tokens", 8, False), |
| 218 | + |
| 219 | + ]) |
| 220 | + def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph): |
| 221 | + config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig()) |
| 222 | + base_class = LinearActivationQuantizedTensor |
| 223 | + |
| 224 | + self._test_impl_moe_quant( |
| 225 | + config=config, |
| 226 | + num_tokens=num_tokens, |
| 227 | + base_class=base_class, |
| 228 | + fullgraph=fullgraph |
| 229 | + ) |
| 230 | + |
| 231 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 232 | + @parameterized.expand([ |
| 233 | + ("single_token", 1, True), |
| 234 | + ("multiple_tokens", 8, False), |
| 235 | + ]) |
| 236 | + def test_fp8dq_base(self, name, num_tokens, fullgraph): |
| 237 | + config = Float8DynamicActivationFloat8WeightConfig() |
| 238 | + base_class = LinearActivationQuantizedTensor |
| 239 | + |
| 240 | + self._test_impl_moe_quant( |
| 241 | + config=config, |
| 242 | + num_tokens=num_tokens, |
| 243 | + base_class=base_class, |
| 244 | + fullgraph=fullgraph |
| 245 | + ) |
| 246 | + |
| 247 | + |
| 248 | + |
| 249 | + |
| 250 | + |
| 251 | + |
| 252 | + |
| 253 | + |
| 254 | +if __name__ == "__main__": |
| 255 | + unittest.main() |
0 commit comments