Skip to content

Commit 63bb0ae

Browse files
committed
Enabling MOE Quantization using linear decomposition
Summary: This PR is a first step at optimizing moe inference using torchAO. The goal for this step is to enable existing quantization kernels and workflows to work for moe quantization by decomposing the group gemm into a sequence of unbalanced linear ops that can use the existing quantized kernels. To enable this we had to add support for quantizing these 3D tensors as well as slicing and indexing. 2 methods of achieving this were implemented. for int8wo, int8dq, int4wo, fp8wo, fp8dq, the underlying quantized tensor subclass was adapted to both support 3D tensors, indexing and slicing, as well as an updated transformation function that can handle the ConditionalFeedForwardAOQuantizable modules if the filter funciton in quantize_ is used to target the aforementioned module. For some complex kernels which use packed data that couldn't be made to easily work in 3D, we also added FakeExtraDimTensor which can transform any quantized tensor subclass into supporting the necessary slice and index operations for moe quantization. This option is enabled by using MoeQuantConfig. This can be applied to huggingface llama4 for instance as shown int he llama4_quant.py example. Since the hf moe module is implemented in a way that's not condusive to quantization, it first requires a module swap to the MOEFeedForwardAOQuantizable. TODO final benchmark numbers from run.sh, consolidate 3x implementation of MOEFeedForwardAOQuantizable and ConditionalFeedForwardAOQuantizable. verify hqq Test Plan: python test/quantization/test_moe_quant.py python test/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py -k "test_moe_quant_intx" sh torchao/_models/mixtral-moe/run.sh Reviewers: Subscribers: Tasks: Tags:
1 parent 8369268 commit 63bb0ae

21 files changed

+2429
-130
lines changed

