Skip to content

Commit e664a1e

Browse files
committed
Add support for Int4GroupwisePreshuffleTensor for fbgemm
Summary: Note: slice is not working yet, others are working Test Plan: python test/dtypes/test_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2421, branch: jerryzh168/stack/1
1 parent 5a50667 commit e664a1e

File tree

7 files changed

+493
-15
lines changed

7 files changed

+493
-15
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from torch.testing._internal.common_utils import (
11+
TestCase,
12+
run_tests,
13+
)
14+
15+
from torchao.quantization import (
16+
FbgemmConfig,
17+
quantize_,
18+
)
19+
from torchao.quantization.utils import compute_error
20+
from torchao.utils import (
21+
TORCH_VERSION_AT_LEAST_2_8,
22+
_is_fbgemm_genai_gpu_available,
23+
is_sm_at_least_90,
24+
)
25+
26+
27+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
28+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
29+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
30+
@unittest.skipIf(
31+
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
32+
)
33+
class TestInt4GroupwisePreshuffleTensor(TestCase):
34+
def setUp(self):
35+
self.config = FbgemmConfig(
36+
input_dtype=torch.bfloat16,
37+
weight_dtype=torch.int4,
38+
output_dtype=torch.bfloat16,
39+
block_size=[1, 128],
40+
preshuffle=True,
41+
)
42+
self.bmm_config = FbgemmConfig(
43+
input_dtype=torch.bfloat16,
44+
weight_dtype=torch.int4,
45+
output_dtype=torch.bfloat16,
46+
block_size=[1, 1, 128],
47+
preshuffle=True,
48+
)
49+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
50+
51+
def test_linear(self):
52+
dtype = torch.bfloat16
53+
device = "cuda"
54+
input = torch.randn(1, 128, dtype=dtype, device=device)
55+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
56+
original = linear(input)
57+
quantize_(linear, self.config)
58+
quantized = linear(input)
59+
self.assertTrue(compute_error(original, quantized) > 20)
60+
61+
def test_bmm(self):
62+
class M(torch.nn.Module):
63+
def __init__(self, weight):
64+
super().__init__()
65+
self.weight = weight
66+
67+
def forward(self, x):
68+
return torch.bmm(x, self.weight)
69+
70+
dtype = torch.bfloat16
71+
device = "cuda"
72+
input = torch.randn(10, 32, 128, dtype=dtype, device=device)
73+
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
74+
m = M(weight).eval()
75+
original = m(input)
76+
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
77+
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
78+
quantized = m(input)
79+
self.assertTrue(compute_error(original, quantized) > 18)
80+
81+
def test_to_device(self):
82+
for device in self.GPU_DEVICES:
83+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
84+
quantize_(linear, self.config)
85+
linear.to(device)
86+
87+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
88+
quantize_(linear, self.config)
89+
linear.to(device=device)
90+
91+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
92+
quantize_(linear, self.config)
93+
linear.to(device)
94+
95+
def test_module_path(self):
96+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
97+
quantize_(linear, self.config)
98+
self.assertEqual(
99+
str(type(linear.weight)),
100+
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
101+
)
102+
103+
104+
if __name__ == "__main__":
105+
run_tests()

torchao/dtypes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,5 @@
6969
"to_fbgemm_fp8",
7070
"FbgemmFp8Tensor",
7171
"Int8DynamicActInt4WeightCPULayout",
72+
"Int4GroupwisePreshuffleTensor",
7273
]

torchao/quantization/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@
8787
dequantize_affine,
8888
quantize_affine,
8989
)
90+
from .quantize_ import (
91+
Int4GroupwisePreshuffleTensor,
92+
)
9093
from .smoothquant import (
9194
SmoothFakeDynamicallyQuantizedLinear,
9295
SmoothFakeDynQuantMixin,
@@ -149,6 +152,8 @@
149152
"AOPerModuleConfig",
150153
"ModuleFqnToConfig",
151154
"FbgemmConfig",
155+
# tensor subclasses
156+
"Int4GroupwisePreshuffleTensor",
152157
# smooth quant - subject to change
153158
"get_scale",
154159
"SmoothFakeDynQuantMixin",

torchao/quantization/quant_api.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
and mixed GEMM kernels
1616
"""
1717

18-
import importlib.util
1918
import logging
2019
import types
2120
import warnings
@@ -68,6 +67,9 @@
6867
LinearActivationWeightObservedTensor,
6968
)
7069
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
70+
from torchao.quantization.quantize_ import (
71+
Int4GroupwisePreshuffleTensor,
72+
)
7173
from torchao.quantization.transform_module import (
7274
_QUANTIZE_CONFIG_HANDLER,
7375
register_quantize_module_handler,
@@ -79,7 +81,7 @@
7981
TORCH_VERSION_AT_LEAST_2_4,
8082
TORCH_VERSION_AT_LEAST_2_5,
8183
TORCH_VERSION_AT_LEAST_2_6,
82-
is_fbcode,
84+
_is_fbgemm_genai_gpu_available,
8385
is_MI300,
8486
is_sm_at_least_89,
8587
is_sm_at_least_90,
@@ -2046,18 +2048,12 @@ class FbgemmConfig(AOBaseConfig):
20462048
block_size: Optional[List[int]] = None
20472049
activation_scale_ub: Optional[float] = None
20482050
transpose_input: bool = False
2051+
preshuffle: bool = False
20492052

20502053

20512054
@register_quantize_module_handler(FbgemmConfig)
20522055
def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
2053-
# TODO: use is_package_at_least("fbgemm_gpu", "1.2.0") when
2054-
# https://github.com/pytorch/FBGEMM/issues/4198 is fixed
2055-
if importlib.util.find_spec("fbgemm_gpu") is None:
2056-
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")
2057-
2058-
import fbgemm_gpu.experimental.gen_ai # noqa: F401
2059-
2060-
if not is_fbcode() and fbgemm_gpu.__version__ < "1.2.0":
2056+
if not _is_fbgemm_genai_gpu_available():
20612057
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")
20622058

20632059
_SUPPORTED_DTYPES = {
@@ -2070,11 +2066,16 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20702066
and (config.weight_dtype == torch.int4)
20712067
and (config.output_dtype == torch.bfloat16)
20722068
):
2073-
weight = to_fbgemm_int4(
2074-
module.weight,
2075-
config.block_size,
2076-
config.transpose_input,
2077-
)
2069+
if config.preshuffle:
2070+
weight = Int4GroupwisePreshuffleTensor.from_float(
2071+
module.weight, config.block_size
2072+
)
2073+
else:
2074+
weight = to_fbgemm_int4(
2075+
module.weight,
2076+
config.block_size,
2077+
config.transpose_input,
2078+
)
20782079
module.weight = torch.nn.Parameter(weight, requires_grad=False)
20792080
module.extra_repr = types.MethodType(_linear_extra_repr, module)
20802081
return module
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .int4_groupwise_preshuffle_tensor import (
2+
Int4GroupwisePreshuffleTensor,
3+
)
4+
5+
Int4GroupwisePreshuffleTensor.__module__ = "torchao.quantization"
6+
7+
__all__ = [
8+
"Int4GroupwisePreshuffleTensor",
9+
]

0 commit comments

Comments
 (0)