Skip to content

Commit 4a3c77e

Browse files
committed
Add all fbgemm kernel Tensors into Int4WeightOnlyConfig and Float8DynamicActivationInt4WeightConfig
Summary: att, we will deprecate FbgemmConfig since it's a single kernel. we'd like to categorize things to derived dtype + packed format Test Plan: python test/quantization/quantize_/test_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2474, branch: jerryzh168/stack/10
1 parent 76ba25d commit 4a3c77e

File tree

10 files changed

+109
-88
lines changed

10 files changed

+109
-88
lines changed

test/integration/test_serialization_bc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
_MODEL_NAMES = [
2020
"torchao-testing/opt-125m-float8dq-row-fbgemm",
21+
"torchao-testing/opt-125m-int4wo-preshuffle",
2122
]
2223

2324

test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
run_tests,
1616
)
1717

18-
from torchao.float8.config import e4m3_dtype
1918
from torchao.quantization import (
20-
FbgemmConfig,
19+
Float8ActivationInt4WeightConfig,
20+
Int4WeightOnlyConfig,
2121
quantize_,
2222
)
2323
from torchao.quantization.utils import compute_error
@@ -27,44 +27,15 @@
2727
is_sm_at_least_90,
2828
)
2929

30-
if TORCH_VERSION_AT_LEAST_2_8:
31-
BF16_ACT_CONFIG = FbgemmConfig(
32-
input_dtype=torch.bfloat16,
33-
weight_dtype=torch.int4,
34-
output_dtype=torch.bfloat16,
35-
block_size=[1, 128],
36-
preshuffle=True,
37-
)
38-
39-
BF16_ACT_BMM_CONFIG = FbgemmConfig(
40-
input_dtype=torch.bfloat16,
41-
weight_dtype=torch.int4,
42-
output_dtype=torch.bfloat16,
43-
block_size=[1, 1, 128],
44-
preshuffle=True,
45-
)
46-
47-
FP8_ACT_CONFIG = FbgemmConfig(
48-
input_dtype=e4m3_dtype,
49-
weight_dtype=torch.int4,
50-
output_dtype=torch.bfloat16,
51-
block_size=[1, 128],
52-
preshuffle=True,
53-
)
54-
55-
FP8_ACT_BMM_CONFIG = FbgemmConfig(
56-
input_dtype=e4m3_dtype,
57-
weight_dtype=torch.int4,
58-
output_dtype=torch.bfloat16,
59-
block_size=[1, 1, 128],
60-
preshuffle=True,
61-
)
62-
63-
else:
64-
BF16_ACT_CONFIG = None
65-
BF16_ACT_BMM_CONFIG = None
66-
FP8_ACT_CONFIG = None
67-
FP8_ACT_BMM_CONFIG = None
30+
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
31+
group_size=128,
32+
use_preshuffle=True,
33+
)
34+
35+
FP8_ACT_CONFIG = Float8ActivationInt4WeightConfig(
36+
group_size=128,
37+
use_preshuffle=True,
38+
)
6839

6940

7041
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@@ -90,7 +61,7 @@ def test_linear(self, config):
9061

9162
# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
9263
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
93-
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
64+
@parametrize("bmm_config", [FP8_ACT_CONFIG, BF16_ACT_CONFIG])
9465
def test_bmm(self, bmm_config):
9566
class M(torch.nn.Module):
9667
def __init__(self, weight):

test/dtypes/test_fbgemm_int4.py renamed to test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414

