Skip to content

Commit 889dca3

Browse files
committed
Add support for float8 activation for Int4PreshuffledTensor
Summary: Added basic op support like linear and bmm, we have both float8 and bf16 in the same Tensor because it's the same dtype, only difference is whether the activation is quantized or not. Although there is some differneces in implementation: bf16 activaton: * group_scale * group_zero fp8 activation * group_scale * row_scale Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2437, branch: jerryzh168/stack/4
1 parent e5ca515 commit 889dca3

File tree

7 files changed

+271
-124
lines changed

7 files changed

+271
-124
lines changed

test/quantization/quantize_/test_int4_groupwise_preshuffle.py renamed to test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import tempfile
78
import unittest
89

910
import torch
1011
from torch.testing._internal.common_utils import (
1112
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
1215
run_tests,
1316
)
1417

18+
from torchao.float8.config import e4m3_dtype
1519
from torchao.quantization import (
1620
FbgemmConfig,
1721
quantize_,
@@ -23,42 +27,71 @@
2327
is_sm_at_least_90,
2428
)
2529

30+
if TORCH_VERSION_AT_LEAST_2_8:
31+
BF16_ACT_CONFIG = FbgemmConfig(
32+
input_dtype=torch.bfloat16,
33+
weight_dtype=torch.int4,
34+
output_dtype=torch.bfloat16,
35+
block_size=[1, 128],
36+
preshuffle=True,
37+
)
38+
39+
BF16_ACT_BMM_CONFIG = FbgemmConfig(
40+
input_dtype=torch.bfloat16,
41+
weight_dtype=torch.int4,
42+
output_dtype=torch.bfloat16,
43+
block_size=[1, 1, 128],
44+
preshuffle=True,
45+
)
46+
47+
FP8_ACT_CONFIG = FbgemmConfig(
48+
input_dtype=e4m3_dtype,
49+
weight_dtype=torch.int4,
50+
output_dtype=torch.bfloat16,
51+
block_size=[1, 128],
52+
preshuffle=True,
53+
)
54+
55+
FP8_ACT_BMM_CONFIG = FbgemmConfig(
56+
input_dtype=e4m3_dtype,
57+
weight_dtype=torch.int4,
58+
output_dtype=torch.bfloat16,
59+
block_size=[1, 1, 128],
60+
preshuffle=True,
61+
)
62+
63+
else:
64+
BF16_ACT_CONFIG = None
65+
BF16_ACT_BMM_CONFIG = None
66+
FP8_ACT_CONFIG = None
67+
FP8_ACT_BMM_CONFIG = None
68+
2669

2770
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
2871
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2972
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
3073
@unittest.skipIf(
3174
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
3275
)
33-
class TestInt4GroupwisePreshuffleTensor(TestCase):
76+
class TestInt4PreshuffledTensor(TestCase):
3477
def setUp(self):
35-
self.config = FbgemmConfig(
36-
input_dtype=torch.bfloat16,
37-
weight_dtype=torch.int4,
38-
output_dtype=torch.bfloat16,
39-
block_size=[1, 128],
40-
preshuffle=True,
41-
)
42-
self.bmm_config = FbgemmConfig(
43-
input_dtype=torch.bfloat16,
44-
weight_dtype=torch.int4,
45-
output_dtype=torch.bfloat16,
46-
block_size=[1, 1, 128],
47-
preshuffle=True,
48-
)
4978
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
5079

51-
def test_linear(self):
80+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
81+
def test_linear(self, config):
5282
dtype = torch.bfloat16
5383
device = "cuda"
5484
input = torch.randn(1, 128, dtype=dtype, device=device)
5585
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
5686
original = linear(input)
57-
quantize_(linear, self.config)
87+
quantize_(linear, config)
5888
quantized = linear(input)
5989
self.assertTrue(compute_error(original, quantized) > 20)
6090

61-
def test_bmm(self):
91+
# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
92+
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
93+
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
94+
def test_bmm(self, bmm_config):
6295
class M(torch.nn.Module):
6396
def __init__(self, weight):
6497
super().__init__()
@@ -74,32 +107,46 @@ def forward(self, x):
74107
m = M(weight).eval()
75108
original = m(input)
76109
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
77-
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
110+
quantize_(m, bmm_config, filter_fn=lambda x, fqn: True)
78111
quantized = m(input)
79112
self.assertTrue(compute_error(original, quantized) > 18)
80113

81-
def test_to_device(self):
114+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
115+
def test_to_device(self, config):
82116
for device in self.GPU_DEVICES:
83117
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
84-
quantize_(linear, self.config)
118+
quantize_(linear, config)
85119
linear.to(device)
86120

87121
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
88-
quantize_(linear, self.config)
122+
quantize_(linear, config)
89123
linear.to(device=device)
90124

91125
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
92-
quantize_(linear, self.config)
126+
quantize_(linear, config)
93127
linear.to(device)
94128

