Skip to content

Commit cde56ad

Browse files
committed
Enabling MOE Quantization using linear decomposition
Summary: This PR is a first step at optimizing moe inference using torchAO. The goal for this step is to enable existing quantization kernels and workflows to work for moe quantization by decomposing the group gemm into a sequence of unbalanced linear ops that can use the existing quantized kernels. To enable this we had to add support for quantizing these 3D tensors as well as slicing and indexing. 2 methods of achieving this were implemented. for int8wo, int8dq, int4wo, fp8wo, fp8dq, the underlying quantized tensor subclass was adapted to both support 3D tensors, indexing and slicing, as well as an updated transformation function that can handle the ConditionalFeedForwardAOQuantizable modules if the filter funciton in quantize_ is used to target the aforementioned module. For some complex kernels which use packed data that couldn't be made to easily work in 3D, we also added FakeExtraDimTensor which can transform any quantized tensor subclass into supporting the necessary slice and index operations for moe quantization. This option is enabled by using MoeQuantConfig. This can be applied to huggingface llama4 for instance as shown int he llama4_quant.py example. Since the hf moe module is implemented in a way that's not condusive to quantization, it first requires a module swap to the MOEFeedForwardAOQuantizable. TODO final benchmark numbers from run.sh, consolidate 3x implementation of MOEFeedForwardAOQuantizable and ConditionalFeedForwardAOQuantizable. verify hqq Test Plan: python test/quantization/test_moe_quant.py python test/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py -k "test_moe_quant_intx" sh torchao/_models/mixtral-moe/run.sh Reviewers: Subscribers: Tasks: Tags: testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent d9fe8b6 commit cde56ad

21 files changed

+2062
-135
lines changed

