Skip to content

Commit 72bc113

Browse files
committed
Add support for float8 activation for Int4GroupwisePreshuffleTensor
Summary: Added basic op support like linear and bmm, we have both float8 and bf16 in the same Tensor because it's the same dtype, only difference is whether the activation is quantized or not. Although there is some differneces in implementation: bf16 activaton: * group_scale * group_zero fp8 activation * group_scale * row_scale Test Plan: python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2437, branch: jerryzh168/stack/4
1 parent 00417b8 commit 72bc113

File tree

5 files changed

+246
-77
lines changed

5 files changed

+246
-77
lines changed

test/quantization/quantize_/test_int4_groupwise_preshuffle.py renamed to test/quantization/quantize_/int4/test_int4_groupwise_preshuffle_tensor.py

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import tempfile
78
import unittest
89

910
import torch
1011
from torch.testing._internal.common_utils import (
1112
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
1215
run_tests,
1316
)
1417

18+
from torchao.float8.config import e4m3_dtype
1519
from torchao.quantization import (
1620
FbgemmConfig,
1721
quantize_,
@@ -23,6 +27,45 @@
2327
is_sm_at_least_90,
2428
)
2529

30+
if TORCH_VERSION_AT_LEAST_2_8:
31+
BF16_ACT_CONFIG = FbgemmConfig(
32+
input_dtype=torch.bfloat16,
33+
weight_dtype=torch.int4,
34+
output_dtype=torch.bfloat16,
35+
block_size=[1, 128],
36+
preshuffle=True,
37+
)
38+
39+
BF16_ACT_BMM_CONFIG = FbgemmConfig(
40+
input_dtype=torch.bfloat16,
41+
weight_dtype=torch.int4,
42+
output_dtype=torch.bfloat16,
43+
block_size=[1, 1, 128],
44+
preshuffle=True,
45+
)
46+
47+
FP8_ACT_CONFIG = FbgemmConfig(
48+
input_dtype=e4m3_dtype,
49+
weight_dtype=torch.int4,
50+
output_dtype=torch.bfloat16,
51+
block_size=[1, 128],
52+
preshuffle=True,
53+
)
54+
55+
FP8_ACT_BMM_CONFIG = FbgemmConfig(
56+
input_dtype=e4m3_dtype,
57+
weight_dtype=torch.int4,
58+
output_dtype=torch.bfloat16,
59+
block_size=[1, 1, 128],
60+
preshuffle=True,
61+
)
62+
63+
else:
64+
BF16_ACT_CONFIG = None
65+
BF16_ACT_BMM_CONFIG = None
66+
FP8_ACT_CONFIG = None
67+
FP8_ACT_BMM_CONFIG = None
68+
2669

2770
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
2871
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -32,33 +75,23 @@
3275
)
3376
class TestInt4GroupwisePreshuffleTensor(TestCase):
3477
def setUp(self):
35-
self.config = FbgemmConfig(
36-
input_dtype=torch.bfloat16,
37-
weight_dtype=torch.int4,
38-
output_dtype=torch.bfloat16,
39-
block_size=[1, 128],
40-
preshuffle=True,
41-
)
42-
self.bmm_config = FbgemmConfig(
43-
input_dtype=torch.bfloat16,
44-
weight_dtype=torch.int4,
45-
output_dtype=torch.bfloat16,
46-
block_size=[1, 1, 128],
47-
preshuffle=True,
48-
)
4978
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
5079

51-
def test_linear(self):
80+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
81+
def test_linear(self, config):
5282
dtype = torch.bfloat16
5383
device = "cuda"
5484
input = torch.randn(1, 128, dtype=dtype, device=device)
5585
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
5686
original = linear(input)
57-
quantize_(linear, self.config)
87+
quantize_(linear, config)
5888
quantized = linear(input)
5989
self.assertTrue(compute_error(original, quantized) > 20)
6090

