Skip to content

Add support for float8 activation for Int4PreshuffledTensor #2437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torchao.float8.config import e4m3_dtype
from torchao.quantization import (
FbgemmConfig,
quantize_,
Expand All @@ -23,42 +27,71 @@
is_sm_at_least_90,
)

if TORCH_VERSION_AT_LEAST_2_8:
BF16_ACT_CONFIG = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
preshuffle=True,
)

BF16_ACT_BMM_CONFIG = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
preshuffle=True,
)

FP8_ACT_CONFIG = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
preshuffle=True,
)

FP8_ACT_BMM_CONFIG = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
preshuffle=True,
)

else:
BF16_ACT_CONFIG = None
BF16_ACT_BMM_CONFIG = None
FP8_ACT_CONFIG = None
FP8_ACT_BMM_CONFIG = None


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
)
class TestInt4GroupwisePreshuffleTensor(TestCase):
class TestInt4PreshuffledTensor(TestCase):
def setUp(self):
self.config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
preshuffle=True,
)
self.bmm_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
preshuffle=True,
)
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

def test_linear(self):
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
def test_linear(self, config):
dtype = torch.bfloat16
device = "cuda"
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, self.config)
quantize_(linear, config)
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)

def test_bmm(self):
# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
def test_bmm(self, bmm_config):
class M(torch.nn.Module):
def __init__(self, weight):
super().__init__()
Expand All @@ -74,32 +107,46 @@ def forward(self, x):
m = M(weight).eval()
original = m(input)
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
quantize_(m, bmm_config, filter_fn=lambda x, fqn: True)
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 18)

def test_to_device(self):
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
def test_to_device(self, config):
for device in self.GPU_DEVICES:
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
quantize_(linear, config)
linear.to(device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
quantize_(linear, config)
linear.to(device=device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
quantize_(linear, config)
linear.to(device)

def test_module_path(self):
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
def test_module_path(self, config):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
quantize_(linear, config)
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
"<class 'torchao.quantization.Int4PreshuffledTensor'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Int4PreshuffledTensor'>",
)


instantiate_parametrized_tests(TestInt4PreshuffledTensor)


if __name__ == "__main__":
run_tests()
6 changes: 3 additions & 3 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@
dequantize_affine,
quantize_affine,
)
from .quantize_ import (
Int4GroupwisePreshuffleTensor,
from .quantize_.workflows import (
Int4PreshuffledTensor,
)
from .smoothquant import (
SmoothFakeDynamicallyQuantizedLinear,
Expand Down Expand Up @@ -153,7 +153,7 @@
"ModuleFqnToConfig",
"FbgemmConfig",
# tensor subclasses
"Int4GroupwisePreshuffleTensor",
"Int4PreshuffledTensor",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
27 changes: 22 additions & 5 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@
LinearActivationWeightObservedTensor,
)
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
from torchao.quantization.quantize_ import (
Int4GroupwisePreshuffleTensor,
from torchao.quantization.quantize_.workflows import (
Int4PreshuffledTensor,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
Expand Down Expand Up @@ -2056,6 +2056,7 @@ class FbgemmConfig(AOBaseConfig):
weight_dtype (torch.dtype): weight dtype of the kernel
output_dtype (torch.dtype): output dtype of the kernel
group_size (int): The group size for weight
preshuffle (bool): whether preshuffle the weights or not
"""

input_dtype: torch.dtype
Expand All @@ -2073,7 +2074,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:

_SUPPORTED_DTYPES = {
(torch.bfloat16, torch.int4, torch.bfloat16),
(e4m3_dtype, e4m3_dtype, torch.bfloat16),
(torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16),
}

if (
Expand All @@ -2082,8 +2083,10 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
and (config.output_dtype == torch.bfloat16)
):
if config.preshuffle:
weight = Int4GroupwisePreshuffleTensor.from_float(
module.weight, config.block_size
weight = Int4PreshuffledTensor.from_float(
module.weight,
config.block_size,
activation_dtype=torch.bfloat16,
)
else:
weight = to_fbgemm_int4(
Expand All @@ -2093,6 +2096,20 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
module.weight = torch.nn.Parameter(weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module
if (
(config.input_dtype == e4m3_dtype)
and (config.weight_dtype == torch.int4)
and (config.output_dtype == torch.bfloat16)
):
if config.preshuffle:
weight = Int4PreshuffledTensor.from_float(
module.weight,
config.block_size,
activation_dtype=torch.float8_e4m3fn,
)
module.weight = torch.nn.Parameter(weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module
elif (
(config.input_dtype == e4m3_dtype)
and (config.weight_dtype == e4m3_dtype)
Expand Down
9 changes: 0 additions & 9 deletions torchao/quantization/quantize_/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@
from .int4_groupwise_preshuffle_tensor import (
Int4GroupwisePreshuffleTensor,
)

Int4GroupwisePreshuffleTensor.__module__ = "torchao.quantization"

__all__ = [
"Int4GroupwisePreshuffleTensor",
]
7 changes: 7 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .int4.int4_preshuffled_tensor import (
Int4PreshuffledTensor,
)

__all__ = [
"Int4PreshuffledTensor",
]
Empty file.
Loading
Loading