Skip to content

Enabling MOE Quantization using linear decomposition [WIP] #2043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions test/quantization/test_moe_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
import torch
import unittest
from torchao.quantization.prototype.moe_quant.quantizable_moe_modules import MOEFeedForwardAOQuantizable
from torchao.quantization.prototype.moe_quant.utils import MoEQuantConfig, cond_ffn_filter, FakeExtraDimTensor
from torchao.quantization.quant_api import (
Int8WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int4WeightOnlyConfig,
Float8WeightOnlyConfig,
Int8DynamicActivationIntxWeightConfig,
PackedLinearInt8DynamicActivationIntxWeightLayout,
quantize_,
AffineQuantizedTensor,
LinearActivationQuantizedTensor,
Float8DynamicActivationFloat8WeightConfig,
)
from torchao.quantization.utils import compute_error
from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl
from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from parameterized import param, parameterized

class TestMoEQuantCompile(unittest.TestCase):
DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k

@torch.no_grad()
def _test_impl_moe_quant(self,
config,
num_tokens=1,
model_params=None,
base_class=AffineQuantizedTensor,
tensor_impl_class=None,
dtype=torch.bfloat16,
device="cuda",
fullgraph=False
):
"""
Tests moe quant for techniques using fake extra dim
"""
if model_params is None:
model_params=self.DEFAULT_PARAMS

input_shape = (num_tokens, model_params[0])
model = MOEFeedForwardAOQuantizable(*model_params).to(dtype).to(device)
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)

out = model(input)

import copy
new_mod = copy.deepcopy(model)

quantize_(model, config, cond_ffn_filter)

if isinstance(config, MoEQuantConfig):
self.assertIsInstance(model.experts.w1, FakeExtraDimTensor)
if base_class is not None:
self.assertIsInstance(model.experts.w1.head_tensor, base_class)
if tensor_impl_class is not None:
self.assertIsInstance(model.experts.w1.head_tensor.tensor_impl, tensor_impl_class)
else:
if base_class is not None:
self.assertIsInstance(model.experts.w1, base_class)
if tensor_impl_class is not None:
self.assertIsInstance(model.experts.w1.tensor_impl, tensor_impl_class)

out_q = model(input)

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
model_c = torch.compile(model, mode="reduce-overhead", fullgraph=fullgraph)

model_c(input)
model_c(input)
out_qc = model_c(input).clone()

for i in range(10):
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
model_c(input)

self.assertGreaterEqual(compute_error(out_q, out), 10)
self.assertGreaterEqual(compute_error(out_qc, out), 10)
print(compute_error(out_q, out), compute_error(out_qc, out))


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parameterized.expand([
("single_token", 1, False),
("multiple_tokens", 8, False),
])
def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
config = MoEQuantConfig(Int4WeightOnlyConfig())
tensor_impl_class = TensorCoreTiledAQTTensorImpl

self._test_impl_moe_quant(
config=config,
num_tokens=num_tokens,
tensor_impl_class=tensor_impl_class,
fullgraph=fullgraph
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parameterized.expand([
("single_token", 1, True),
("multiple_tokens", 8, False),
])
def test_int4wo_base(self, name, num_tokens, fullgraph):
config = Int4WeightOnlyConfig()
tensor_impl_class = TensorCoreTiledAQTTensorImpl

self._test_impl_moe_quant(
config=config,
num_tokens=num_tokens,
tensor_impl_class=tensor_impl_class,
fullgraph=fullgraph
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parameterized.expand([
("single_token", 1, False),
("multiple_tokens", 8, False),
])
def test_int8wo_fake_dim(self, name, num_tokens, fullgraph):
config = MoEQuantConfig(Int8WeightOnlyConfig())
tensor_impl_class = PlainAQTTensorImpl

self._test_impl_moe_quant(
config=config,
num_tokens=num_tokens,
tensor_impl_class=tensor_impl_class,
fullgraph=fullgraph
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parameterized.expand([
("single_token", 1, True),
("multiple_tokens", 8, False),
])
def test_int8wo_base(self, name, num_tokens, fullgraph):
config = Int8WeightOnlyConfig()
tensor_impl_class = PlainAQTTensorImpl

self._test_impl_moe_quant(
config=config,
num_tokens=num_tokens,
tensor_impl_class=tensor_impl_class,
fullgraph=fullgraph
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parameterized.expand([
("multiple_tokens", 32, False),
])
def test_int8dq_fake_dim(self, name, num_tokens, fullgraph):
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
base_class = LinearActivationQuantizedTensor

self._test_impl_moe_quant(
model_params=(512, 256, 2, 2),
config=config,
num_tokens=num_tokens,
base_class=base_class,
fullgraph=fullgraph
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parameterized.expand([
("multiple_tokens", 32, False),
])
def test_int8dq_base(self, name, num_tokens, fullgraph):
config = Int8DynamicActivationInt8WeightConfig()
base_class = LinearActivationQuantizedTensor

self._test_impl_moe_quant(
model_params=(512, 256, 2, 2),
config=config,
num_tokens=num_tokens,
base_class=base_class,
fullgraph=fullgraph
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parameterized.expand([
("single_token", 1, False),
("multiple_tokens", 8, False),

])
def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph):
config = MoEQuantConfig(Float8WeightOnlyConfig())
tensor_impl_class = Float8AQTTensorImpl

self._test_impl_moe_quant(
config=config,
num_tokens=num_tokens,
tensor_impl_class=tensor_impl_class,
fullgraph=fullgraph
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parameterized.expand([
("single_token", 1, True),
("multiple_tokens", 8, False),
])
def test_fp8wo_base(self, name, num_tokens, fullgraph):
config = Float8WeightOnlyConfig()
tensor_impl_class = Float8AQTTensorImpl

self._test_impl_moe_quant(
config=config,
num_tokens=num_tokens,
tensor_impl_class=tensor_impl_class,
fullgraph=fullgraph
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parameterized.expand([
("single_token", 1, False),
("multiple_tokens", 8, False),

])
def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph):
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig())
base_class = LinearActivationQuantizedTensor

self._test_impl_moe_quant(
config=config,
num_tokens=num_tokens,
base_class=base_class,
fullgraph=fullgraph
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parameterized.expand([
("single_token", 1, True),
("multiple_tokens", 8, False),
])
def test_fp8dq_base(self, name, num_tokens, fullgraph):
config = Float8DynamicActivationFloat8WeightConfig()
base_class = LinearActivationQuantizedTensor

self._test_impl_moe_quant(
config=config,
num_tokens=num_tokens,
base_class=base_class,
fullgraph=fullgraph
)








if __name__ == "__main__":
unittest.main()
Loading