Skip to content

Commit c64118c

Browse files
committed
update API and remove branching on quant_api.py transform functions
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 6e6f6eb commit c64118c

File tree

9 files changed

+192
-139
lines changed

9 files changed

+192
-139
lines changed

test/quantization/test_moe_quant.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torchao.quantization.prototype.moe_quant.utils import (
1313
FakeExtraDimTensor,
1414
MoEQuantConfig,
15+
UseFakeExtraDimTensor,
1516
cond_ffn_filter,
1617
)
1718
from torchao.quantization.quant_api import (
@@ -25,7 +26,11 @@
2526
quantize_,
2627
)
2728
from torchao.quantization.utils import compute_error
28-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_90, TORCH_VERSION_AT_LEAST_2_6
29+
from torchao.utils import (
30+
TORCH_VERSION_AT_LEAST_2_5,
31+
TORCH_VERSION_AT_LEAST_2_6,
32+
is_sm_at_least_90,
33+
)
2934

3035

3136
class TestMoEQuantCompile(unittest.TestCase):
@@ -61,7 +66,10 @@ def _test_impl_moe_quant(
6166

6267
quantize_(model, config, cond_ffn_filter)
6368

64-
if isinstance(config, MoEQuantConfig):
69+
if (
70+
isinstance(config, MoEQuantConfig)
71+
and config.use_fake_extra_dim_tensor == UseFakeExtraDimTensor.TRUE
72+
):
6573
self.assertIsInstance(model.experts.w1, FakeExtraDimTensor)
6674
if base_class is not None:
6775
self.assertIsInstance(model.experts.w1.head_tensor, base_class)
@@ -104,7 +112,9 @@ def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
104112
if not TORCH_VERSION_AT_LEAST_2_5:
105113
self.skipTest("Test only enabled for 2.5+")
106114

107-
config = MoEQuantConfig(Int4WeightOnlyConfig())
115+
config = MoEQuantConfig(
116+
Int4WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
117+
)
108118
tensor_impl_class = TensorCoreTiledAQTTensorImpl
109119

110120
self._test_impl_moe_quant(
@@ -128,7 +138,7 @@ def test_int4wo_base(self, name, num_tokens, fullgraph):
128138
if not TORCH_VERSION_AT_LEAST_2_5:
129139
self.skipTest("Test only enabled for 2.5+")
130140

131-
config = Int4WeightOnlyConfig()
141+
config = MoEQuantConfig(Int4WeightOnlyConfig())
132142
tensor_impl_class = TensorCoreTiledAQTTensorImpl
133143

134144
self._test_impl_moe_quant(
@@ -150,7 +160,9 @@ def test_int8wo_fake_dim(self, name, num_tokens, fullgraph):
150160
if not TORCH_VERSION_AT_LEAST_2_5:
151161
self.skipTest("Test only enabled for 2.5+")
152162

153-
config = MoEQuantConfig(Int8WeightOnlyConfig())
163+
config = MoEQuantConfig(
164+
Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
165+
)
154166
tensor_impl_class = PlainAQTTensorImpl
155167

156168
self._test_impl_moe_quant(
@@ -172,7 +184,7 @@ def test_int8wo_base(self, name, num_tokens, fullgraph):
172184
if not TORCH_VERSION_AT_LEAST_2_6:
173185
self.skipTest("Test only enabled for 2.6+")
174186

175-
config = Int8WeightOnlyConfig()
187+
config = MoEQuantConfig(Int8WeightOnlyConfig())
176188
tensor_impl_class = PlainAQTTensorImpl
177189

178190
self._test_impl_moe_quant(
@@ -192,7 +204,7 @@ def test_int8wo_base_cpu(self, name, num_tokens, fullgraph):
192204
if not TORCH_VERSION_AT_LEAST_2_6:
193205
self.skipTest("Test only enabled for 2.6+")
194206

195-
config = Int8WeightOnlyConfig()
207+
config = MoEQuantConfig(Int8WeightOnlyConfig())
196208
tensor_impl_class = PlainAQTTensorImpl
197209

198210
self._test_impl_moe_quant(
@@ -214,7 +226,10 @@ def test_int8dq_fake_dim(self, name, num_tokens, fullgraph):
214226
if not TORCH_VERSION_AT_LEAST_2_5:
215227
self.skipTest("Test only enabled for 2.5+")
216228

217-
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
229+
config = MoEQuantConfig(
230+
Int8DynamicActivationInt8WeightConfig(),
231+
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
232+
)
218233
base_class = LinearActivationQuantizedTensor
219234

220235
self._test_impl_moe_quant(
@@ -236,7 +251,7 @@ def test_int8dq_base(self, name, num_tokens, fullgraph):
236251
if not TORCH_VERSION_AT_LEAST_2_5:
237252
self.skipTest("Test only enabled for 2.5+")
238253

239-
config = Int8DynamicActivationInt8WeightConfig()
254+
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
240255
base_class = LinearActivationQuantizedTensor
241256

242257
self._test_impl_moe_quant(
@@ -259,7 +274,10 @@ def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph):
259274
if not is_sm_at_least_90():
260275
self.skipTest("Requires CUDA capability >= 9.0")
261276

262-
config = MoEQuantConfig(Float8WeightOnlyConfig())
277+
config = MoEQuantConfig(
278+
Float8WeightOnlyConfig(),
279+
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
280+
)
263281
tensor_impl_class = Float8AQTTensorImpl
264282

265283
self._test_impl_moe_quant(
@@ -281,7 +299,7 @@ def test_fp8wo_base(self, name, num_tokens, fullgraph):
281299
if not is_sm_at_least_90():
282300
self.skipTest("Requires CUDA capability >= 9.0")
283301

284-
config = Float8WeightOnlyConfig()
302+
config = MoEQuantConfig(Float8WeightOnlyConfig())
285303
tensor_impl_class = Float8AQTTensorImpl
286304

287305
self._test_impl_moe_quant(
@@ -303,7 +321,10 @@ def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph):
303321
if not is_sm_at_least_90():
304322
self.skipTest("Requires CUDA capability >= 9.0")
305323

306-
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig())
324+
config = MoEQuantConfig(
325+
Float8DynamicActivationFloat8WeightConfig(),
326+
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
327+
)
307328
base_class = LinearActivationQuantizedTensor
308329

309330
self._test_impl_moe_quant(
@@ -325,7 +346,7 @@ def test_fp8dq_base(self, name, num_tokens, fullgraph):
325346
if not is_sm_at_least_90():
326347
self.skipTest("Requires CUDA capability >= 9.0")
327348

328-
config = Float8DynamicActivationFloat8WeightConfig()
349+
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig())
329350
base_class = LinearActivationQuantizedTensor
330351

331352
self._test_impl_moe_quant(

torchao/_models/mixtral-moe/generate.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def main(
239239
from torchao.quantization.prototype.moe_quant.utils import (
240240
MoEQuantConfig,
241241
cond_ffn_filter,
242+
UseFakeExtraDimTensor
242243
)
243244
from torchao.quantization.quant_api import (
244245
Float8DynamicActivationFloat8WeightConfig,
@@ -256,40 +257,44 @@ def main(
256257
torch._dynamo.config.capture_dynamic_output_shape_ops = True
257258
config = None
258259
if "int8wo-base" in moe_quant:
259-
config = Int8WeightOnlyConfig()
260+
config = MoEQuantConfig(Int8WeightOnlyConfig())
260261

261262
elif "int8wo" in moe_quant:
262-
config = MoEQuantConfig(Int8WeightOnlyConfig())
263+
config = MoEQuantConfig(Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE)
263264

264265
elif "int8dq-base" in moe_quant:
265-
config = Int8DynamicActivationInt8WeightConfig()
266+
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
266267

267268
elif "int8dq" in moe_quant:
268-
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
269+
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE)
269270

270271
elif "int4wo-base" in moe_quant:
271-
config = Int4WeightOnlyConfig()
272+
config = MoEQuantConfig(Int4WeightOnlyConfig())
272273

273274
elif "int4wo" in moe_quant:
274-
config = MoEQuantConfig(Int4WeightOnlyConfig())
275+
config = MoEQuantConfig(Int4WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE)
275276

276277
elif "fp8wo-base" in moe_quant:
277-
config = Float8WeightOnlyConfig()
278+
config = MoEQuantConfig(Float8WeightOnlyConfig())
278279

279280
elif "fp8wo" in moe_quant:
280-
config = MoEQuantConfig(Float8WeightOnlyConfig())
281+
config = MoEQuantConfig(Float8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE)
281282

282283
elif "fp8dq-base" in moe_quant:
283-
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
284+
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
284285

285286
elif "fp8dq" in moe_quant:
286287
config = MoEQuantConfig(
287-
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
288+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
289+
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
288290
)
289291

290292
elif "intxdq" in moe_quant:
291-
config = Int8DynamicActivationIntxWeightConfig(
292-
layout=PackedLinearInt8DynamicActivationIntxWeightLayout()
293+
config = MoEQuantConfig(
294+
Int8DynamicActivationIntxWeightConfig(
295+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
296+
),
297+
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
293298
)
294299
else:
295300
assert config is not None, (

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,6 @@ def _(func, types, args, kwargs):
504504
assert len(indices) == 1, (
505505
f"op {func} currently only implemented for single dimensional indexing but got indices: {indices}"
506506
)
507-
508507
new_tensor_impl = aten.index.Tensor(self.tensor_impl, indices)
509508
shape = tuple([indices[0].numel(), *self.shape[1:]])
510509

torchao/dtypes/floatx/float8_layout.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class Float8Layout(Layout):
5555

5656
mm_config: Optional[Float8MMConfig] = None
5757

58+
_fallback_warning_shown = False
5859

5960
@register_layout(Float8Layout)
6061
class Float8AQTTensorImpl(AQTTensorImpl):
@@ -100,12 +101,34 @@ def __init__(
100101

101102
def _apply_fn_to_data(self, fn):
102103
"""Applys a fn to all tensor components stored on this class"""
103-
return self.__class__(
104-
fn(self.float8_data),
105-
fn(self.scale),
106-
self.transposed,
107-
self._layout,
108-
)
104+
global _fallback_warning_shown
105+
106+
try:
107+
return self.__class__(
108+
fn(self.float8_data),
109+
fn(self.scale),
110+
self.transposed,
111+
self._layout,
112+
)
113+
except RuntimeError as e:
114+
if '"index_cuda" not implemented for ' in str(e):
115+
if not _fallback_warning_shown:
116+
import warnings
117+
warnings.warn(
118+
f"When trying to index Float8AQTTensorImpl, got known error {e}, will use slower fallback but "
119+
+ "note: You can torch.compile the model to avoid this problem.",
120+
UserWarning
121+
)
122+
_fallback_warning_shown = True
123+
124+
return self.__class__( # do indexing in bfloat16 then convert back
125+
fn(self.float8_data.to(torch.bfloat16)).to(self.float8_data.dtype),
126+
fn(self.scale),
127+
self.transposed,
128+
self._layout,
129+
)
130+
else:
131+
raise e
109132

110133
def to(self, *args, **kwargs):
111134
kwargs = self._get_to_kwargs(*args, **kwargs)

torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,7 @@ def test_moe_quant_intx(self):
637637
FakeExtraDimTensor,
638638
MoEQuantConfig,
639639
cond_ffn_filter,
640+
UseFakeExtraDimTensor,
640641
)
641642
from torchao.quantization.quant_api import (
642643
Int8DynamicActivationIntxWeightConfig,
@@ -656,7 +657,7 @@ def test_moe_quant_intx(self):
656657
base_config = Int8DynamicActivationIntxWeightConfig(
657658
layout=PackedLinearInt8DynamicActivationIntxWeightLayout()
658659
)
659-
moe_config = MoEQuantConfig(base_config)
660+
moe_config = MoEQuantConfig(base_config, use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE)
660661

661662
quantize_(model, moe_config, cond_ffn_filter)
662663

torchao/quantization/prototype/moe_quant/README.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ The API for moe quantization is very similar to linear quantization, given a moe
1010

1111
```python
1212

13-
from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter
13+
from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter,
1414
from torchao.quantization.quant_api import quantize_, Int8WeightOnlyConfig
1515

16-
quantize_(model, Int8WeightOnlyConfig(), filter_fn=cond_ffn_filter)
16+
quantize_(model, MoEQuantConfig(Int8WeightOnlyConfig()), filter_fn=cond_ffn_filter)
1717
model=torch.compile(model, mode="reduce-overhead")
1818
# you can also use fullgraph=True for single token inference
1919
```
@@ -23,20 +23,26 @@ This api is the same as for normal linear quantization but with a specific filte
2323

2424
## Alternative Quantization API
2525

26-
To make the above api work, each tensor subclass had to be edited to work as 3D tensors. However the only ops we actually need to support are a few indexing and slicing ops on the 0th dimension, the majority of the work was removing hard coded assumptions about the tensor dimensionality. This means its possible to instead create a new tensor subclass that pretends to be a 3D tensor by storing a series of 2D tensors and simulating the slicing and indexing ops until eventually just returning the singular desired 2D quantized tensor subclass. This can be achieved using the alternative api as follows:
26+
To make the above api work, each tensor subclass had to be edited to work as 3D tensors. However the only ops we actually need to support are a few indexing and slicing ops on the 0th dimension, the majority of the work was removing hard coded assumptions about the tensor dimensionality. This means its possible to instead create a new tensor subclass that pretends to be a 3D tensor by storing a series of 2D tensors and simulating the slicing and indexing ops until eventually just returning the singular desired 2D quantized tensor subclass. This can be achieved using the alternative api by changing the fake_extra_dim_tensor flag of the MoEQuantConfig:
2727

2828
```python
2929

30-
from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter, MoEQuantConfig
30+
from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter, MoEQuantConfig, UseFakeExtraDimTensor
3131
from torchao.quantization.quant_api import quantize_, Int8DynamicActivationIntxWeightConfig
3232

33-
config = MoEQuantConfig(Int8DynamicActivationIntxWeightConfig())
33+
config = MoEQuantConfig(
34+
Int8DynamicActivationIntxWeightConfig(),
35+
# this is the only difference from the above api
36+
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
37+
)
3438

3539
quantize_(model, , filter_fn=cond_ffn_filter)
3640
model=torch.compile(model, mode="reduce-overhead")
3741
```
3842

39-
While this approach turns out to not be especially performant, it does allow for comparable memory characteristics, allowing models that wouldn't fit on a single node/gpu to actually run. It is flexible enough however to work with all of the existing linear quantization techniques that make use of quantized tensor subclasses without any changes being made to those classes. It is compilable though even single token inference doesn't work with fullgraph compilation.
43+
It should also be noted that the default value for use_fake_extra_dim_tensor is AS_FALLBACK which means that it will try to use the base method but if not, will use the more general but less performant fake_extra_dim_tensor method.
44+
45+
While this approach turns out to not be especially performant, it does allow for slightly better memory characteristics since all the tensors are held seperately and aren't actually modified or indexed. It is flexible enough to work with all of the existing linear quantization techniques that make use of quantized tensor subclasses without any changes being made to those classes. It is compilable though neither single token nor multi token inference works with fullgraph compilation.
4046

4147
## Model API
4248

torchao/quantization/prototype/moe_quant/llama4_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def convert_fn(module):
7070
model = model
7171

7272
from torchao.quantization import Int4WeightOnlyConfig, quantize_
73-
from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter
73+
from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter, MoEQuantConfig
7474

75-
quantize_(model, Int4WeightOnlyConfig(), cond_ffn_filter, device="cuda")
75+
quantize_(model, MoEQuantConfig(Int4WeightOnlyConfig()), cond_ffn_filter, device="cuda")
7676

7777
model.cuda()
7878

0 commit comments

Comments
 (0)