test/quantization/test_moe_quant.py

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
import unittest
2+
3+
import torch
4+
from parameterized import parameterized
5+
6+
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
7+
from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl
8+
from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl
9+
from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import (
10+
MOEFeedForwardAOQuantizable,
11+
)
12+
from torchao.quantization.prototype.moe_quant.utils import (
13+
FakeExtraDimTensor,
14+
MoEQuantConfig,
15+
cond_ffn_filter,
16+
)
17+
from torchao.quantization.quant_api import (
18+
AffineQuantizedTensor,
19+
Float8DynamicActivationFloat8WeightConfig,
20+
Float8WeightOnlyConfig,
21+
Int4WeightOnlyConfig,
22+
Int8DynamicActivationInt8WeightConfig,
23+
Int8WeightOnlyConfig,
24+
LinearActivationQuantizedTensor,
25+
quantize_,
26+
)
27+
from torchao.quantization.utils import compute_error
28+
29+
30+
class TestMoEQuantCompile(unittest.TestCase):
31+
DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k
32+
33+
@torch.no_grad()
34+
def _test_impl_moe_quant(self,
35+
config,
36+
num_tokens=1,
37+
model_params=None,
38+
base_class=AffineQuantizedTensor,
39+
tensor_impl_class=None,
40+
dtype=torch.bfloat16,
41+
device="cuda",
42+
fullgraph=False
43+
):
44+
"""
45+
Tests moe quant for techniques using fake extra dim
46+
"""
47+
if model_params is None:
48+
model_params=self.DEFAULT_PARAMS
49+
50+
input_shape = (num_tokens, model_params[0])
51+
model = MOEFeedForwardAOQuantizable(*model_params, empty_init=False).to(dtype).to(device)
52+
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
53+
54+
out = model(input)
55+
56+
import copy
57+
58+
quantize_(model, config, cond_ffn_filter)
59+
60+
if isinstance(config, MoEQuantConfig):
61+
self.assertIsInstance(model.experts.w1, FakeExtraDimTensor)
62+
if base_class is not None:
63+
self.assertIsInstance(model.experts.w1.head_tensor, base_class)
64+
if tensor_impl_class is not None:
65+
self.assertIsInstance(model.experts.w1.head_tensor.tensor_impl, tensor_impl_class)
66+
else:
67+
if base_class is not None:
68+
self.assertIsInstance(model.experts.w1, base_class)
69+
if tensor_impl_class is not None:
70+
self.assertIsInstance(model.experts.w1.tensor_impl, tensor_impl_class)
71+
72+
out_q = model(input)
73+
74+
torch._dynamo.config.capture_scalar_outputs = True
75+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
76+
model_c = torch.compile(model, mode="reduce-overhead", fullgraph=fullgraph)
77+
78+
model_c(input)
79+
model_c(input)
80+
out_qc = model_c(input).clone()
81+
82+
for i in range(10):
83+
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
84+
model_c(input)
85+
86+
self.assertGreaterEqual(compute_error(out_q, out), 10)
87+
self.assertGreaterEqual(compute_error(out_qc, out), 10)
88+
89+
90+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
91+
@parameterized.expand([
92+
("single_token", 1, False),
93+
("multiple_tokens", 8, False),
94+
])
95+
def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
96+
config = MoEQuantConfig(Int4WeightOnlyConfig())
97+
tensor_impl_class = TensorCoreTiledAQTTensorImpl
98+
99+
self._test_impl_moe_quant(
100+
config=config,
101+
num_tokens=num_tokens,
102+
tensor_impl_class=tensor_impl_class,
103+
fullgraph=fullgraph
104+
)
105+
106+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
107+
@parameterized.expand([
108+
("single_token", 1, True),
109+
("multiple_tokens", 8, False),
110+
])
111+
def test_int4wo_base(self, name, num_tokens, fullgraph):
112+
config = Int4WeightOnlyConfig()
113+
tensor_impl_class = TensorCoreTiledAQTTensorImpl
114+
115+
self._test_impl_moe_quant(
116+
config=config,
117+
num_tokens=num_tokens,
118+
tensor_impl_class=tensor_impl_class,
119+
fullgraph=fullgraph
120+
)
121+
122+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
123+
@parameterized.expand([
124+
("single_token", 1, False),
125+
("multiple_tokens", 8, False),
126+
])
127+
def test_int8wo_fake_dim(self, name, num_tokens, fullgraph):
128+
config = MoEQuantConfig(Int8WeightOnlyConfig())
129+
tensor_impl_class = PlainAQTTensorImpl
130+
131+
self._test_impl_moe_quant(
132+
config=config,
133+
num_tokens=num_tokens,
134+
tensor_impl_class=tensor_impl_class,
135+
fullgraph=fullgraph
136+
)
137+
138+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
139+
@parameterized.expand([
140+
("single_token", 1, True),
141+
("multiple_tokens", 8, False),
142+
])
143+
def test_int8wo_base(self, name, num_tokens, fullgraph):
144+
config = Int8WeightOnlyConfig()
145+
tensor_impl_class = PlainAQTTensorImpl
146+
147+
self._test_impl_moe_quant(
148+
config=config,
149+
num_tokens=num_tokens,
150+
tensor_impl_class=tensor_impl_class,
151+
fullgraph=fullgraph
152+
)
153+
154+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
155+
@parameterized.expand([
156+
("multiple_tokens", 32, False),
157+
])
158+
def test_int8dq_fake_dim(self, name, num_tokens, fullgraph):
159+
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
160+
base_class = LinearActivationQuantizedTensor
161+
162+
self._test_impl_moe_quant(
163+
model_params=(512, 256, 2, 2),
164+
config=config,
165+
num_tokens=num_tokens,
166+
base_class=base_class,
167+
fullgraph=fullgraph
168+
)
169+
170+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
171+
@parameterized.expand([
172+
("multiple_tokens", 32, False),
173+
])
174+
def test_int8dq_base(self, name, num_tokens, fullgraph):
175+
config = Int8DynamicActivationInt8WeightConfig()
176+
base_class = LinearActivationQuantizedTensor
177+
178+
self._test_impl_moe_quant(
179+
model_params=(512, 256, 2, 2),
180+
config=config,
181+
num_tokens=num_tokens,
182+
base_class=base_class,
183+
fullgraph=fullgraph
184+
)
185+
186+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
187+
@parameterized.expand([
188+
("single_token", 1, False),
189+
("multiple_tokens", 8, False),
190+
191+
])
192+
def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph):
193+
config = MoEQuantConfig(Float8WeightOnlyConfig())
194+
tensor_impl_class = Float8AQTTensorImpl
195+
196+
self._test_impl_moe_quant(
197+
config=config,
198+
num_tokens=num_tokens,
199+
tensor_impl_class=tensor_impl_class,
200+
fullgraph=fullgraph
201+
)
202+
203+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
204+
@parameterized.expand([
205+
("single_token", 1, True),
206+
("multiple_tokens", 8, False),
207+
])
208+
def test_fp8wo_base(self, name, num_tokens, fullgraph):
209+
config = Float8WeightOnlyConfig()
210+
tensor_impl_class = Float8AQTTensorImpl
211+
212+
self._test_impl_moe_quant(
213+
config=config,
214+
num_tokens=num_tokens,
215+
tensor_impl_class=tensor_impl_class,
216+
fullgraph=fullgraph
217+
)
218+
219+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
220+
@parameterized.expand([
221+
("single_token", 1, False),
222+
("multiple_tokens", 8, False),
223+
224+
])
225+
def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph):
226+
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig())
227+
base_class = LinearActivationQuantizedTensor
228+
229+
self._test_impl_moe_quant(
230+
config=config,
231+
num_tokens=num_tokens,
232+
base_class=base_class,
233+
fullgraph=fullgraph
234+
)
235+
236+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
237+
@parameterized.expand([
238+
("single_token", 1, True),
239+
("multiple_tokens", 8, False),
240+
])
241+
def test_fp8dq_base(self, name, num_tokens, fullgraph):
242+
config = Float8DynamicActivationFloat8WeightConfig()
243+
base_class = LinearActivationQuantizedTensor
244+
245+
self._test_impl_moe_quant(
246+
config=config,
247+
num_tokens=num_tokens,
248+
base_class=base_class,
249+
fullgraph=fullgraph
250+
)
251+
252+
253+
if __name__ == "__main__":
254+
unittest.main()

0 commit comments

Comments
 (0)