Skip to content

Commit b82465e

Browse files
committed
Add all fbgemm kernel Tensors into Int4WeightOnlyConfig and Float8DynamicActivationInt4WeightConfig
Summary: att, we will deprecate FbgemmConfig since it's a single kernel. we'd like to categorize things to derived dtype + packed format Test Plan: python test/quantization/quantize_/test_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2474, branch: jerryzh168/stack/10
1 parent 66f4137 commit b82465e

File tree

11 files changed

+170
-118
lines changed

11 files changed

+170
-118
lines changed

test/integration/test_serialization_bc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
_MODEL_NAMES = [
2020
"torchao-testing/opt-125m-float8dq-row-fbgemm",
21+
"torchao-testing/opt-125m-int4wo-preshuffle",
2122
]
2223

2324

test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
run_tests,
1616
)
1717

18-
from torchao.float8.config import e4m3_dtype
1918
from torchao.quantization import (
20-
FbgemmConfig,
19+
Float8ActivationInt4WeightConfig,
20+
Int4WeightOnlyConfig,
2121
quantize_,
2222
)
2323
from torchao.quantization.utils import compute_error
@@ -27,44 +27,15 @@
2727
is_sm_at_least_90,
2828
)
2929

30-
if TORCH_VERSION_AT_LEAST_2_8:
31-
BF16_ACT_CONFIG = FbgemmConfig(
32-
input_dtype=torch.bfloat16,
33-
weight_dtype=torch.int4,
34-
output_dtype=torch.bfloat16,
35-
block_size=[1, 128],
36-
preshuffle=True,
37-
)
38-
39-
BF16_ACT_BMM_CONFIG = FbgemmConfig(
40-
input_dtype=torch.bfloat16,
41-
weight_dtype=torch.int4,
42-
output_dtype=torch.bfloat16,
43-
block_size=[1, 1, 128],
44-
preshuffle=True,
45-
)
46-
47-
FP8_ACT_CONFIG = FbgemmConfig(
48-
input_dtype=e4m3_dtype,
49-
weight_dtype=torch.int4,
50-
output_dtype=torch.bfloat16,
51-
block_size=[1, 128],
52-
preshuffle=True,
53-
)
54-
55-
FP8_ACT_BMM_CONFIG = FbgemmConfig(
56-
input_dtype=e4m3_dtype,
57-
weight_dtype=torch.int4,
58-
output_dtype=torch.bfloat16,
59-
block_size=[1, 1, 128],
60-
preshuffle=True,
61-
)
62-
63-
else:
64-
BF16_ACT_CONFIG = None
65-
BF16_ACT_BMM_CONFIG = None
66-
FP8_ACT_CONFIG = None
67-
FP8_ACT_BMM_CONFIG = None
30+
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
31+
group_size=128,
32+
packing_format="preshuffled",
33+
)
34+
35+
FP8_ACT_CONFIG = Float8ActivationInt4WeightConfig(
36+
group_size=128,
37+
packing_format="preshuffled",
38+
)
6839

6940

7041
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@@ -90,7 +61,7 @@ def test_linear(self, config):
9061

9162
# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
9263
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
93-
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
64+
@parametrize("bmm_config", [FP8_ACT_CONFIG, BF16_ACT_CONFIG])
9465
def test_bmm(self, bmm_config):
9566
class M(torch.nn.Module):
9667
def __init__(self, weight):

test/dtypes/test_fbgemm_int4.py renamed to test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414

1515
from torchao.quantization import (
16-
FbgemmConfig,
16+
Int4WeightOnlyConfig,
1717
quantize_,
1818
)
1919
from torchao.quantization.utils import compute_error
@@ -26,19 +26,11 @@
2626
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
2727
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2828
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
29-
class TestFbgemmInt4Tensor(TestCase):
29+
class TestInt4Tensor(TestCase):
3030
def setUp(self):
31-
self.config = FbgemmConfig(
32-
input_dtype=torch.bfloat16,
33-
weight_dtype=torch.int4,
34-
output_dtype=torch.bfloat16,
35-
block_size=[1, 128],
36-
)
37-
self.bmm_config = FbgemmConfig(
38-
input_dtype=torch.bfloat16,
39-
weight_dtype=torch.int4,
40-
output_dtype=torch.bfloat16,
41-
block_size=[1, 1, 128],
31+
self.config = Int4WeightOnlyConfig(
32+
group_size=128,
33+
packing_format="plain",
4234
)
4335
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4436

@@ -68,13 +60,9 @@ def test_slice(self):
6860
quantize_(dummy, self.config)
6961
weight1 = dummy.weight.narrow(0, 0, 64)
7062
weight2 = dummy.weight.narrow(1, 0, 128)
71-
self.assertEqual(
72-
weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64)
73-
)
63+
self.assertEqual(weight1._data, dummy.weight._data.narrow(0, 0, 64))
7464
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64))
75-
self.assertEqual(
76-
weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64)
77-
)
65+
self.assertEqual(weight2._data, dummy.weight._data.narrow(1, 0, 64))
7866
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1))
7967

