Skip to content

Commit 0b33f12

Browse files
authored
Add support for fbgemm int4 mm kernel (#2255)
* Add support for fbgemm int4 mm kernel Summary: we also plan to expose some other kernels like fp8xint4 and bf16xfp8, fp8xfp8 to compare with existing torchao kernels Test Plan: test/dtypes/test_fbgemm_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: * fix and test * fix dtype * use importlib * add links to fbgemm code * update io_dtype type * renaming * remove enum * serializability update * format * fix tests * skip fbgemm config tests for 2.5 and below
1 parent 63f2e51 commit 0b33f12

File tree

10 files changed

+355
-7
lines changed

10 files changed

+355
-7
lines changed

test/dtypes/test_fbgemm_quantized.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from torch.testing._internal.common_utils import (
11+
TestCase,
12+
run_tests,
13+
)
14+
15+
from torchao.quantization import (
16+
FbgemmConfig,
17+
quantize_,
18+
)
19+
from torchao.quantization.utils import compute_error
20+
from torchao.utils import is_sm_at_least_90
21+
22+
23+
class TestFbgemmInt4Tensor(TestCase):
24+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
25+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
26+
def test_linear(self):
27+
dtype = torch.bfloat16
28+
device = "cuda"
29+
input = torch.randn(1, 128, dtype=dtype, device=device)
30+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
31+
original = linear(input)
32+
config = FbgemmConfig(
33+
input_dtype=torch.bfloat16,
34+
weight_dtype=torch.int4,
35+
output_dtype=torch.bfloat16,
36+
block_size=(1, 128),
37+
)
38+
quantize_(linear, config)
39+
quantized = linear(input)
40+
self.assertTrue(compute_error(original, quantized) > 20)
41+
42+
43+
if __name__ == "__main__":
44+
run_tests()
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from torch.testing._internal.common_utils import (
11+
TestCase,
12+
run_tests,
13+
)
14+
15+
from torchao.quantization import (
16+
FbgemmConfig,
17+
quantize_,
18+
)
19+
from torchao.quantization.utils import compute_error
20+
from torchao.utils import (
21+
TORCH_VERSION_AT_LEAST_2_6,
22+
is_sm_at_least_90,
23+
)
24+
25+
26+
class TestFbgemmInt4Tensor(TestCase):
27+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
28+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
29+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch >= 2.6")
30+
def test_linear(self):
31+
dtype = torch.bfloat16
32+
device = "cuda"
33+
input = torch.randn(1, 128, dtype=dtype, device=device)
34+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
35+
original = linear(input)
36+
config = FbgemmConfig(
37+
input_dtype=torch.bfloat16,
38+
weight_dtype=torch.int4,
39+
output_dtype=torch.bfloat16,
40+
block_size=[1, 128],
41+
)
42+
quantize_(linear, config)
43+
quantized = linear(input)
44+
self.assertTrue(compute_error(original, quantized) > 20)
45+
46+
47+
if __name__ == "__main__":
48+
run_tests()

test/quantization/test_config_serialization.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
config_to_dict,
2121
)
2222
from torchao.quantization.quant_api import (
23+
FbgemmConfig,
2324
Float8DynamicActivationFloat8WeightConfig,
2425
Float8WeightOnlyConfig,
2526
FPXWeightOnlyConfig,
@@ -34,11 +35,13 @@
3435
UIntXWeightOnlyConfig,
3536
)
3637
from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig
38+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
3739

3840
# Define test configurations as fixtures
3941
configs = [
4042
Float8DynamicActivationFloat8WeightConfig(),
4143
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
44+
Float8DynamicActivationFloat8WeightConfig(granularity=[PerRow(), PerRow()]),
4245
Float8WeightOnlyConfig(
4346
weight_dtype=torch.float8_e4m3fn,
4447
),
@@ -78,6 +81,9 @@
7881
),
7982
]
8083

84+
if TORCH_VERSION_AT_LEAST_2_6:
85+
configs += [FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256])]
86+
8187

8288
# Create ids for better test naming
8389
def get_config_ids(configs):

