Skip to content

Commit ee9fb34

Browse files
committed
Add support for Int4GroupwisePreshuffleTensor for fbgemm
Summary: Note: slice is not working yet, others are working Test Plan: python test/dtypes/test_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2421, branch: jerryzh168/stack/1
1 parent 5a50667 commit ee9fb34

File tree

6 files changed

+624
-5
lines changed

6 files changed

+624
-5
lines changed
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from torch.testing._internal.common_utils import (
11+
TestCase,
12+
run_tests,
13+
)
14+
15+
from torchao.quantization import (
16+
FbgemmConfig,
17+
quantize_,
18+
)
19+
from torchao.quantization.utils import compute_error
20+
from torchao.utils import (
21+
TORCH_VERSION_AT_LEAST_2_8,
22+
is_sm_at_least_90,
23+
)
24+
25+
26+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
27+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
28+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
29+
class TestInt4GroupwisePreshuffleTensor(TestCase):
30+
def setUp(self):
31+
self.config = FbgemmConfig(
32+
input_dtype=torch.bfloat16,
33+
weight_dtype=torch.int4,
34+
output_dtype=torch.bfloat16,
35+
block_size=[1, 128],
36+
preshuffle=True,
37+
)
38+
self.bmm_config = FbgemmConfig(
39+
input_dtype=torch.bfloat16,
40+
weight_dtype=torch.int4,
41+
output_dtype=torch.bfloat16,
42+
block_size=[1, 1, 128],
43+
preshuffle=True,
44+
)
45+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
46+
47+
def test_linear(self):
48+
dtype = torch.bfloat16
49+
device = "cuda"
50+
input = torch.randn(1, 128, dtype=dtype, device=device)
51+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
52+
original = linear(input)
53+
quantize_(linear, self.config)
54+
quantized = linear(input)
55+
self.assertTrue(compute_error(original, quantized) > 20)
56+
57+
@unittest.skip("WIP: this doesn't work yet")
58+
def test_slice(self):
59+
dtype = torch.bfloat16
60+
device = "cuda"
61+
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
62+
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
63+
dummy1.weight = torch.nn.Parameter(
64+
dummy.weight.narrow(0, 0, 64), requires_grad=False
65+
)
66+
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
67+
dummy2.weight = torch.nn.Parameter(
68+
dummy.weight.narrow(1, 0, 128), requires_grad=False
69+
)
70+
71+
quantize_(dummy, self.config)
72+
weight1 = dummy.weight.narrow(0, 0, 64)
73+
weight2 = dummy.weight.narrow(1, 0, 128)
74+
# check the slicing operation is correctly performend of the constituents Tensors
75+
self.assertEqual(
76+
weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64)
77+
)
78+
self.assertEqual(weight1.group_scale, dummy.weight.group_scale.narrow(1, 0, 64))
79+
self.assertEqual(
80+
weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64)
81+
)
82+
self.assertEqual(weight2.group_scale, dummy.weight.group_scale.narrow(0, 0, 1))
83+
84+
# check for 1. sliced bf16 weight 2. sliced quantized weight
85+
# can produce similar results doing matmul on the same input Tensor
86+
87+
input = torch.randn(2, 256, dtype=dtype, device=device)
88+
res_ref = dummy1(input)
89+
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
90+
res = dummy(input)
91+
sqnr = compute_error(res, res_ref)
92+
assert sqnr > 20, f"Got: {sqnr}"
93+
94+
input = torch.randn(2, 128, dtype=dtype, device=device)
95+
res_ref = dummy2(input)
96+
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
97+
res = dummy(input)
98+
sqnr = compute_error(res, res_ref)
99+
assert sqnr > 15, f"Got: {sqnr}"
100+
101+
def test_slice_and_copy_(self):
102+
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
103+
l.weight = torch.nn.Parameter(
104+
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
105+
)
106+
quantize_(l, self.config)
107+
param = l.weight
108+
param_data = param.data
109+
param_data = param_data.narrow(0, 0, 512)
110+
assert (
111+
param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr()
112+
)
113+
assert param.data.group_scale.data_ptr() == param_data.group_scale.data_ptr()
114+
assert param.data.group_zero.data_ptr() == param_data.group_zero.data_ptr()
115+
orig_value = param.data.packed_weight[0][0].item()
116+
117+
# dummy_l has random input (shouldn't be 0)
118+
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
119+
quantize_(dummy_l, self.config)
120+
quantized = dummy_l.weight
121+
quantized = quantized.narrow(0, 0, 512)
122+
123+
param_data.copy_(quantized)
124+
125+
# making sure param.data is updated
126+
assert param.data.packed_weight[0][0] != orig_value
127+
128+
def test_bmm(self):
129+
class M(torch.nn.Module):
130+
def __init__(self, weight):
131+
super().__init__()
132+
self.weight = weight
133+
134+
def forward(self, x):
135+
return torch.bmm(x, self.weight)
136+
137+
dtype = torch.bfloat16
138+
device = "cuda"
139+
input = torch.randn(10, 32, 128, dtype=dtype, device=device)
140+
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
141+
m = M(weight).eval()
142+
original = m(input)
143+
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
144+
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
145+
quantized = m(input)
146+
self.assertTrue(compute_error(original, quantized) > 18)
147+
148+
def test_to_device(self):
149+
for device in self.GPU_DEVICES:
150+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
151+
quantize_(linear, self.config)
152+
linear.to(device)
153+
154+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
155+
quantize_(linear, self.config)
156+
linear.to(device=device)
157+
158+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
159+
quantize_(linear, self.config)
160+
linear.to(device)
161+
162+
def test_module_path(self):
163+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
164+
quantize_(linear, self.config)
165+
self.assertEqual(str(type(linear.weight)), "<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>")
166+
167+
168+
if __name__ == "__main__":
169+
run_tests()

