Skip to content

Commit 0640474

Browse files
authored
Fix slicing and get_plain() in GemLite (#2288)
* fix get_plain() with FMA mode * update * fix in_features/out_feature meta-data mismatch * update gemlite slice test * add packing_bitwidth support * add packing_bitwidth support and cleanup * update default gemlite layout * cleanup * fix symmetric use-case and relax _same_meta_data * _copy() meta data * fix (4,) in autoquant
1 parent 72ea2fc commit 0640474

File tree

4 files changed

+178
-31
lines changed

4 files changed

+178
-31
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,12 +371,81 @@ def test_slice_gemlite(self, device, dtype):
371371
# in_feature not divisible by 1024
372372
# out_feature not divisible by 8
373373
# to test slice + padding for int4 weight only quantization
374-
dummy = nn.Linear(256, 512, dtype=dtype, device=device)
375-
quantize_(dummy, GemliteUIntXWeightOnlyConfig())
374+
in_features, out_features, group_size, bit_width = 256, 512, 64, 4
375+
orig_shape = [out_features, in_features]
376+
dummy = nn.Linear(
377+
in_features, out_features, bias=False, dtype=dtype, device=device
378+
)
379+
quantize_(
380+
dummy,
381+
GemliteUIntXWeightOnlyConfig(bit_width=bit_width, group_size=group_size),
382+
)
383+
W_group_mode = dummy.weight.tensor_impl.gemlite_kwargs["meta_args"][10]
384+
376385
# make sure these run without error
377386
_ = dummy.weight.narrow(0, 0, 64)
378387
_ = dummy.weight.narrow(1, 0, 128)
379388

389+
# Dequant op
390+
import gemlite
391+
392+
def dequant(input_layer, in_features, orig_shape):
393+
int_data = input_layer.tensor_impl.packed_weight
394+
scale = input_layer.tensor_impl.scale
395+
zero_point = input_layer.tensor_impl.zero_point
396+
397+
W_q = (
398+
gemlite.bitpack.unpack_over_rows(
399+
int_data,
400+
W_nbits=bit_width,
401+
num_output_rows=in_features,
402+
dtype=torch.uint8,
403+
)
404+
.T.contiguous()
405+
.view([-1, group_size])
406+
)
407+
408+
s = scale.t().contiguous().view(-1, 1)
409+
z = zero_point.t().contiguous().view(-1, 1)
410+
411+
if W_group_mode == 4: # FMA
412+
W_deq = (W_q * s + z).view(orig_shape)
413+
else:
414+
W_deq = ((W_q - z) * s).view(orig_shape)
415+
416+
return W_deq
417+
418+
W_r = dequant(dummy.weight, dummy.in_features, orig_shape)
419+
420+
# Slicing in half
421+
for slice_axis, start, end in [
422+
(0, 0, 256),
423+
(0, 256, 256),
424+
(1, 0, 128),
425+
(1, 128, 128),
426+
]:
427+
layer_sliced = dummy.weight.narrow(slice_axis, start, end)
428+
429+
if slice_axis == 0:
430+
num_rows, out_shape = (
431+
dummy.in_features,
432+
(orig_shape[0] // 2, orig_shape[1]),
433+
)
434+
else:
435+
num_rows, out_shape = (
436+
dummy.in_features // 2,
437+
(orig_shape[0], orig_shape[1] // 2),
438+
)
439+
440+
W_slice = dequant(layer_sliced, num_rows, out_shape)
441+
442+
W_slice_ref = (
443+
W_r[start : start + end, :]
444+
if slice_axis == 0
445+
else W_r[:, start : start + end]
446+
)
447+
self.assertEqual((W_slice_ref - W_slice).abs().mean().item(), 0)
448+
380449
@common_utils.parametrize("device", ["cuda"])
381450
@common_utils.parametrize("dtype", [torch.bfloat16])
382451
def test_matmul(self, device, dtype):

torchao/dtypes/uintx/gemlite_layout.py

Lines changed: 98 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
except:
2626
gemlite = None
2727

28-
2928
aten = torch.ops.aten
3029

3130

@@ -35,7 +34,12 @@ def _same_metadata(
3534
) -> bool:
3635
kwargs_match = len(self.gemlite_kwargs) == len(src.gemlite_kwargs)
3736
for k, v in self.gemlite_kwargs.items():
38-
if k != "scale_activations":
37+
if k in [
38+
"in_features",
39+
"out_features",
40+
"packing_bitwidth",
41+
"elements_per_sample",
42+
]:
3943
kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k])
4044

4145
return (
@@ -80,6 +84,7 @@ def get_gemlite_aqt_kwargs(
8084
weight,
8185
group_size=64,
8286
bit_width=4,
87+
packing_bitwidth=None,
8388
use_hqq=True,
8489
):
8590
if gemlite is None:
@@ -99,6 +104,9 @@ def get_gemlite_aqt_kwargs(
99104
assert group_size is None or bit_width != 8, (
100105
"gemlite only works with group_size=None for bit_width=8"
101106
)
107+
assert packing_bitwidth in [8, 16, 32, None], (
108+
f"Invalid packing bitwidth, got {packing_bitwidth}"
109+
)
102110

103111
out_features, in_features = weight.shape
104112
group_size = in_features if group_size is None else group_size
@@ -107,15 +115,17 @@ def get_gemlite_aqt_kwargs(
107115
aqt_kwargs["_layout"] = GemlitePackedLayout(
108116
group_size=group_size,
109117
bit_width=bit_width,
118+
packing_bitwidth=packing_bitwidth,
110119
)
111120
aqt_kwargs["use_hqq"] = use_hqq
112121
return aqt_kwargs
113122

114123

115124
@dataclass(frozen=True)
116125
class GemlitePackedLayout(Layout):
117-
group_size: Optional[int] = 64
126+
group_size: Optional[int] = 128
118127
bit_width: int = 4
128+
packing_bitwidth: Optional[int] = None
119129

120130

121131
@register_layout(GemlitePackedLayout)
@@ -191,24 +201,36 @@ def from_plain(
191201

192202
group_size, bit_width = _layout.group_size, _layout.bit_width
193203
out_features, in_features = int_data.shape
204+
packing_bitwidth = _layout.packing_bitwidth
194205

195206
if bit_width == 8 and group_size == in_features:
196207
gemlite_linear = gemlite.helper.A16W8(device=int_data.device).from_weights(
197208
int_data, scales=scale, bias=None
198209
)
199210
else:
200-
gemlite_linear = gemlite.helper.A16Wn(device=int_data.device).from_weights(
211+
gemlite_linear = gemlite.helper.A16Wn(
212+
device=int_data.device, packing_bitwidth=packing_bitwidth
213+
).from_weights(
201214
int_data, scale, zero_point, bit_width, group_size, bias=None
202215
)
203216

217+
meta_args = gemlite_linear.get_meta_args()
204218
gemlite_kwargs = {
205219
"in_features": in_features,
206220
"out_features": out_features,
207-
"meta_args": gemlite_linear.get_meta_args(),
221+
"packing_bitwidth": packing_bitwidth,
222+
"data_contiguous": gemlite_linear.data_contiguous,
223+
"elements_per_sample": gemlite_linear.elements_per_sample,
224+
"W_group_mode": gemlite_linear.W_group_mode,
225+
"meta_args": meta_args,
208226
}
209227

210228
packed_weight, scale, zero_point = gemlite_linear.get_tensor_args()
211229
packed_weight = packed_weight.to(device)
230+
if zero_point is None:
231+
zero_point = torch.tensor(
232+
[[]], device=packed_weight.device, dtype=torch.int32
233+
)
212234

213235
return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout)
214236

@@ -235,18 +257,39 @@ def _apply_fn_to_data(self, fn):
235257
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
236258
device = self.packed_weight.device
237259
int_data = (
238-
gemlite.bitpack.unpack_over_rows(
239-
self.packed_weight.cuda(),
240-
W_nbits=self._layout.bit_width,
241-
num_output_rows=self.gemlite_kwargs["out_features"],
242-
dtype=torch.uint8,
260+
(
261+
gemlite.bitpack.unpack_over_rows(
262+
self.packed_weight.cuda(),
263+
W_nbits=self._layout.bit_width,
264+
num_output_rows=self.gemlite_kwargs["in_features"],
265+
dtype=torch.uint8,
266+
)
243267
)
268+
.to(device)
244269
.t()
245-
.contiguous()
246-
).to(device)
270+
)
271+
272+
# Preserve col-row major layout
273+
if self.gemlite_kwargs["data_contiguous"]:
274+
int_data = int_data.contiguous()
275+
276+
# Handle FMA mode: W_q * s + z -> (W_q - z) * s
277+
if self.gemlite_kwargs["W_group_mode"] == 4:
278+
scale_min_val = 1e-8
279+
scale = self.scale.clone().float()
280+
scale[torch.logical_and(scale >= 0, scale.abs() <= scale_min_val)] = (
281+
scale_min_val
282+
)
283+
scale[
284+
torch.logical_and(scale < 0, scale.abs() <= scale_min_val)
285+
] = -scale_min_val
286+
zero_point = (-self.zero_point.float() / scale).clamp_(-100, 100)
287+
zero_point = zero_point.to(self.scale.dtype)
288+
else:
289+
zero_point = self.zero_point
247290

248291
scale = self.scale.t().contiguous()
249-
zero_point = self.zero_point.t().contiguous()
292+
zero_point = zero_point.t().contiguous()
250293

251294
return int_data, scale, zero_point
252295

@@ -274,30 +317,47 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
274317
assert step == 1, "Only step == 1 is supported in slicing right now"
275318

276319
if dim in [0, 1]:
277-
int_data, scale, zero_point = self.get_plain()
278-
data_len = int_data.shape[dim]
320+
# data in self is transposed, meaning forward() performs x @ W_deq not x @ W_deq.T
321+
dim = 1 - dim
322+
packed_weight = self.packed_weight
323+
scale = self.scale
324+
zero_point = self.zero_point
325+
326+
gemlite_kwargs = self.gemlite_kwargs.copy()
327+
orig_shape = [
328+
gemlite_kwargs["in_features"],
329+
gemlite_kwargs["out_features"],
330+
]
331+
elements_per_sample = gemlite_kwargs["elements_per_sample"]
332+
data_len = orig_shape[dim]
279333
scale_len = scale.shape[dim]
280334
ratio = data_len / scale_len
281335
start_scale = int(start / ratio)
282336
end_scale = int(end / ratio)
283337

284-
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
338+
# For packing only the K dimension. This should be flipped for N-dim packing.
339+
div = elements_per_sample if dim == 0 else 1
340+
packed_weight = aten.slice.Tensor(
341+
packed_weight, dim, start // div, end // div, step
342+
)
343+
344+
# Update in_features/out_features
345+
gemlite_kwargs["in_features"] = (
346+
packed_weight.shape[0] * elements_per_sample
347+
)
348+
gemlite_kwargs["out_features"] = packed_weight.shape[1]
349+
285350
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
286351
if zero_point is not None and zero_point.numel() > 0:
287352
zero_point = aten.slice.Tensor(
288353
zero_point, dim, start_scale, end_scale, step
289354
)
290355
else:
291356
zero_point = None
292-
# this is to handle padding
293-
int_data, scale, zero_point = self._layout.post_process(
294-
int_data, scale, zero_point, self.block_size
295-
)
296-
297-
sliced = self.from_plain(
298-
int_data, scale, zero_point, self._layout
299-
) # Will be transposed again
300357

358+
sliced = GemliteAQTTensorImpl(
359+
packed_weight, scale, zero_point, gemlite_kwargs, self._layout
360+
)
301361
return return_and_correct_aliasing(func, args, kwargs, sliced)
302362

303363
else:
@@ -308,10 +368,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
308368
elif func is aten.copy_.default:
309369
self = args[0]
310370
src = args[1]
371+
372+
# Handle zero_point = None with symmetric quant
373+
if self.zero_point is None:
374+
self.zero_point = torch.tensor(
375+
[[]], device=self.packed_weight.device, dtype=torch.int32
376+
)
377+
378+
if src.zero_point is None:
379+
src.zero_point = torch.tensor(
380+
[[]], device=src.packed_weight.device, dtype=torch.int32
381+
)
382+
311383
if _same_metadata(self, src):
312384
self_tensors = self.__tensor_flatten__()[0]
313385
for tensor_name in self_tensors:
314386
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
387+
for key in self.gemlite_kwargs:
388+
self.gemlite_kwargs[key] = src.gemlite_kwargs[key]
315389
return
316390
raise ValueError(
317391
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"

torchao/quantization/autoquant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -741,11 +741,11 @@ def from_float(cls, weight):
741741
weight = weight.to(torch.float16)
742742

743743
bit_width = 4
744-
packing_bitwidth = 32
745-
contiguous = None
744+
packing_bitwidth = None
746745
use_hqq = True
746+
747747
aqt_kwargs = get_gemlite_aqt_kwargs(
748-
weight, cls.group_size, bit_width, packing_bitwidth, contiguous, use_hqq
748+
weight, cls.group_size, bit_width, packing_bitwidth, use_hqq
749749
)
750750
weight = to_affine_quantized_intx(weight, **aqt_kwargs)
751751
input_quant_func = _to_float16

torchao/quantization/quant_api.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -990,8 +990,9 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
990990
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values.
991991
"""
992992

993-
group_size: Optional[int] = 64
993+
group_size: Optional[int] = 128
994994
bit_width: int = 4
995+
packing_bitwidth: Optional[int] = None
995996
set_inductor_config: bool = True
996997

997998

@@ -1005,6 +1006,7 @@ def _gemlite_uintx_weight_only_transform(
10051006
):
10061007
group_size = config.group_size
10071008
bit_width = config.bit_width
1009+
packing_bitwidth = config.packing_bitwidth
10081010
if config.set_inductor_config:
10091011
torchao.quantization.utils.recommended_inductor_config_setter()
10101012

@@ -1015,7 +1017,9 @@ def _gemlite_uintx_weight_only_transform(
10151017
use_hqq = True if bit_width == 4 else False
10161018
new_weight = to_affine_quantized_intx(
10171019
weight,
1018-
**get_gemlite_aqt_kwargs(weight, group_size, bit_width, use_hqq),
1020+
**get_gemlite_aqt_kwargs(
1021+
weight, group_size, bit_width, packing_bitwidth, use_hqq
1022+
),
10191023
)
10201024
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
10211025
module.extra_repr = types.MethodType(_linear_extra_repr, module)

0 commit comments

Comments
 (0)