Skip to content

Commit 0afa4c1

Browse files
authored
Add dynamic quantization support to gemlite layout (#2327)
* fix get_plain() with FMA mode * update * fix in_features/out_feature meta-data mismatch * update gemlite slice test * add packing_bitwidth support * add packing_bitwidth support and cleanup * update default gemlite layout * cleanup * fix symmetric use-case and relax _same_meta_data * _copy() meta data * fix (4,) in autoquant * Add dynamic mode in gemlite layout * mode explanation Signed-off-by: mobicham <hicham@mobiuslabs.com> * use weights_only instead of static --------- Signed-off-by: mobicham <hicham@mobiuslabs.com>
1 parent 03c850a commit 0afa4c1

File tree

3 files changed

+36
-5
lines changed

3 files changed

+36
-5
lines changed

torchao/dtypes/uintx/gemlite_layout.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def get_gemlite_aqt_kwargs(
8585
group_size=64,
8686
bit_width=4,
8787
packing_bitwidth=None,
88+
mode="weight_only",
8889
use_hqq=True,
8990
):
9091
if gemlite is None:
@@ -108,6 +109,10 @@ def get_gemlite_aqt_kwargs(
108109
f"Invalid packing bitwidth, got {packing_bitwidth}"
109110
)
110111

112+
assert mode in ["weight_only", "dynamic"], (
113+
f"Invalid mode: should be either weight_only or dynamic, got {mode}"
114+
)
115+
111116
out_features, in_features = weight.shape
112117
group_size = in_features if group_size is None else group_size
113118

@@ -116,6 +121,7 @@ def get_gemlite_aqt_kwargs(
116121
group_size=group_size,
117122
bit_width=bit_width,
118123
packing_bitwidth=packing_bitwidth,
124+
mode=mode,
119125
)
120126
aqt_kwargs["use_hqq"] = use_hqq
121127
return aqt_kwargs
@@ -126,6 +132,7 @@ class GemlitePackedLayout(Layout):
126132
group_size: Optional[int] = 128
127133
bit_width: int = 4
128134
packing_bitwidth: Optional[int] = None
135+
mode: Optional[str] = "weight_only"
129136

130137

131138
@register_layout(GemlitePackedLayout)
@@ -202,13 +209,24 @@ def from_plain(
202209
group_size, bit_width = _layout.group_size, _layout.bit_width
203210
out_features, in_features = int_data.shape
204211
packing_bitwidth = _layout.packing_bitwidth
212+
mode = _layout.mode
205213

206214
if bit_width == 8 and group_size == in_features:
207-
gemlite_linear = gemlite.helper.A16W8(device=int_data.device).from_weights(
215+
processor = (
216+
gemlite.helper.A8W8_int8_dynamic
217+
if mode == "dynamic"
218+
else gemlite.helper.A16W8
219+
)
220+
gemlite_linear = processor(device=int_data.device).from_weights(
208221
int_data, scales=scale, bias=None
209222
)
210223
else:
211-
gemlite_linear = gemlite.helper.A16Wn(
224+
processor = (
225+
gemlite.helper.A8Wn_dynamic
226+
if mode == "dynamic"
227+
else gemlite.helper.A16Wn
228+
)
229+
gemlite_linear = processor(
212230
device=int_data.device, packing_bitwidth=packing_bitwidth
213231
).from_weights(
214232
int_data, scale, zero_point, bit_width, group_size, bias=None

torchao/quantization/autoquant.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,10 +742,16 @@ def from_float(cls, weight):
742742

743743
bit_width = 4
744744
packing_bitwidth = None
745+
mode = "weight_only"
745746
use_hqq = True
746747

747748
aqt_kwargs = get_gemlite_aqt_kwargs(
748-
weight, cls.group_size, bit_width, packing_bitwidth, use_hqq
749+
weight,
750+
group_size=cls.group_size,
751+
bit_width=bit_width,
752+
packing_bitwidth=packing_bitwidth,
753+
mode=mode,
754+
use_hqq=use_hqq,
749755
)
750756
weight = to_affine_quantized_intx(weight, **aqt_kwargs)
751757
input_quant_func = _to_float16

torchao/quantization/quant_api.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -986,13 +986,14 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
986986
size is more fine grained
987987
`bit_width`: bit width of the quantized weight.
988988
`packing_bitwidth`: bit width of the packed weight, should be 8 or 32. Can have performance impacts depending on hardware.
989-
`contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice.
989+
`mode`: if set to "dynamic", activations are quantized at runtime; default is "weight_only" (weight-only quantization).
990990
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values.
991991
"""
992992

993993
group_size: Optional[int] = 128
994994
bit_width: int = 4
995995
packing_bitwidth: Optional[int] = None
996+
mode: Optional[str] = "weight_only"
996997
set_inductor_config: bool = True
997998

998999

@@ -1007,6 +1008,7 @@ def _gemlite_uintx_weight_only_transform(
10071008
group_size = config.group_size
10081009
bit_width = config.bit_width
10091010
packing_bitwidth = config.packing_bitwidth
1011+
mode = config.mode
10101012
if config.set_inductor_config:
10111013
torchao.quantization.utils.recommended_inductor_config_setter()
10121014

@@ -1018,7 +1020,12 @@ def _gemlite_uintx_weight_only_transform(
10181020
new_weight = to_affine_quantized_intx(
10191021
weight,
10201022
**get_gemlite_aqt_kwargs(
1021-
weight, group_size, bit_width, packing_bitwidth, use_hqq
1023+
weight,
1024+
group_size=group_size,
1025+
bit_width=bit_width,
1026+
packing_bitwidth=packing_bitwidth,
1027+
mode=mode,
1028+
use_hqq=use_hqq,
10221029
),
10231030
)
10241031
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)

0 commit comments

Comments
 (0)