8068
# check for sliced weight, before and after float8 quantization
@@ -100,12 +88,10 @@ def test_slice_and_copy_(self):
10088
param = l.weight
10189
param_data = param.data
10290
param_data = param_data.narrow(0, 0, 512)
103-
assert (
104-
param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr()
105-
)
91+
assert param.data._data.data_ptr() == param_data._data.data_ptr()
10692
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
10793
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
108-
orig_value = param.data.packed_weight[0][0].item()
94+
orig_value = param.data._data[0][0].item()
10995

11096
# dummy_l has random input (shouldn't be 0)
11197
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
@@ -116,7 +102,7 @@ def test_slice_and_copy_(self):
116102
param_data.copy_(quantized)
117103

118104
# making sure param.data is updated
119-
assert param.data.packed_weight[0][0] != orig_value
105+
assert param.data._data[0][0] != orig_value
120106

121107
def test_bmm(self):
122108
class M(torch.nn.Module):
@@ -135,7 +121,7 @@ def forward(self, x):
135121
original = m(input)
136122
# we need to transpose the weight first for bmm
137123
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
138-
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
124+
quantize_(m, self.config, filter_fn=lambda x, fqn: True)
139125
quantized = m(input)
140126
self.assertTrue(compute_error(original, quantized) > 18)
141127