95-
def test_module_path(self):
129+
@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
130+
def test_module_path(self, config):
96131
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
97-
quantize_(linear, self.config)
132+
quantize_(linear, config)
98133
self.assertEqual(
99134
str(type(linear.weight)),
100-
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
135+
"<class 'torchao.quantization.Int4PreshuffledTensor'>",
101136
)
102137

138+
with tempfile.NamedTemporaryFile() as f:
139+
torch.save(linear.state_dict(), f)
140+
f.seek(0)
141+
state_dict = torch.load(f)
142+
self.assertEqual(
143+
str(type(state_dict["weight"])),
144+
"<class 'torchao.quantization.Int4PreshuffledTensor'>",
145+
)
146+
147+
148+
instantiate_parametrized_tests(TestInt4PreshuffledTensor)
149+
103150

104151
if __name__ == "__main__":
105152
run_tests()

torchao/quantization/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@
8787
dequantize_affine,
8888
quantize_affine,
8989
)
90-
from .quantize_ import (
91-
Int4GroupwisePreshuffleTensor,
90+
from .quantize_.workflows import (
91+
Int4PreshuffledTensor,
9292
)
9393
from .smoothquant import (
9494
SmoothFakeDynamicallyQuantizedLinear,
@@ -153,7 +153,7 @@
153153
"ModuleFqnToConfig",
154154
"FbgemmConfig",
155155
# tensor subclasses
156-
"Int4GroupwisePreshuffleTensor",
156+
"Int4PreshuffledTensor",
157157
# smooth quant - subject to change
158158
"get_scale",
159159
"SmoothFakeDynQuantMixin",

torchao/quantization/quant_api.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@
6767
LinearActivationWeightObservedTensor,
6868
)
6969
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
70-
from torchao.quantization.quantize_ import (
71-
Int4GroupwisePreshuffleTensor,
70+
from torchao.quantization.quantize_.workflows import (
71+
Int4PreshuffledTensor,
7272
)
7373
from torchao.quantization.transform_module import (
7474
_QUANTIZE_CONFIG_HANDLER,
@@ -2056,6 +2056,7 @@ class FbgemmConfig(AOBaseConfig):
20562056
weight_dtype (torch.dtype): weight dtype of the kernel
20572057
output_dtype (torch.dtype): output dtype of the kernel
20582058
group_size (int): The group size for weight
2059+
preshuffle (bool): whether preshuffle the weights or not
20592060
"""
20602061

20612062
input_dtype: torch.dtype
@@ -2073,7 +2074,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20732074

20742075
_SUPPORTED_DTYPES = {
20752076
(torch.bfloat16, torch.int4, torch.bfloat16),
2076-
(e4m3_dtype, e4m3_dtype, torch.bfloat16),
2077+
(torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16),
20772078
}
20782079

20792080
if (
@@ -2082,8 +2083,10 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20822083
and (config.output_dtype == torch.bfloat16)
20832084
):
20842085
if config.preshuffle:
2085-
weight = Int4GroupwisePreshuffleTensor.from_float(
2086-
module.weight, config.block_size
2086+
weight = Int4PreshuffledTensor.from_float(
2087+
module.weight,
2088+
config.block_size,
2089+
activation_dtype=torch.bfloat16,
20872090
)
20882091
else:
20892092
weight = to_fbgemm_int4(
@@ -2093,6 +2096,20 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20932096
module.weight = torch.nn.Parameter(weight, requires_grad=False)
20942097
module.extra_repr = types.MethodType(_linear_extra_repr, module)
20952098
return module
2099+
if (
2100+
(config.input_dtype == e4m3_dtype)
2101+
and (config.weight_dtype == torch.int4)
2102+
and (config.output_dtype == torch.bfloat16)
2103+
):
2104+
if config.preshuffle:
2105+
weight = Int4PreshuffledTensor.from_float(
2106+
module.weight,
2107+
config.block_size,
2108+
activation_dtype=torch.float8_e4m3fn,
2109+
)
2110+
module.weight = torch.nn.Parameter(weight, requires_grad=False)
2111+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
2112+
return module
20962113
elif (
20972114
(config.input_dtype == e4m3_dtype)
20982115
and (config.weight_dtype == e4m3_dtype)
Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +0,0 @@
1-
from .int4_groupwise_preshuffle_tensor import (
2-
Int4GroupwisePreshuffleTensor,
3-
)
4-
5-
Int4GroupwisePreshuffleTensor.__module__ = "torchao.quantization"
6-
7-
__all__ = [
8-
"Int4GroupwisePreshuffleTensor",
9-
]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .int4.int4_preshuffled_tensor import (
2+
Int4PreshuffledTensor,
3+
)
4+
5+
__all__ = [
6+
"Int4PreshuffledTensor",
7+
]

torchao/quantization/quantize_/workflows/int4/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)