Skip to content

Commit 90233e1

Browse files
committed
Enabling MOE Quantization using linear decomposition
Summary: This PR is a first step at optimizing moe inference using torchAO. The goal for this step is to enable existing quantization kernels and workflows to work for moe quantization by decomposing the group gemm into a sequence of unbalanced linear ops that can use the existing quantized kernels. To enable this we had to add support for quantizing these 3D tensors as well as slicing and indexing. 2 methods of achieving this were implemented. for int8wo, int8dq, int4wo, fp8wo, fp8dq, the underlying quantized tensor subclass was adapted to both support 3D tensors, indexing and slicing, as well as an updated transformation function that can handle the ConditionalFeedForwardAOQuantizable modules if the filter funciton in quantize_ is used to target the aforementioned module. For some complex kernels which use packed data that couldn't be made to easily work in 3D, we also added FakeExtraDimTensor which can transform any quantized tensor subclass into supporting the necessary slice and index operations for moe quantization. This option is enabled by using MoeQuantConfig. This can be applied to huggingface llama4 for instance as shown int he llama4_quant.py example. Since the hf moe module is implemented in a way that's not condusive to quantization, it first requires a module swap to the MOEFeedForwardAOQuantizable. TODO final benchmark numbers from run.sh, consolidate 3x implementation of MOEFeedForwardAOQuantizable and ConditionalFeedForwardAOQuantizable. verify hqq Test Plan: python test/quantization/test_moe_quant.py python test/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py -k "test_moe_quant_intx" sh torchao/_models/mixtral-moe/run.sh Reviewers: Subscribers: Tasks: Tags: testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent a81322e commit 90233e1

21 files changed

+2309
-134
lines changed

test/quantization/test_moe_quant.py