torchao/_models/llama/generate.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,17 @@ def ffn_or_attn_only(mod, fqn):
439439
f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
440440
)
441441
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
442+
elif "fbgemm" in quantization:
443+
from torchao.quantization import FbgemmConfig
444+
445+
_, precision, group_size = quantization.split("-")
446+
group_size = int(group_size)
447+
if precision == "int4":
448+
quantize_(model, FbgemmConfig("bf16i4bf16", group_size))
449+
else:
450+
raise NotImplementedError(
451+
f"FbegemmConfig({precision=}) not supported yet"
452+
)
442453
elif "int4dq-" in quantization:
443454
from torchao.dtypes import CutlassInt4PackedLayout
444455

@@ -1163,7 +1174,7 @@ def callback(x):
11631174
help=(
11641175
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
11651176
+ "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
1166-
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, float8dq, int4dq-<nbits>"
1177+
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, float8dq, int4dq-<nbits>, fbgemm-int4-<group_size>"
11671178
),
11681179
)
11691180
parser.add_argument(

torchao/core/config.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,12 @@ def default(self, o):
132132
if isinstance(o, list):
133133
return [self.encode_value(item) for item in o]
134134

135+
elif isinstance(o, tuple):
136+
raise NotImplementedError(
137+
"Tuples will be serialized as List in JSON, so we recommend to use "
138+
f"Lists instead to avoid surprises. got: {o}"
139+
)
140+
135141
if isinstance(o, dict):
136142
return {k: self.encode_value(v) for k, v in o.items()}
137143

@@ -250,13 +256,18 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
250256
# Recursively handle nested configs
251257
processed_data[key] = config_from_dict(value)
252258
elif isinstance(value, list):
253-
# Handle lists of possible configs
259+
# Handle lists or tuples of possible configs
254260
processed_data[key] = [
255261
config_from_dict(item)
256262
if isinstance(item, dict) and "_type" in item and "_data" in item
257263
else item
258264
for item in value
259265
]
266+
elif isinstance(value, tuple):
267+
raise NotImplementedError(
268+
"Tuples will be serialized as List in JSON, so we recommend to use "
269+
f"Lists instead to avoid surprises. got: {value}"
270+
)
260271
elif isinstance(value, dict):
261272
# Handle dicts of possible configs
262273
processed_data[key] = {

torchao/dtypes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
to_affine_quantized_intx,
99
to_affine_quantized_intx_static,
1010
)
11+
from .fbgemm_quantized_tensor import to_fbgemm_quantized
1112
from .floatx import (
1213
CutlassSemiSparseLayout,
1314
Float8Layout,
@@ -61,4 +62,5 @@
6162
"PackedLinearInt8DynamicActivationIntxWeightLayout",
6263
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
6364
"Int4XPULayout",
65+
"to_fbgemm_quantized",
6466
]
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import importlib.util
9+
from typing import List
10+
11+
import torch
12+
from torch.utils._python_dispatch import return_and_correct_aliasing
13+
14+
from torchao.utils import TorchAOBaseTensor
15+
16+
__all__ = [
17+
"to_fbgemm_quantized",
18+
]
19+
20+
aten = torch.ops.aten
21+
22+
23+
if importlib.util.find_spec("fbgemm_gpu") is None:
24+
int4_row_quantize_zp = None
25+
pack_int4 = None
26+
else:
27+
from fbgemm_gpu.experimental.gen_ai.quantize import int4_row_quantize_zp, pack_int4
28+
29+
30+
class FbgemmInt4Tensor(TorchAOBaseTensor):
31+
tensor_data_attrs = ["packed_weight", "scale", "zero_point"]
32+
tensor_attributes = ["group_size"]
33+
34+
def __new__(cls, packed_weight, scale, zero_point, group_size):
35+
shape = packed_weight.shape
36+
kwargs = {}
37+
kwargs["device"] = packed_weight.device
38+
kwargs["dtype"] = scale.dtype
39+
kwargs["requires_grad"] = False
40+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
41+
42+
def __init__(self, packed_weight, scale, zero_point, group_size):
43+
self.packed_weight = packed_weight
44+
self.scale = scale
45+
self.zero_point = zero_point
46+
self.group_size = group_size
47+
48+
def __tensor_flatten__(self):
49+
return self.tensor_data_attrs, [
50+
getattr(self, attr) for attr in self.tensor_attributes
51+
]
52+
53+
@classmethod
54+
def __tensor_unflatten__(
55+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
56+
):
57+
return cls(
58+
*[tensor_data_dict[name] for name in cls.tensor_data_attrs],
59+
*tensor_attributes,
60+
)
61+
62+
def _apply_fn_to_data(self, fn):
63+
return self.__class__(
64+
*[fn(getattr(self, attr)) for attr in self.tensor_data_attrs],
65+
*[getattr(self, attr) for attr in self.tensor_attributes],
66+
)
67+
68+
def __repr__(self):
69+
return (
70+
f"{self.__class__.__name__}(weight={self.packed_weight}, group_size={self.group_size}, "
71+
f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
72+
)
73+
74+
@classmethod
75+
def from_float(
76+
cls,
77+
w: torch.Tensor,
78+
input_dtype: torch.dtype,
79+
weight_dtype: torch.dtype,
80+
output_dtype: torch.dtype,
81+
block_size: List[int],
82+
):
83+
assert len(block_size) == w.ndim, (
84+
f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}"
85+
)
86+
group_size = block_size[-1]
87+
88+
assert (input_dtype, weight_dtype, output_dtype) == (
89+
torch.bfloat16,
90+
torch.int4,
91+
torch.bfloat16,
92+
)
93+
94+
if w.ndim >= 3:
95+
wq, scale, zero_point = zip(
96+
*[int4_row_quantize_zp(i, group_size) for i in w], strict=False
97+
)
98+
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
99+
scale = torch.stack(scale, dim=0)
100+
zero_point = torch.stack(zero_point, dim=0)
101+
else:
102+
wq, scale, zero_point = int4_row_quantize_zp(w, group_size)
103+
wq = pack_int4(wq)
104+
105+
scale = scale.to(w.dtype)
106+
zero_point = zero_point.to(w.dtype)
107+
108+
del w
109+
return FbgemmInt4Tensor(
110+
packed_weight=wq,
111+
scale=scale,
112+
zero_point=zero_point,
113+
group_size=group_size,
114+
)
115+
116+
117+
implements = FbgemmInt4Tensor.implements
118+
119+
120+
@implements([torch.nn.functional.linear, aten.linear.default])
121+
def _(func, types, args, kwargs):
122+
input_tensor, weight_tensor, bias = (
123+
args[0],
124+
args[1],
125+
args[2] if len(args) > 2 else None,
126+
)
127+
if not input_tensor.is_floating_point():
128+
raise NotImplementedError(
129+
f"{func} is not implemented for non floating point input"
130+
)
131+
132+
orig_act_size = input_tensor.size()
133+
orig_out_features = weight_tensor.shape[-2]
134+
135+
res = torch.ops.fbgemm.bf16i4bf16_rowwise(
136+
input_tensor,
137+
weight_tensor.packed_weight,
138+
weight_tensor.scale,
139+
weight_tensor.zero_point,
140+
)
141+
if bias is not None:
142+
res = res + bias
143+
return res.reshape(*orig_act_size[:-1], orig_out_features)
144+
145+
146+
@implements([aten.detach.default, aten.alias.default])
147+
def _(func, types, args, kwargs):
148+
return return_and_correct_aliasing(
149+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
150+
)
151+
152+
153+
@implements([aten.clone.default, aten.copy_.default])
154+
def _(func, types, args, kwargs):
155+
return return_and_correct_aliasing(
156+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
157+
)
158+
159+
160+
# We can have `to_fbgemm_tensor` to dispatch to different Fbgemm tensors later
161+
to_fbgemm_quantized = FbgemmInt4Tensor.from_float

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
)
4141
from .quant_api import (
4242
CutlassInt4PackedLayout,
43+
FbgemmConfig,
4344
Float8DynamicActivationFloat8SemiSparseWeightConfig,
4445
Float8DynamicActivationFloat8WeightConfig,
4546
Float8MMConfig,
@@ -148,6 +149,7 @@
148149
"FPXWeightOnlyConfig",
149150
"GemliteUIntXWeightOnlyConfig",
150151
"ModuleFqnToConfig",
152+
"FbgemmConfig",
151153
# smooth quant - subject to change
152154
"get_scale",
153155
"SmoothFakeDynQuantMixin",

0 commit comments

Comments
 (0)