Skip to content

Commit b2b95b7

Browse files
authored
[Quantization] Update group quantization (#336)
* update * fix conditions
1 parent c27a22b commit b2b95b7

File tree

1 file changed

+35
-24
lines changed
  • src/compressed_tensors/quantization/lifecycle

1 file changed

+35
-24
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -227,31 +227,42 @@ def _process_quantization(
227227
perm = torch.argsort(g_idx)
228228
x = safe_permute(x, perm, dim=1)
229229

230-
# TODO: experiment with vectorizing for loop for performance
231-
end = 0
232-
for index, group_count in enumerate(group_sizes):
233-
sc = scale[:, index].view(-1, 1)
234-
zp = zero_point[:, index].view(-1, 1) if zero_point is not None else None
235-
236-
start = end
237-
end = start + group_count
238-
if do_quantize:
239-
output[:, start:end] = _quantize(
240-
x=x[:, start:end],
241-
scale=sc,
242-
zero_point=zp,
243-
q_min=q_min,
244-
q_max=q_max,
245-
args=args,
246-
dtype=dtype,
247-
global_scale=global_scale,
248-
)
230+
x = torch.reshape(
231+
x,
232+
(
233+
x.shape[0],
234+
ceil(x.shape[1] / group_size),
235+
group_size,
236+
),
237+
)
249238

250-
if do_dequantize:
251-
input = output[:, start:end] if do_quantize else x[:, start:end]
252-
output[:, start:end] = _dequantize(
253-
x_q=input, scale=sc, zero_point=zp, global_scale=global_scale
254-
)
239+
if do_quantize:
240+
output = _quantize(
241+
x=x,
242+
scale=scale.unsqueeze(-1),
243+
zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
244+
dtype=dtype,
245+
global_scale=global_scale,
246+
q_min=q_min,
247+
q_max=q_max,
248+
args=args,
249+
)
250+
251+
if do_dequantize:
252+
input = output if do_quantize else x
253+
output = _dequantize(
254+
x_q=input,
255+
scale=scale.unsqueeze(-1),
256+
zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
257+
global_scale=global_scale,
258+
)
259+
260+
output = torch.reshape(
261+
output,
262+
(output.shape[0], output.shape[1] * output.shape[2]),
263+
)
264+
265+
output = output.to(output_dtype)
255266

256267
if not is_column_order:
257268
output = safe_permute(output, torch.argsort(perm), dim=1)

0 commit comments

Comments
 (0)