Skip to content

Commit edf7488

Browse files
committed
Enabling MOE Quantization using linear decomposition
Summary: This PR is a first step at optimizing moe inference using torchAO. The goal for this step is to enable existing quantization kernels and workflows to work for moe quantization by decomposing the group gemm into a sequence of unbalanced linear ops that can use the existing quantized kernels. To enable this we had to add support for quantizing these 3D tensors as well as slicing and indexing. 2 methods of achieving this were implemented. for int8wo, int8dq, int4wo, fp8wo, fp8dq, the underlying quantized tensor subclass was adapted to both support 3D tensors, indexing and slicing, as well as an updated transformation function that can handle the ConditionalFeedForwardAOQuantizable modules if the filter funciton in quantize_ is used to target the aforementioned module. For some complex kernels which use packed data that couldn't be made to easily work in 3D, we also added FakeExtraDimTensor which can transform any quantized tensor subclass into supporting the necessary slice and index operations for moe quantization. This option is enabled by using MoeQuantConfig. This can be applied to huggingface llama4 for instance as shown int he llama4_quant.py example. Since the hf moe module is implemented in a way that's not condusive to quantization, it first requires a module swap to the MOEFeedForwardAOQuantizable. TODO final benchmark numbers from run.sh, consolidate 3x implementation of MOEFeedForwardAOQuantizable and ConditionalFeedForwardAOQuantizable. verify hqq Test Plan: python test/quantization/test_moe_quant.py python test/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py -k "test_moe_quant_intx" sh torchao/_models/mixtral-moe/run.sh Reviewers: Subscribers: Tasks: Tags: testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent d9fe8b6 commit edf7488

21 files changed

+2431
-130
lines changed