test/quantization/test_moe_quant.py

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
import unittest
2+
3+
import torch
4+
from parameterized import parameterized
5+
6+
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
7+
from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl
8+
from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl
9+
from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import (
10+
MOEFeedForwardAOQuantizable,
11+
)
12+
from torchao.quantization.prototype.moe_quant.utils import (
13+
FakeExtraDimTensor,
14+
MoEQuantConfig,
15+
cond_ffn_filter,
16+
)
17+
from torchao.quantization.quant_api import (
18+
AffineQuantizedTensor,
19+
Float8DynamicActivationFloat8WeightConfig,
20+
Float8WeightOnlyConfig,
21+
Int4WeightOnlyConfig,
22+
Int8DynamicActivationInt8WeightConfig,
23+
Int8WeightOnlyConfig,
24+
LinearActivationQuantizedTensor,
25+
quantize_,
26+
)
27+
from torchao.quantization.utils import compute_error
28+
from torchao.utils import is_sm_at_least_90, TORCH_VERSION_AT_LEAST_2_5
29+
30+
31+
class TestMoEQuantCompile(unittest.TestCase):
32+
DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k
33+
34+
@torch.no_grad()
35+
def _test_impl_moe_quant(
36+
self,
37+
config,
38+
num_tokens=1,
39+
model_params=None,
40+
base_class=AffineQuantizedTensor,
41+
tensor_impl_class=None,
42+
dtype=torch.bfloat16,
43+
device="cuda",
44+
fullgraph=False,
45+
):
46+
"""
47+
Tests moe quant for techniques using fake extra dim
48+
"""
49+
if model_params is None:
50+
model_params = self.DEFAULT_PARAMS
51+
52+
input_shape = (num_tokens, model_params[0])
53+
model = (
54+
MOEFeedForwardAOQuantizable(*model_params, empty_init=False)
55+
.to(dtype)
56+
.to(device)
57+
)
58+
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
59+
60+
out = model(input)
61+
62+
quantize_(model, config, cond_ffn_filter)
63+
64+
if isinstance(config, MoEQuantConfig):
65+
self.assertIsInstance(model.experts.w1, FakeExtraDimTensor)
66+
if base_class is not None:
67+
self.assertIsInstance(model.experts.w1.head_tensor, base_class)
68+
if tensor_impl_class is not None:
69+
self.assertIsInstance(
70+
model.experts.w1.head_tensor.tensor_impl, tensor_impl_class
71+
)
72+
else:
73+
if base_class is not None:
74+
self.assertIsInstance(model.experts.w1, base_class)
75+
if tensor_impl_class is not None:
76+
self.assertIsInstance(model.experts.w1.tensor_impl, tensor_impl_class)
77+
78+
out_q = model(input)
79+
80+
torch._dynamo.config.capture_scalar_outputs = True
81+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
82+
model_c = torch.compile(model, mode="reduce-overhead", fullgraph=fullgraph)
83+
84+
model_c(input)
85+
model_c(input)
86+
out_qc = model_c(input).clone()
87+
88+
for i in range(10):
89+
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
90+
model_c(input)
91+
92+
self.assertGreaterEqual(compute_error(out_q, out), 10)
93+
self.assertGreaterEqual(compute_error(out_qc, out), 10)
94+
95+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
96+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
97+
@parameterized.expand(
98+
[
99+
("single_token", 1, False),
100+
("multiple_tokens", 8, False),
101+
]
102+
)
103+
def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
104+
config = MoEQuantConfig(Int4WeightOnlyConfig())
105+
tensor_impl_class = TensorCoreTiledAQTTensorImpl
106+
107+
self._test_impl_moe_quant(
108+
config=config,
109+
num_tokens=num_tokens,
110+
tensor_impl_class=tensor_impl_class,
111+
fullgraph=fullgraph,
112+
)
113+
114+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
115+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
116+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
117+
@parameterized.expand(
118+
[
119+
("single_token", 1, True),
120+
("multiple_tokens", 8, False),
121+
]
122+
)
123+
def test_int4wo_base(self, name, num_tokens, fullgraph):
124+
config = Int4WeightOnlyConfig()
125+
tensor_impl_class = TensorCoreTiledAQTTensorImpl
126+
127+
self._test_impl_moe_quant(
128+
config=config,
129+
num_tokens=num_tokens,
130+
tensor_impl_class=tensor_impl_class,
131+
fullgraph=fullgraph,
132+
)
133+
134+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
135+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
136+
@parameterized.expand(
137+
[
138+
("single_token", 1, False),
139+
("multiple_tokens", 8, False),
140+
]
141+
)
142+
def test_int8wo_fake_dim(self, name, num_tokens, fullgraph):
143+
config = MoEQuantConfig(Int8WeightOnlyConfig())
144+
tensor_impl_class = PlainAQTTensorImpl
145+
146+
self._test_impl_moe_quant(
147+
config=config,
148+
num_tokens=num_tokens,
149+
tensor_impl_class=tensor_impl_class,
150+
fullgraph=fullgraph,
151+
)
152+
153+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
154+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
155+
@parameterized.expand(
156+
[
157+
("single_token", 1, True),
158+
("multiple_tokens", 8, False),
159+
]
160+
)
161+
def test_int8wo_base(self, name, num_tokens, fullgraph):
162+
config = Int8WeightOnlyConfig()
163+
tensor_impl_class = PlainAQTTensorImpl
164+
165+
self._test_impl_moe_quant(
166+
config=config,
167+
num_tokens=num_tokens,
168+
tensor_impl_class=tensor_impl_class,
169+
fullgraph=fullgraph,
170+
)
171+
172+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
173+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
174+
@parameterized.expand(
175+
[
176+
("multiple_tokens", 32, False),
177+
]
178+
)
179+
def test_int8dq_fake_dim(self, name, num_tokens, fullgraph):
180+
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
181+
base_class = LinearActivationQuantizedTensor
182+
183+
self._test_impl_moe_quant(
184+
model_params=(512, 256, 2, 2),
185+
config=config,
186+
num_tokens=num_tokens,
187+
base_class=base_class,
188+
fullgraph=fullgraph,
189+
)
190+
191+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
192+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
193+
@parameterized.expand(
194+
[
195+
("multiple_tokens", 32, False),
196+
]
197+
)
198+
def test_int8dq_base(self, name, num_tokens, fullgraph):
199+
config = Int8DynamicActivationInt8WeightConfig()
200+
base_class = LinearActivationQuantizedTensor
201+
202+
self._test_impl_moe_quant(
203+
model_params=(512, 256, 2, 2),
204+
config=config,
205+
num_tokens=num_tokens,
206+
base_class=base_class,
207+
fullgraph=fullgraph,
208+
)
209+
210+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
211+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
212+
@parameterized.expand(
213+
[
214+
("single_token", 1, False),
215+
("multiple_tokens", 8, False),
216+
]
217+
)
218+
def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph):
219+
config = MoEQuantConfig(Float8WeightOnlyConfig())
220+
tensor_impl_class = Float8AQTTensorImpl
221+
222+
self._test_impl_moe_quant(
223+
config=config,
224+
num_tokens=num_tokens,
225+
tensor_impl_class=tensor_impl_class,
226+
fullgraph=fullgraph,
227+
)
228+
229+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
230+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
231+
@parameterized.expand(
232+
[
233+
("single_token", 1, True),
234+
("multiple_tokens", 8, False),
235+
]
236+
)
237+
def test_fp8wo_base(self, name, num_tokens, fullgraph):
238+
config = Float8WeightOnlyConfig()
239+
tensor_impl_class = Float8AQTTensorImpl
240+
241+
self._test_impl_moe_quant(
242+
config=config,
243+
num_tokens=num_tokens,
244+
tensor_impl_class=tensor_impl_class,
245+
fullgraph=fullgraph,
246+
)
247+
248+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
249+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
250+
@parameterized.expand(
251+
[
252+
("single_token", 1, False),
253+
("multiple_tokens", 8, False),
254+
]
255+
)
256+
def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph):
257+
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig())
258+
base_class = LinearActivationQuantizedTensor
259+
260+
self._test_impl_moe_quant(
261+
config=config,
262+
num_tokens=num_tokens,
263+
base_class=base_class,
264+
fullgraph=fullgraph,
265+
)
266+
267+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
268+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
269+
@parameterized.expand(
270+
[
271+
("single_token", 1, True),
272+
("multiple_tokens", 8, False),
273+
]
274+
)
275+
def test_fp8dq_base(self, name, num_tokens, fullgraph):
276+
config = Float8DynamicActivationFloat8WeightConfig()
277+
base_class = LinearActivationQuantizedTensor
278+
279+
self._test_impl_moe_quant(
280+
config=config,
281+
num_tokens=num_tokens,
282+
base_class=base_class,
283+
fullgraph=fullgraph,
284+
)
285+
286+
287+
if __name__ == "__main__":
288+
unittest.main()

0 commit comments

Comments
 (0)