61-
def test_bmm(self):
91+
# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
92+
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
93+
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
94+
def test_bmm(self, bmm_config):
6295
class M(torch.nn.Module):
6396
def __init__(self, weight):
6497
super().__init__()
@@ -74,32 +107,46 @@ def forward(self, x):
74107
m = M(weight).eval()
75108
original = m(input)
76109
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
77-
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
110+
quantize_(m, bmm_config, filter_fn=lambda x, fqn: True)
78111
quantized = m(input)
79112
self.assertTrue(compute_error(original, quantized) > 18)
80113

81-
def test_to_device(self):
114+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
115+
def test_to_device(self, config):
82116
for device in self.GPU_DEVICES:
83117
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
84-
quantize_(linear, self.config)
118+
quantize_(linear, config)
85119
linear.to(device)
86120

87121
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
88-
quantize_(linear, self.config)
122+
quantize_(linear, config)
89123
linear.to(device=device)
90124

91125
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
92-
quantize_(linear, self.config)
126+
quantize_(linear, config)
93127
linear.to(device)
94128

95-
def test_module_path(self):
129+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
130+
def test_module_path(self, config):
96131
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
97-
quantize_(linear, self.config)
132+
quantize_(linear, config)
98133
self.assertEqual(
99134
str(type(linear.weight)),
100135
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
101136
)
102137

138+
with tempfile.NamedTemporaryFile() as f:
139+
torch.save(linear.state_dict(), f)
140+
f.seek(0)
141+
state_dict = torch.load(f)
142+
self.assertEqual(
143+
str(type(state_dict["weight"])),
144+
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
145+
)
146+
147+
148+
instantiate_parametrized_tests(TestInt4GroupwisePreshuffleTensor)
149+
103150

104151
if __name__ == "__main__":
105152
run_tests()

torchao/quantization/quant_api.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2040,6 +2040,8 @@ class FbgemmConfig(AOBaseConfig):
20402040
weight_dtype (torch.dtype): weight dtype of the kernel
20412041
output_dtype (torch.dtype): output dtype of the kernel
20422042
group_size (int): The group size for weight
2043+
preshuffle (bool): whether preshuffle the weights or not
2044+
activation_dtype_for_int4 (str): the dtype for activation for int4 weight, either bf16 or fp8
20432045
"""
20442046

20452047
input_dtype: torch.dtype
@@ -2067,7 +2069,9 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20672069
):
20682070
if config.preshuffle:
20692071
weight = Int4GroupwisePreshuffleTensor.from_float(
2070-
module.weight, config.block_size
2072+
module.weight,
2073+
config.block_size,
2074+
activation_dtype="bf16",
20712075
)
20722076
else:
20732077
weight = to_fbgemm_int4(
@@ -2077,6 +2081,20 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20772081
module.weight = torch.nn.Parameter(weight, requires_grad=False)
20782082
module.extra_repr = types.MethodType(_linear_extra_repr, module)
20792083
return module
2084+
if (
2085+
(config.input_dtype == e4m3_dtype)
2086+
and (config.weight_dtype == torch.int4)
2087+
and (config.output_dtype == torch.bfloat16)
2088+
):
2089+
if config.preshuffle:
2090+
weight = Int4GroupwisePreshuffleTensor.from_float(
2091+
module.weight,
2092+
config.block_size,
2093+
activation_dtype="fp8",
2094+
)
2095+
module.weight = torch.nn.Parameter(weight, requires_grad=False)
2096+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
2097+
return module
20802098
elif (
20812099
(config.input_dtype == e4m3_dtype)
20822100
and (config.weight_dtype == e4m3_dtype)

torchao/quantization/quantize_/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .int4_groupwise_preshuffle_tensor import (
1+
from .int4 import (
22
Int4GroupwisePreshuffleTensor,
33
)
44

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .int4_groupwise_preshuffle_tensor import (
2+
Int4GroupwisePreshuffleTensor,
3+
)
4+
5+
__all__ = [
6+
"Int4GroupwisePreshuffleTensor",
7+
]

0 commit comments

Comments
 (0)