test/quantization/test_moe_quant.py

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
import unittest
2+
3+
import torch
4+
from parameterized import parameterized
5+
6+
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
7+
from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl
8+
from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl
9+
from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import (
10+
MOEFeedForwardAOQuantizable,
11+
)
12+
from torchao.quantization.prototype.moe_quant.utils import (
13+
FakeExtraDimTensor,
14+
MoEQuantConfig,
15+
cond_ffn_filter,
16+
)
17+
from torchao.quantization.quant_api import (
18+
AffineQuantizedTensor,
19+
Float8DynamicActivationFloat8WeightConfig,
20+
Float8WeightOnlyConfig,
21+
Int4WeightOnlyConfig,
22+
Int8DynamicActivationInt8WeightConfig,
23+
Int8WeightOnlyConfig,
24+
LinearActivationQuantizedTensor,
25+
quantize_,
26+
)
27+
from torchao.quantization.utils import compute_error
28+
29+
30+
class TestMoEQuantCompile(unittest.TestCase):
31+
DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k
32+
33+
@torch.no_grad()
34+
def _test_impl_moe_quant(
35+
self,
36+
config,
37+
num_tokens=1,
38+
model_params=None,
39+
base_class=AffineQuantizedTensor,
40+
tensor_impl_class=None,
41+
dtype=torch.bfloat16,
42+
device="cuda",
43+
fullgraph=False,
44+
):
45+
"""
46+
Tests moe quant for techniques using fake extra dim
47+
"""
48+
if model_params is None:
49+
model_params = self.DEFAULT_PARAMS
50+
51+
input_shape = (num_tokens, model_params[0])
52+
model = (
53+
MOEFeedForwardAOQuantizable(*model_params, empty_init=False)
54+
.to(dtype)
55+
.to(device)
56+
)
57+
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
58+
59+
out = model(input)
60+
61+
quantize_(model, config, cond_ffn_filter)
62+
63+
if isinstance(config, MoEQuantConfig):
64+
self.assertIsInstance(model.experts.w1, FakeExtraDimTensor)
65+
if base_class is not None:
66+
self.assertIsInstance(model.experts.w1.head_tensor, base_class)
67+
if tensor_impl_class is not None:
68+
self.assertIsInstance(
69+
model.experts.w1.head_tensor.tensor_impl, tensor_impl_class
70+
)
71+
else:
72+
if base_class is not None:
73+
self.assertIsInstance(model.experts.w1, base_class)
74+
if tensor_impl_class is not None:
75+
self.assertIsInstance(model.experts.w1.tensor_impl, tensor_impl_class)
76+
77+
out_q = model(input)
78+
79+
torch._dynamo.config.capture_scalar_outputs = True
80+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
81+
model_c = torch.compile(model, mode="reduce-overhead", fullgraph=fullgraph)
82+
83+
model_c(input)
84+
model_c(input)
85+
out_qc = model_c(input).clone()
86+
87+
for i in range(10):
88+
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
89+
model_c(input)
90+
91+
self.assertGreaterEqual(compute_error(out_q, out), 10)
92+
self.assertGreaterEqual(compute_error(out_qc, out), 10)
93+
94+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
95+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
96+
@parameterized.expand(
97+
[
98+
("single_token", 1, False),
99+
("multiple_tokens", 8, False),
100+
]
101+
)
102+
def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
103+
config = MoEQuantConfig(Int4WeightOnlyConfig())
104+
tensor_impl_class = TensorCoreTiledAQTTensorImpl
105+
106+
self._test_impl_moe_quant(
107+
config=config,
108+
num_tokens=num_tokens,
109+
tensor_impl_class=tensor_impl_class,
110+
fullgraph=fullgraph,
111+
)
112+
113+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
114+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
115+
@parameterized.expand(
116+
[
117+
("single_token", 1, True),
118+
("multiple_tokens", 8, False),
119+
]
120+
)
121+
def test_int4wo_base(self, name, num_tokens, fullgraph):
122+
config = Int4WeightOnlyConfig()
123+
tensor_impl_class = TensorCoreTiledAQTTensorImpl
124+
125+
self._test_impl_moe_quant(
126+
config=config,
127+
num_tokens=num_tokens,
128+
tensor_impl_class=tensor_impl_class,
129+
fullgraph=fullgraph,
130+
)
131+
132+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
133+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
134+
@parameterized.expand(
135+
[
136+
("single_token", 1, False),
137+
("multiple_tokens", 8, False),
138+
]
139+
)
140+
def test_int8wo_fake_dim(self, name, num_tokens, fullgraph):
141+
config = MoEQuantConfig(Int8WeightOnlyConfig())
142+
tensor_impl_class = PlainAQTTensorImpl
143+
144+
self._test_impl_moe_quant(
145+
config=config,
146+
num_tokens=num_tokens,
147+
tensor_impl_class=tensor_impl_class,
148+
fullgraph=fullgraph,
149+
)
150+
151+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
152+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
153+
@parameterized.expand(
154+
[
155+
("single_token", 1, True),
156+
("multiple_tokens", 8, False),
157+
]
158+
)
159+
def test_int8wo_base(self, name, num_tokens, fullgraph):
160+
config = Int8WeightOnlyConfig()
161+
tensor_impl_class = PlainAQTTensorImpl
162+
163+
self._test_impl_moe_quant(
164+
config=config,
165+
num_tokens=num_tokens,
166+
tensor_impl_class=tensor_impl_class,
167+
fullgraph=fullgraph,
168+
)
169+
170+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
171+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
172+
@parameterized.expand(
173+
[
174+
("multiple_tokens", 32, False),
175+
]
176+
)
177+
def test_int8dq_fake_dim(self, name, num_tokens, fullgraph):
178+
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
179+
base_class = LinearActivationQuantizedTensor
180+
181+
self._test_impl_moe_quant(
182+
model_params=(512, 256, 2, 2),
183+
config=config,
184+
num_tokens=num_tokens,
185+
base_class=base_class,
186+
fullgraph=fullgraph,
187+
)
188+
189+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
190+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
191+
@parameterized.expand(
192+
[
193+
("multiple_tokens", 32, False),
194+
]
195+
)
196+
def test_int8dq_base(self, name, num_tokens, fullgraph):
197+
config = Int8DynamicActivationInt8WeightConfig()
198+
base_class = LinearActivationQuantizedTensor
199+
200+
self._test_impl_moe_quant(
201+
model_params=(512, 256, 2, 2),
202+
config=config,
203+
num_tokens=num_tokens,
204+
base_class=base_class,
205+
fullgraph=fullgraph,
206+
)
207+
208+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
209+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
210+
@parameterized.expand(
211+
[
212+
("single_token", 1, False),
213+
("multiple_tokens", 8, False),
214+
]
215+
)
216+
def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph):
217+
config = MoEQuantConfig(Float8WeightOnlyConfig())
218+
tensor_impl_class = Float8AQTTensorImpl
219+
220+
self._test_impl_moe_quant(
221+
config=config,
222+
num_tokens=num_tokens,
223+
tensor_impl_class=tensor_impl_class,
224+
fullgraph=fullgraph,
225+
)
226+
227+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
228+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
229+
@parameterized.expand(
230+
[
231+
("single_token", 1, True),
232+
("multiple_tokens", 8, False),
233+
]
234+
)
235+
def test_fp8wo_base(self, name, num_tokens, fullgraph):
236+
config = Float8WeightOnlyConfig()
237+
tensor_impl_class = Float8AQTTensorImpl
238+
239+
self._test_impl_moe_quant(
240+
config=config,
241+
num_tokens=num_tokens,
242+
tensor_impl_class=tensor_impl_class,
243+
fullgraph=fullgraph,
244+
)
245+
246+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
247+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
248+
@parameterized.expand(
249+
[
250+
("single_token", 1, False),
251+
("multiple_tokens", 8, False),
252+
]
253+
)
254+
def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph):
255+
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig())
256+
base_class = LinearActivationQuantizedTensor
257+
258+
self._test_impl_moe_quant(
259+
config=config,
260+
num_tokens=num_tokens,
261+
base_class=base_class,
262+
fullgraph=fullgraph,
263+
)
264+
265+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
266+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
267+
@parameterized.expand(
268+
[
269+
("single_token", 1, True),
270+
("multiple_tokens", 8, False),
271+
]
272+
)
273+
def test_fp8dq_base(self, name, num_tokens, fullgraph):
274+
config = Float8DynamicActivationFloat8WeightConfig()
275+
base_class = LinearActivationQuantizedTensor
276+
277+
self._test_impl_moe_quant(
278+
config=config,
279+
num_tokens=num_tokens,
280+
base_class=base_class,
281+
fullgraph=fullgraph,
282+
)
283+
284+
285+
if __name__ == "__main__":
286+
unittest.main()

0 commit comments

Comments
 (0)