1515
from torchao.quantization import (
16-
FbgemmConfig,
16+
Int4WeightOnlyConfig,
1717
quantize_,
1818
)
1919
from torchao.quantization.utils import compute_error
@@ -26,19 +26,12 @@
2626
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
2727
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2828
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
29-
class TestFbgemmInt4Tensor(TestCase):
29+
class TestInt4Tensor(TestCase):
3030
def setUp(self):
31-
self.config = FbgemmConfig(
32-
input_dtype=torch.bfloat16,
33-
weight_dtype=torch.int4,
34-
output_dtype=torch.bfloat16,
35-
block_size=[1, 128],
36-
)
37-
self.bmm_config = FbgemmConfig(
38-
input_dtype=torch.bfloat16,
39-
weight_dtype=torch.int4,
40-
output_dtype=torch.bfloat16,
41-
block_size=[1, 1, 128],
31+
self.config = Int4WeightOnlyConfig(
32+
group_size=128,
33+
use_preshuffle=False,
34+
gemm_kernel_choice="fbgemm",
4235
)
4336
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4437

@@ -135,7 +128,7 @@ def forward(self, x):
135128
original = m(input)
136129
# we need to transpose the weight first for bmm
137130
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
138-
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
131+
quantize_(m, self.config, filter_fn=lambda x, fqn: True)
139132
quantized = m(input)
140133
self.assertTrue(compute_error(original, quantized) > 18)
141134

torchao/dtypes/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
to_affine_quantized_intx_static,
1010
)
1111
from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8
12-
from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4
1312
from .floatx import (
1413
CutlassSemiSparseLayout,
1514
Float8Layout,
@@ -64,8 +63,6 @@
6463
"PackedLinearInt8DynamicActivationIntxWeightLayout",
6564
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
6665
"Int4XPULayout",
67-
"to_fbgemm_int4",
68-
"FbgemmInt4Tensor",
6966
"to_fbgemm_fp8",
7067
"FbgemmFp8Tensor",
7168
"Int8DynamicActInt4WeightCPULayout",

torchao/quantization/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .quant_api import (
4545
CutlassInt4PackedLayout,
4646
FbgemmConfig,
47+
Float8ActivationInt4WeightConfig,
4748
Float8DynamicActivationFloat8SemiSparseWeightConfig,
4849
Float8DynamicActivationFloat8WeightConfig,
4950
Float8MMConfig,
@@ -90,6 +91,7 @@
9091
from .quantize_.workflows import (
9192
Float8Tensor,
9293
Int4PreshuffledTensor,
94+
Int4Tensor,
9395
)
9496
from .smoothquant import (
9597
SmoothFakeDynamicallyQuantizedLinear,
@@ -141,6 +143,7 @@
141143
"Int8DynamicActivationInt8WeightConfig",
142144
"Int8DynamicActivationIntxWeightConfig",
143145
"Int4WeightOnlyConfig",
146+
"Float8ActivationInt4WeightConfig",
144147
"Int8WeightOnlyConfig",
145148
"Float8WeightOnlyConfig",
146149
"Float8DynamicActivationFloat8WeightConfig",
@@ -154,6 +157,7 @@
154157
"ModuleFqnToConfig",
155158
"FbgemmConfig",
156159
# tensor subclasses
160+
"Int4Tensor",
157161
"Int4PreshuffledTensor",
158162
"Float8Tensor",
159163
# smooth quant - subject to change

torchao/quantization/quant_api.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
to_affine_quantized_floatx_static,
5050
to_affine_quantized_intx,
5151
to_fbgemm_fp8,
52-
to_fbgemm_int4,
5352
to_marlinqqq_quantized_intx,
5453
)
5554
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
@@ -72,6 +71,7 @@
7271
from torchao.quantization.quantize_.workflows import (
7372
Float8Tensor,
7473
Int4PreshuffledTensor,
74+
Int4Tensor,
7575
)
7676
from torchao.quantization.transform_module import (
7777
_QUANTIZE_CONFIG_HANDLER,
@@ -1116,6 +1116,7 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11161116
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE
11171117
set_inductor_config: bool = True
11181118
preserve_zero: Optional[bool] = None
1119+
use_preshuffle: bool = False
11191120

11201121

11211122
# for BC
@@ -1133,15 +1134,31 @@ def _int4_weight_only_quantize_tensor(weight, config):
11331134
layout = config.layout
11341135
use_hqq = config.use_hqq
11351136
zero_point_domain = config.zero_point_domain
1137+
use_preshuffle = config.use_preshuffle
11361138

11371139
if weight.shape[-1] % group_size != 0:
11381140
logger.info(
11391141
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
11401142
)
11411143
return weight
11421144

1145+
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
1146+
1147+
if use_preshuffle:
1148+
new_weight = Int4PreshuffledTensor.from_float(
1149+
weight,
1150+
block_size,
1151+
activation_dtype="bf16",
1152+
)
1153+
return new_weight
1154+
else:
1155+
new_weight = Int4Tensor.from_float(
1156+
weight,
1157+
block_size,
1158+
)
1159+
return new_weight
1160+
11431161
mapping_type = MappingType.ASYMMETRIC
1144-
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size])
11451162
target_dtype = torch.int32
11461163
quant_min = 0
11471164
quant_max = 15
@@ -1213,6 +1230,39 @@ def _int4_weight_only_transform(
12131230
return module
12141231

12151232

1233+
@dataclass
1234+
class Float8ActivationInt4WeightConfig(AOBaseConfig):
1235+
group_size: int = 128
1236+
use_preshuffle: bool = False
1237+
kernel: str = "fbgemm"
1238+
1239+
1240+
@register_quantize_module_handler(Float8ActivationInt4WeightConfig)
1241+
def _(module: torch.nn.Module, config: Int4WeightOnlyConfig) -> torch.nn.Module:
1242+
assert hasattr(module, "weight"), (
1243+
"applying int8 weight only quant requires module to have weight attribute"
1244+
+ " but {module} does not have one"
1245+
)
1246+
group_size = config.group_size
1247+
use_preshuffle = config.use_preshuffle
1248+
kernel = config.kernel
1249+
1250+
assert use_preshuffle, (
1251+
f"only use_preshuffle == True is supported right now, got: {use_preshuffle}"
1252+
)
1253+
assert kernel == "fbgemm", f"only fbgemm kernel is supported, got: {kernel}"
1254+
weight = module.weight
1255+
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
1256+
new_weight = Int4PreshuffledTensor.from_float(
1257+
module.weight,
1258+
block_size,
1259+
activation_dtype="fp8",
1260+
)
1261+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1262+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1263+
return module
1264+
1265+
12161266
@dataclass
12171267
class Int8WeightOnlyConfig(AOBaseConfig):
12181268
"""
@@ -2067,7 +2117,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20672117
activation_dtype=torch.bfloat16,
20682118
)
20692119
else:
2070-
weight = to_fbgemm_int4(
2120+
weight = Int4Tensor.from_float(
20712121
module.weight,
20722122
config.block_size,
20732123
)

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
)
44
from .int4 import (
55
Int4PreshuffledTensor,
6+
Int4Tensor,
67
)
78

9+
Int4Tensor.__module__ = "torchao.quantization"
810
Int4PreshuffledTensor.__module__ = "torchao.quantization"
911
Float8Tensor.__module__ = "torchao.quantization"
1012

1113
__all__ = [
14+
"Int4Tensor",
1215
"Int4PreshuffledTensor",
1316
"Float8Tensor",
1417
]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from .int4_preshuffled_tensor import Int4PreshuffledTensor
2+
from .int4_tensor import Int4Tensor
23

34
__all__ = [
45
"Int4PreshuffledTensor",
6+
"Int4Tensor",
57
]

torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def from_float(
194194
if quantize_int4_preshuffle is None:
195195
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")
196196

197-
assert all(x == 1 for x in block_size[:-1]), (
197+
assert all(x == 1 for x in block_size[:-1]) and block_size[-1] != 1, (
198198
"Only groupwise quant is supported right now"
199199
)
200200
group_size = block_size[-1]

0 commit comments

Comments
 (0)