Skip to content

Commit c4a7ad4

Browse files
authored
Relax MOE constraints and add test for torch.mm computation (#2227)
* Relax some constraints to allow quantizing aten.mm Summary: Currently both float8 dynamic quant and int4 weight only quant only works with F.linear, not aten.mm this PR allows fallback to dequantizing tensors and run the fallback path before the real support is in place. Test Plan: python test/dtypes/test_affine_quantized.py -k test_mm_int4wo python test/dtypes/test_affine_quantized_float.py -k test_mm_float8dq Reviewers: Subscribers: Tasks: Tags: * add skip if no cuda * update tests * update
1 parent 5153bd3 commit c4a7ad4

File tree

3 files changed

+45
-1
lines changed

3 files changed

+45
-1
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,24 @@ def test_slice_and_copy_int4wo(self, device, dtype):
421421
# making sure param.data is updated
422422
assert param.data.dequantize()[0][0] != 0
423423

424+
@common_utils.parametrize("device", ["cuda"])
425+
@common_utils.parametrize("dtype", [torch.bfloat16])
426+
@skip_if_no_cuda()
427+
@skip_if_rocm("ROCm enablement in progress")
428+
def test_mm_int4wo(self, device, dtype):
429+
weight = torch.randn(512, 1024).to(device).to(dtype)
430+
weight = weight.t()
431+
432+
l = torch.nn.Linear(512, 1024).to(device).to(dtype)
433+
l.weight = torch.nn.Parameter(weight)
434+
quantize_(l, Int4WeightOnlyConfig())
435+
# weight shape: 1024 x 512
436+
weight = l.weight
437+
438+
input = torch.randn(1, 512, device=device, dtype=dtype)
439+
# make sure it runs
440+
torch.nn.functional.linear(input, weight)
441+
424442

425443
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
426444
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)

test/dtypes/test_affine_quantized_float.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from torchao.float8.float8_utils import compute_error
2929
from torchao.quantization import (
30+
Float8DynamicActivationFloat8WeightConfig,
3031
float8_dynamic_activation_float8_weight,
3132
float8_weight_only,
3233
quantize_,
@@ -292,6 +293,26 @@ def test_fp8_weight_dimension_warning(self):
292293
f"Expected warning message containing: {expected}",
293294
)
294295

296+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
297+
@unittest.skipIf(
298+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
299+
)
300+
def test_mm_float8dq(self):
301+
device = "cuda"
302+
dtype = torch.bfloat16
303+
weight = torch.randn(512, 1024).to(device).to(dtype)
304+
weight = weight.t()
305+
306+
l = torch.nn.Linear(512, 1024).to(device).to(dtype)
307+
l.weight = torch.nn.Parameter(weight)
308+
quantize_(l, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
309+
# weight shape: 1024 x 512
310+
weight = l.weight
311+
312+
input = torch.randn(1, 512, device=device, dtype=dtype)
313+
# make sure it runs
314+
torch.nn.functional.linear(input, weight)
315+
295316

296317
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
297318

torchao/prototype/moe_quant/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
import torch
28
from torch.utils._python_dispatch import (
39
return_and_correct_aliasing,
@@ -282,7 +288,6 @@ def moe_quant_fn(module, config: MoEQuantConfig):
282288

283289
warnings.simplefilter("ignore", lineno=84)
284290
warnings.simplefilter("ignore", lineno=105)
285-
assert "ConditionalFeedForwardAOQuantizable" in str(type(module))
286291

287292
for weight_attr in ["w1", "w2", "w3"]:
288293
param = getattr(module, weight_attr)

0 commit comments

Comments
 (0)