torchao/dtypes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,5 @@
6969
"to_fbgemm_fp8",
7070
"FbgemmFp8Tensor",
7171
"Int8DynamicActInt4WeightCPULayout",
72+
"Int4GroupwisePreshuffleTensor",
7273
]

torchao/quantization/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@
102102
compute_error,
103103
)
104104
from .weight_only import WeightOnlyInt8QuantLinear
105+
from .quantize_ import (
106+
Int4GroupwisePreshuffleTensor,
107+
)
105108

106109
# TODO: remove after migration of APIs are done
107110
AOPerModuleConfig = ModuleFqnToConfig
@@ -149,6 +152,8 @@
149152
"AOPerModuleConfig",
150153
"ModuleFqnToConfig",
151154
"FbgemmConfig",
155+
# tensor subclasses
156+
"Int4GroupwisePreshuffleTensor",
152157
# smooth quant - subject to change
153158
"get_scale",
154159
"SmoothFakeDynQuantMixin",

torchao/quantization/quant_api.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@
7575
from torchao.quantization.weight_tensor_linear_activation_quantization import (
7676
to_weight_tensor_with_linear_activation_quantization_metadata,
7777
)
78+
from torchao.quantization.quantize_ import (
79+
Int4GroupwisePreshuffleTensor,
80+
)
7881
from torchao.utils import (
7982
TORCH_VERSION_AT_LEAST_2_4,
8083
TORCH_VERSION_AT_LEAST_2_5,
@@ -2046,6 +2049,7 @@ class FbgemmConfig(AOBaseConfig):
20462049
block_size: Optional[List[int]] = None
20472050
activation_scale_ub: Optional[float] = None
20482051
transpose_input: bool = False
2052+
preshuffle: bool = False
20492053

20502054

20512055
@register_quantize_module_handler(FbgemmConfig)
@@ -2070,11 +2074,16 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20702074
and (config.weight_dtype == torch.int4)
20712075
and (config.output_dtype == torch.bfloat16)
20722076
):
2073-
weight = to_fbgemm_int4(
2074-
module.weight,
2075-
config.block_size,
2076-
config.transpose_input,
2077-
)
2077+
if config.preshuffle:
2078+
weight = Int4GroupwisePreshuffleTensor.from_float(
2079+
module.weight, config.block_size
2080+
)
2081+
else:
2082+
weight = to_fbgemm_int4(
2083+
module.weight,
2084+
config.block_size,
2085+
config.transpose_input,
2086+
)
20782087
module.weight = torch.nn.Parameter(weight, requires_grad=False)
20792088
module.extra_repr = types.MethodType(_linear_extra_repr, module)
20802089
return module
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .int4_groupwise_preshuffle_tensor import (
2+
Int4GroupwisePreshuffleTensor,
3+
)
4+
5+
Int4GroupwisePreshuffleTensor.__module__ = "torchao.quantization"
6+
7+
__all__ = [
8+
"Int4GroupwisePreshuffleTensor",
9+
]

0 commit comments

Comments
 (0)