torchao/dtypes/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
to_affine_quantized_intx_static,
1010
)
1111
from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8
12-
from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4
1312
from .floatx import (
1413
CutlassSemiSparseLayout,
1514
Float8Layout,
@@ -64,8 +63,6 @@
6463
"PackedLinearInt8DynamicActivationIntxWeightLayout",
6564
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
6665
"Int4XPULayout",
67-
"to_fbgemm_int4",
68-
"FbgemmInt4Tensor",
6966
"to_fbgemm_fp8",
7067
"FbgemmFp8Tensor",
7168
"Int8DynamicActInt4WeightCPULayout",

torchao/quantization/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .quant_api import (
4545
CutlassInt4PackedLayout,
4646
FbgemmConfig,
47+
Float8ActivationInt4WeightConfig,
4748
Float8DynamicActivationFloat8SemiSparseWeightConfig,
4849
Float8DynamicActivationFloat8WeightConfig,
4950
Float8MMConfig,
@@ -90,6 +91,7 @@
9091
from .quantize_.workflows import (
9192
Float8Tensor,
9293
Int4PreshuffledTensor,
94+
Int4Tensor,
9395
)
9496
from .smoothquant import (
9597
SmoothFakeDynamicallyQuantizedLinear,
@@ -141,6 +143,7 @@
141143
"Int8DynamicActivationInt8WeightConfig",
142144
"Int8DynamicActivationIntxWeightConfig",
143145
"Int4WeightOnlyConfig",
146+
"Float8ActivationInt4WeightConfig",
144147
"Int8WeightOnlyConfig",
145148
"Float8WeightOnlyConfig",
146149
"Float8DynamicActivationFloat8WeightConfig",
@@ -154,6 +157,7 @@
154157
"ModuleFqnToConfig",
155158
"FbgemmConfig",
156159
# tensor subclasses
160+
"Int4Tensor",
157161
"Int4PreshuffledTensor",
158162
"Float8Tensor",
159163
# smooth quant - subject to change

torchao/quantization/quant_api.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
to_affine_quantized_floatx_static,
5050
to_affine_quantized_intx,
5151
to_fbgemm_fp8,
52-
to_fbgemm_int4,
5352
to_marlinqqq_quantized_intx,
5453
)
5554
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
@@ -72,6 +71,8 @@
7271
from torchao.quantization.quantize_.workflows import (
7372
Float8Tensor,
7473
Int4PreshuffledTensor,
74+
Int4Tensor,
75+
PackingFormat,
7576
)
7677
from torchao.quantization.transform_module import (
7778
_QUANTIZE_CONFIG_HANDLER,
@@ -1108,6 +1109,8 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11081109
`zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE]
11091110
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values.
11101111
`preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT
1112+
`packing_format`: the packing format for int4 tensor, this is the new structure we plan to use
1113+
existing layout based structure will use `_legacy`
11111114
"""
11121115

11131116
group_size: int = 128
@@ -1116,6 +1119,9 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11161119
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE
11171120
set_inductor_config: bool = True
11181121
preserve_zero: Optional[bool] = None
1122+
# since not all tensors are migrated to the new structure yet,
1123+
# we use `_legacy' to represent the previous layout
1124+
packing_format: PackingFormat = "_legacy"
11191125

11201126

11211127
# for BC
@@ -1133,15 +1139,33 @@ def _int4_weight_only_quantize_tensor(weight, config):
11331139
layout = config.layout
11341140
use_hqq = config.use_hqq
11351141
zero_point_domain = config.zero_point_domain
1142+
packing_format = config.packing_format
11361143

11371144
if weight.shape[-1] % group_size != 0:
11381145
logger.info(
11391146
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
11401147
)
11411148
return weight
11421149

1150+
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
1151+
1152+
if packing_format == "preshuffled":
1153+
new_weight = Int4PreshuffledTensor.from_float(
1154+
weight,
1155+
block_size,
1156+
activation_dtype=torch.bfloat16,
1157+
)
1158+
return new_weight
1159+
elif packing_format == "plain":
1160+
new_weight = Int4Tensor.from_float(
1161+
weight,
1162+
block_size,
1163+
)
1164+
return new_weight
1165+
1166+
assert packing_format == "_legacy", f"Unsupported packing_format: {packing_format}"
1167+
11431168
mapping_type = MappingType.ASYMMETRIC
1144-
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size])
11451169
target_dtype = torch.int32
11461170
quant_min = 0
11471171
quant_max = 15
@@ -1213,6 +1237,38 @@ def _int4_weight_only_transform(
12131237
return module
12141238

12151239

1240+
@dataclass
1241+
class Float8ActivationInt4WeightConfig(AOBaseConfig):
1242+
group_size: int = 128
1243+
packing_format: PackingFormat = "preshuffled"
1244+
1245+
1246+
@register_quantize_module_handler(Float8ActivationInt4WeightConfig)
1247+
def _float8_activation_int4_weight_transform(
1248+
module: torch.nn.Module, config: Float8ActivationInt4WeightConfig
1249+
) -> torch.nn.Module:
1250+
assert hasattr(module, "weight"), (
1251+
"applying int8 weight only quant requires module to have weight attribute"
1252+
+ " but {module} does not have one"
1253+
)
1254+
group_size = config.group_size
1255+
packing_format = config.packing_format
1256+
1257+
assert packing_format == "preshuffled", (
1258+
f"only preshuffled packing_format supported right now, got: {packing_format}"
1259+
)
1260+
weight = module.weight
1261+
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
1262+
new_weight = Int4PreshuffledTensor.from_float(
1263+
module.weight,
1264+
block_size,
1265+
activation_dtype=torch.float8_e4m3fn,
1266+
)
1267+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1268+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1269+
return module
1270+
1271+
12161272
@dataclass
12171273
class Int8WeightOnlyConfig(AOBaseConfig):
12181274
"""
@@ -1610,6 +1666,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
16101666
granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None
16111667
mm_config: Optional[Float8MMConfig] = None
16121668
activation_scale_ub: float = 1200.0
1669+
packing_format: PackingFormat = "plain"
16131670
set_inductor_config: bool = True
16141671

16151672
def __post_init__(self):
@@ -1631,8 +1688,13 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
16311688
weight_dtype = config.weight_dtype
16321689
granularity = config.granularity
16331690
mm_config = config.mm_config
1691+
packing_format = config.packing_format
16341692
activation_scale_ub = config.activation_scale_ub
16351693

1694+
assert packing_format == "plain", (
1695+
f"Only plain packing_format is supported, got {packing_format}"
1696+
)
1697+
16361698
# Ensure works on device
16371699
_check_hardware_support(granularity)
16381700
activation_granularity, weight_granularity = granularity
@@ -1641,6 +1703,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
16411703
# TODO(future PR): this should really throw an exception instead of silently
16421704
# not doing what the user asked
16431705
return weight
1706+
16441707
if isinstance(weight_granularity, PerRow):
16451708
assert weight.dtype == torch.bfloat16, (
16461709
"PerRow quantization only works for bfloat16 precision input weight"
@@ -2083,7 +2146,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20832146
activation_dtype=torch.bfloat16,
20842147
)
20852148
else:
2086-
weight = to_fbgemm_int4(
2149+
weight = Int4Tensor.from_float(
20872150
module.weight,
20882151
config.block_size,
20892152
)

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,16 @@
44
from .int4.int4_preshuffled_tensor import (
55
Int4PreshuffledTensor,
66
)
7+
from .int4.int4_tensor import (
8+
Int4Tensor,
9+
)
10+
from .packing_format import (
11+
PackingFormat,
12+
)
713

814
__all__ = [
15+
"Int4Tensor",
916
"Int4PreshuffledTensor",
1017
"Float8Tensor",
18+
"PackingFormat",
1119
]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .int4_preshuffled_tensor import Int4PreshuffledTensor
2+
from .int4_tensor import Int4Tensor
3+
4+
__all__ = [
5+
"Int4PreshuffledTensor",
6+
"Int4Tensor",
7+
]

torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def from_float(
194194
if quantize_int4_preshuffle is None:
195195
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")
196196

197-
assert all(x == 1 for x in block_size[:-1]), (
197+
assert all(x == 1 for x in block_size[:-1]) and block_size[-1] != 1, (
198198
"Only groupwise quant is supported right now"
199199
)
200200
group_size = block_size[-1]

0 commit comments

Comments
 (0)