+255
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import torch
2+
import unittest
3+
from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import MOEFeedForwardAOQuantizable
4+
from torchao.quantization.prototype.moe_quant.utils import MoEQuantConfig, cond_ffn_filter, FakeExtraDimTensor
5+
from torchao.quantization.quant_api import (
6+
Int8WeightOnlyConfig,
7+
Int8DynamicActivationInt8WeightConfig,
8+
Int4WeightOnlyConfig,
9+
Float8WeightOnlyConfig,
10+
Int8DynamicActivationIntxWeightConfig,
11+
PackedLinearInt8DynamicActivationIntxWeightLayout,
12+
quantize_,
13+
AffineQuantizedTensor,
14+
LinearActivationQuantizedTensor,
15+
Float8DynamicActivationFloat8WeightConfig,
16+
)
17+
from torchao.quantization.utils import compute_error
18+
from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl
19+
from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl
20+
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
21+
from parameterized import param, parameterized
22+
23+
class TestMoEQuantCompile(unittest.TestCase):
24+
DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k
25+
26+
@torch.no_grad()
27+
def _test_impl_moe_quant(self,
28+
config,
29+
num_tokens=1,
30+
model_params=None,
31+
base_class=AffineQuantizedTensor,
32+
tensor_impl_class=None,
33+
dtype=torch.bfloat16,
34+
device="cuda",
35+
fullgraph=False
36+
):
37+
"""
38+
Tests moe quant for techniques using fake extra dim
39+
"""
40+
if model_params is None:
41+
model_params=self.DEFAULT_PARAMS
42+
43+
input_shape = (num_tokens, model_params[0])
44+
model = MOEFeedForwardAOQuantizable(*model_params).to(dtype).to(device)
45+
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
46+
47+
out = model(input)
48+
49+
import copy
50+
new_mod = copy.deepcopy(model)
51+
52+
quantize_(model, config, cond_ffn_filter)
53+
54+
if isinstance(config, MoEQuantConfig):
55+
self.assertIsInstance(model.experts.w1, FakeExtraDimTensor)
56+
if base_class is not None:
57+
self.assertIsInstance(model.experts.w1.head_tensor, base_class)
58+
if tensor_impl_class is not None:
59+
self.assertIsInstance(model.experts.w1.head_tensor.tensor_impl, tensor_impl_class)
60+
else:
61+
if base_class is not None:
62+
self.assertIsInstance(model.experts.w1, base_class)
63+
if tensor_impl_class is not None:
64+
self.assertIsInstance(model.experts.w1.tensor_impl, tensor_impl_class)
65+
66+
out_q = model(input)
67+
68+
torch._dynamo.config.capture_scalar_outputs = True
69+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
70+
model_c = torch.compile(model, mode="reduce-overhead", fullgraph=fullgraph)
71+
72+
model_c(input)
73+
model_c(input)
74+
out_qc = model_c(input).clone()
75+
76+
for i in range(10):
77+
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
78+
model_c(input)
79+
80+
self.assertGreaterEqual(compute_error(out_q, out), 10)
81+
self.assertGreaterEqual(compute_error(out_qc, out), 10)
82+
print(compute_error(out_q, out), compute_error(out_qc, out))
83+
84+
85+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
86+
@parameterized.expand([
87+
("single_token", 1, False),
88+
("multiple_tokens", 8, False),
89+
])
90+
def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
91+
config = MoEQuantConfig(Int4WeightOnlyConfig())
92+
tensor_impl_class = TensorCoreTiledAQTTensorImpl
93+
94+
self._test_impl_moe_quant(
95+
config=config,
96+
num_tokens=num_tokens,
97+
tensor_impl_class=tensor_impl_class,
98+
fullgraph=fullgraph
99+
)
100+
101+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
102+
@parameterized.expand([
103+
("single_token", 1, True),
104+
("multiple_tokens", 8, False),
105+
])
106+
def test_int4wo_base(self, name, num_tokens, fullgraph):
107+
config = Int4WeightOnlyConfig()
108+
tensor_impl_class = TensorCoreTiledAQTTensorImpl
109+
110+
self._test_impl_moe_quant(
111+
config=config,
112+
num_tokens=num_tokens,
113+
tensor_impl_class=tensor_impl_class,
114+
fullgraph=fullgraph
115+
)
116+
117+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
118+
@parameterized.expand([
119+
("single_token", 1, False),
120+
("multiple_tokens", 8, False),
121+
])
122+
def test_int8wo_fake_dim(self, name, num_tokens, fullgraph):
123+
config = MoEQuantConfig(Int8WeightOnlyConfig())
124+
tensor_impl_class = PlainAQTTensorImpl
125+
126+
self._test_impl_moe_quant(
127+
config=config,
128+
num_tokens=num_tokens,
129+
tensor_impl_class=tensor_impl_class,
130+
fullgraph=fullgraph
131+
)
132+
133+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
134+
@parameterized.expand([
135+
("single_token", 1, True),
136+
("multiple_tokens", 8, False),
137+
])
138+
def test_int8wo_base(self, name, num_tokens, fullgraph):
139+
config = Int8WeightOnlyConfig()
140+
tensor_impl_class = PlainAQTTensorImpl
141+
142+
self._test_impl_moe_quant(
143+
config=config,
144+
num_tokens=num_tokens,
145+
tensor_impl_class=tensor_impl_class,
146+
fullgraph=fullgraph
147+
)
148+
149+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
150+
@parameterized.expand([
151+
("multiple_tokens", 32, False),
152+
])
153+
def test_int8dq_fake_dim(self, name, num_tokens, fullgraph):
154+
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
155+
base_class = LinearActivationQuantizedTensor
156+
157+
self._test_impl_moe_quant(
158+
model_params=(512, 256, 2, 2),
159+
config=config,
160+
num_tokens=num_tokens,
161+
base_class=base_class,
162+
fullgraph=fullgraph
163+
)
164+
165+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
166+
@parameterized.expand([
167+
("multiple_tokens", 32, False),
168+
])
169+
def test_int8dq_base(self, name, num_tokens, fullgraph):
170+
config = Int8DynamicActivationInt8WeightConfig()
171+
base_class = LinearActivationQuantizedTensor
172+
173+
self._test_impl_moe_quant(
174+
model_params=(512, 256, 2, 2),
175+
config=config,
176+
num_tokens=num_tokens,
177+
base_class=base_class,
178+
fullgraph=fullgraph
179+
)
180+
181+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
182+
@parameterized.expand([
183+
("single_token", 1, False),
184+
("multiple_tokens", 8, False),
185+
186+
])
187+
def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph):
188+
config = MoEQuantConfig(Float8WeightOnlyConfig())
189+
tensor_impl_class = Float8AQTTensorImpl
190+
191+
self._test_impl_moe_quant(
192+
config=config,
193+
num_tokens=num_tokens,
194+
tensor_impl_class=tensor_impl_class,
195+
fullgraph=fullgraph
196+
)
197+
198+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
199+
@parameterized.expand([
200+
("single_token", 1, True),
201+
("multiple_tokens", 8, False),
202+
])
203+
def test_fp8wo_base(self, name, num_tokens, fullgraph):
204+
config = Float8WeightOnlyConfig()
205+
tensor_impl_class = Float8AQTTensorImpl
206+
207+
self._test_impl_moe_quant(
208+
config=config,
209+
num_tokens=num_tokens,
210+
tensor_impl_class=tensor_impl_class,
211+
fullgraph=fullgraph
212+
)
213+
214+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
215+
@parameterized.expand([
216+
("single_token", 1, False),
217+
("multiple_tokens", 8, False),
218+
219+
])
220+
def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph):
221+
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig())
222+
base_class = LinearActivationQuantizedTensor
223+
224+
self._test_impl_moe_quant(
225+
config=config,
226+
num_tokens=num_tokens,
227+
base_class=base_class,
228+
fullgraph=fullgraph
229+
)
230+
231+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
232+
@parameterized.expand([
233+
("single_token", 1, True),
234+
("multiple_tokens", 8, False),
235+
])
236+
def test_fp8dq_base(self, name, num_tokens, fullgraph):
237+
config = Float8DynamicActivationFloat8WeightConfig()
238+
base_class = LinearActivationQuantizedTensor
239+
240+
self._test_impl_moe_quant(
241+
config=config,
242+
num_tokens=num_tokens,
243+
base_class=base_class,
244+
fullgraph=fullgraph
245+
)
246+
247+
248+
249+
250+
251+
252+
253+
254+
if __name__ == "__main__":
255+
unittest.main()

0 commit comments

Comments
 (0)