Skip to content

Commit c13489e

Browse files
committed
Add support for float8 activation for Int4GroupwisePreshuffleTensor
Summary: Added basic op support like linear and bmm, we have both float8 and bf16 in the same Tensor because it's the same dtype, only difference is whether the activation is quantized or not. Although there is some differneces in implementation: bf16 activaton: * group_scale * group_zero fp8 activation * group_scale * row_scale Test Plan: python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2437, branch: jerryzh168/stack/4
1 parent ad1efd7 commit c13489e

File tree

5 files changed

+239
-77
lines changed

5 files changed

+239
-77
lines changed

test/quantization/quantize_/test_int4_groupwise_preshuffle.py renamed to test/quantization/quantize_/int4/test_int4_groupwise_preshuffle.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import tempfile
78
import unittest
89

910
import torch
1011
from torch.testing._internal.common_utils import (
1112
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
1215
run_tests,
1316
)
1417

18+
from torchao.float8.config import e4m3_dtype
1519
from torchao.quantization import (
1620
FbgemmConfig,
1721
quantize_,
@@ -23,6 +27,38 @@
2327
is_sm_at_least_90,
2428
)
2529

30+
BF16_ACT_CONFIG = FbgemmConfig(
31+
input_dtype=torch.bfloat16,
32+
weight_dtype=torch.int4,
33+
output_dtype=torch.bfloat16,
34+
block_size=[1, 128],
35+
preshuffle=True,
36+
)
37+
38+
BF16_ACT_BMM_CONFIG = FbgemmConfig(
39+
input_dtype=torch.bfloat16,
40+
weight_dtype=torch.int4,
41+
output_dtype=torch.bfloat16,
42+
block_size=[1, 1, 128],
43+
preshuffle=True,
44+
)
45+
46+
FP8_ACT_CONFIG = FbgemmConfig(
47+
input_dtype=e4m3_dtype,
48+
weight_dtype=torch.int4,
49+
output_dtype=torch.bfloat16,
50+
block_size=[1, 128],
51+
preshuffle=True,
52+
)
53+
54+
FP8_ACT_BMM_CONFIG = FbgemmConfig(
55+
input_dtype=e4m3_dtype,
56+
weight_dtype=torch.int4,
57+
output_dtype=torch.bfloat16,
58+
block_size=[1, 1, 128],
59+
preshuffle=True,
60+
)
61+
2662

2763
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
2864
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -32,33 +68,23 @@
3268
)
3369
class TestInt4GroupwisePreshuffleTensor(TestCase):
3470
def setUp(self):
35-
self.config = FbgemmConfig(
36-
input_dtype=torch.bfloat16,
37-
weight_dtype=torch.int4,
38-
output_dtype=torch.bfloat16,
39-
block_size=[1, 128],
40-
preshuffle=True,
41-
)
42-
self.bmm_config = FbgemmConfig(
43-
input_dtype=torch.bfloat16,
44-
weight_dtype=torch.int4,
45-
output_dtype=torch.bfloat16,
46-
block_size=[1, 1, 128],
47-
preshuffle=True,
48-
)
4971
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
5072

51-
def test_linear(self):
73+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
74+
def test_linear(self, config):
5275
dtype = torch.bfloat16
5376
device = "cuda"
5477
input = torch.randn(1, 128, dtype=dtype, device=device)
5578
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
5679
original = linear(input)
57-
quantize_(linear, self.config)
80+
quantize_(linear, config)
5881
quantized = linear(input)
5982
self.assertTrue(compute_error(original, quantized) > 20)
6083

61-
def test_bmm(self):
84+
# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
85+
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
86+
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
87+
def test_bmm(self, bmm_config):
6288
class M(torch.nn.Module):
6389
def __init__(self, weight):
6490
super().__init__()
@@ -74,32 +100,46 @@ def forward(self, x):
74100
m = M(weight).eval()
75101
original = m(input)
76102
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
77-
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
103+
quantize_(m, bmm_config, filter_fn=lambda x, fqn: True)
78104
quantized = m(input)
79105
self.assertTrue(compute_error(original, quantized) > 18)
80106

81-
def test_to_device(self):
107+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
108+
def test_to_device(self, config):
82109
for device in self.GPU_DEVICES:
83110
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
84-
quantize_(linear, self.config)
111+
quantize_(linear, config)
85112
linear.to(device)
86113

87114
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
88-
quantize_(linear, self.config)
115+
quantize_(linear, config)
89116
linear.to(device=device)
90117

91118
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
92-
quantize_(linear, self.config)
119+
quantize_(linear, config)
93120
linear.to(device)
94121

95-
def test_module_path(self):
122+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
123+
def test_module_path(self, config):
96124
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
97-
quantize_(linear, self.config)
125+
quantize_(linear, config)
98126
self.assertEqual(
99127
str(type(linear.weight)),
100128
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
101129
)
102130

131+
with tempfile.NamedTemporaryFile() as f:
132+
torch.save(linear.state_dict(), f)
133+
f.seek(0)
134+
state_dict = torch.load(f)
135+
self.assertEqual(
136+
str(type(state_dict["weight"])),
137+
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
138+
)
139+
140+
141+
instantiate_parametrized_tests(TestInt4GroupwisePreshuffleTensor)
142+
103143

104144
if __name__ == "__main__":
105145
run_tests()

torchao/quantization/quant_api.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2040,6 +2040,8 @@ class FbgemmConfig(AOBaseConfig):
20402040
weight_dtype (torch.dtype): weight dtype of the kernel
20412041
output_dtype (torch.dtype): output dtype of the kernel
20422042
group_size (int): The group size for weight
2043+
preshuffle (bool): whether preshuffle the weights or not
2044+
activation_dtype_for_int4 (str): the dtype for activation for int4 weight, either bf16 or fp8
20432045
"""
20442046

20452047
input_dtype: torch.dtype
@@ -2067,7 +2069,9 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20672069
):
20682070
if config.preshuffle:
20692071
weight = Int4GroupwisePreshuffleTensor.from_float(
2070-
module.weight, config.block_size
2072+
module.weight,
2073+
config.block_size,
2074+
activation_dtype="bf16",
20712075
)
20722076
else:
20732077
weight = to_fbgemm_int4(
@@ -2077,6 +2081,20 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20772081
module.weight = torch.nn.Parameter(weight, requires_grad=False)
20782082
module.extra_repr = types.MethodType(_linear_extra_repr, module)
20792083
return module
2084+
if (
2085+
(config.input_dtype == e4m3_dtype)
2086+
and (config.weight_dtype == torch.int4)
2087+
and (config.output_dtype == torch.bfloat16)
2088+
):
2089+
if config.preshuffle:
2090+
weight = Int4GroupwisePreshuffleTensor.from_float(
2091+
module.weight,
2092+
config.block_size,
2093+
activation_dtype="fp8",
2094+
)
2095+
module.weight = torch.nn.Parameter(weight, requires_grad=False)
2096+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
2097+
return module
20802098
elif (
20812099
(config.input_dtype == e4m3_dtype)
20822100
and (config.weight_dtype == e4m3_dtype)

torchao/quantization/quantize_/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .int4_groupwise_preshuffle_tensor import (
1+
from .int4 import (
22
Int4GroupwisePreshuffleTensor,
33
)
44

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .int4_groupwise_preshuffle_tensor import (
2+
Int4GroupwisePreshuffleTensor,
3+
)
4+
5+
__all__ = [
6+
"Int4GroupwisePreshuffleTensor",
7+
]

0 commit comments

Comments
 (0)