Skip to content

Commit 7c7520f

Browse files
committed
Reference representation of dqlinear int4 for xnnpack
Summary: This diff adds dynamic quantized linear's integer arithmetic representation. This is quite close to how arithmetic is done in xnnpack. Basic tests added against q/dq to make things are sane. Followups: - See if such a graph is traceable. - Optimize implementation if needed Test Plan: added Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: deb3efa Pull Request resolved: #2520
1 parent 58cb352 commit 7c7520f

File tree

2 files changed

+691
-1
lines changed

2 files changed

+691
-1
lines changed

torchao/quantization/pt2e/reference_representation_rewrite.py

Lines changed: 253 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import contextlib
99
from dataclasses import dataclass
1010
from functools import partial
11-
from typing import Any, Callable, Optional
11+
from typing import Any, Callable, List, Optional
1212

1313
import torch
1414
from torch._higher_order_ops.out_dtype import out_dtype
@@ -23,12 +23,17 @@
2323
_replace_literals_with_new_placeholders,
2424
remove_tensor_overload_for_qdq_ops,
2525
)
26+
from torchao.quantization.quant_primitives import MappingType
27+
from torchao.quantization.utils import _get_per_token_block_size
28+
from torchao.utils import _register_custom_op
2629

2730
try:
2831
from torch._export.utils import _disable_aten_to_metadata_assertions
2932
except:
3033
_disable_aten_to_metadata_assertions = contextlib.nullcontext
3134

35+
quant_lib = torch.library.Library("torchao", "FRAGMENT")
36+
register_custom_op = _register_custom_op(quant_lib)
3237

3338
__all__ = [
3439
"reference_representation_rewrite",
@@ -203,6 +208,221 @@ def _reference_dynamic_quantized_linear(
203208
return out_fp32
204209

205210

211+
def _qdq_dynamic_quantized_linear_4bit_groupwise(
212+
x_fp32,
213+
x_eps,
214+
weight_i4,
215+
weight_scale,
216+
weight_zero_point,
217+
bias_fp32,
218+
group_size,
219+
):
220+
# Dynamic quantization of activation
221+
x_mapping_type = MappingType.ASYMMETRIC
222+
per_token_block_size = _get_per_token_block_size(x_fp32)
223+
x_quant_min = -128
224+
x_quant_max = 127
225+
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
226+
x_fp32,
227+
x_mapping_type.name,
228+
per_token_block_size,
229+
torch.int8,
230+
x_quant_min,
231+
x_quant_max,
232+
x_eps,
233+
torch.float32,
234+
torch.int32,
235+
)
236+
x_i8 = torch.ops.torchao.quantize_affine(
237+
x_fp32,
238+
per_token_block_size,
239+
x_scale,
240+
x_zero_point,
241+
torch.int8,
242+
x_quant_min,
243+
x_quant_max,
244+
)
245+
x_fp32 = torch.ops.torchao.dequantize_affine(
246+
x_i8,
247+
per_token_block_size,
248+
x_scale,
249+
x_zero_point,
250+
torch.int8,
251+
x_quant_min,
252+
x_quant_max,
253+
torch.float32,
254+
)
255+
256+
assert group_size > 0, "Group size must be positive"
257+
assert weight_i4.shape[1] % group_size == 0, (
258+
"Weight must be divisible by group_size"
259+
)
260+
assert weight_i4.dim() == 2, "Weight must be 2D tensor"
261+
block_size = (1, group_size)
262+
weight_fp32 = torch.ops.torchao.dequantize_affine(
263+
weight_i4,
264+
block_size,
265+
weight_scale,
266+
weight_zero_point,
267+
torch.int8,
268+
-8,
269+
7,
270+
)
271+
272+
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
273+
return out_fp32
274+
275+
276+
@register_custom_op
277+
def _reference_dqlinear_int4(
278+
x_fp32: torch.Tensor,
279+
x_eps: float,
280+
weight_i4: torch.Tensor,
281+
weight_scale: torch.Tensor,
282+
weight_zero_point: torch.Tensor, # Not used because assuming weight is symmetric
283+
bias_fp32: Optional[torch.Tensor],
284+
group_size: List[int],
285+
) -> torch.Tensor:
286+
"""
287+
Reference implementation for dynamically quantized linear 4-bit groupwise operation.
288+
This implementation emulates actual numerics of on-device integer compute.
289+
290+
Args:
291+
x_fp32: Input activation tensor in fp32
292+
x_eps: Epsilon for quantization parameter computation
293+
weight_i4: 4-bit quantized weight (stored as int8 with values in [-8, 7])
294+
weight_scale: Groupwise scales for weight dequantization
295+
weight_zero_point: Groupwise zero points for weight (unused for symmetric)
296+
bias_fp32: Optional bias tensor in fp32
297+
group_size: Size of each group for groupwise quantization
298+
299+
Returns:
300+
Output tensor in fp32
301+
"""
302+
# Dynamic quantization of activation
303+
group_size = group_size[1]
304+
x_mapping_type = MappingType.ASYMMETRIC
305+
per_token_block_size = _get_per_token_block_size(x_fp32)
306+
x_quant_min = -128
307+
x_quant_max = 127
308+
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
309+
x_fp32,
310+
x_mapping_type.name,
311+
per_token_block_size,
312+
torch.int8,
313+
x_quant_min,
314+
x_quant_max,
315+
x_eps,
316+
torch.float32,
317+
torch.int32,
318+
)
319+
x_i8 = torch.ops.torchao.quantize_affine(
320+
x_fp32,
321+
per_token_block_size,
322+
x_scale,
323+
x_zero_point,
324+
torch.int8,
325+
x_quant_min,
326+
x_quant_max,
327+
)
328+
329+
# For groupwise quantization, we need to handle the computation differently
330+
# weight_i4 shape: [out_features, in_features]
331+
# weight_scale shape: [out_features, in_features // group_size]
332+
# weight_zero_point shape: [out_features, in_features // group_size]
333+
out_features, in_features = weight_i4.shape
334+
num_groups = in_features // group_size
335+
336+
# scales in xnnpack are stored as bf16 and converted to fp32 for computation
337+
weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32)
338+
339+
assert x_i8.dim() == 2, "x_i8 must be 2D tensor"
340+
# Reshape for group-wise processing
341+
# x: [batch_size, in_features] -> [batch_size, num_groups, group_size]
342+
batch_size = x_i8.shape[0]
343+
x_i8_grouped = x_i8.view(batch_size, num_groups, group_size)
344+
345+
# weight: [out_features, in_features] -> [out_features, num_groups, group_size]
346+
weight_i4_grouped = weight_i4.view(out_features, num_groups, group_size)
347+
348+
# Convert to int16 for computation
349+
x_i32_grouped = x_i8_grouped.to(torch.int32)
350+
weight_i32_grouped = weight_i4_grouped.to(torch.int32)
351+
352+
# Perform groupwise integer linear operation
353+
acc_fp32 = torch.zeros(
354+
batch_size, out_features, dtype=torch.float32, device=x_fp32.device
355+
)
356+
357+
if weight_scale.ndim == 1:
358+
weight_scale = weight_scale.unsqueeze(0)
359+
360+
for group_idx in range(num_groups):
361+
# Extract current group
362+
x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size]
363+
weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size]
364+
weight_group_col_sum = weight_group.sum(dim=-1) # [out_features]
365+
366+
# Get scale for this group
367+
weight_scale_group = weight_scale[:, group_idx] # [out_features]
368+
369+
# Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features]
370+
group_acc = out_dtype(
371+
torch.ops.aten.linear.default,
372+
torch.int32,
373+
x_group,
374+
weight_group,
375+
None,
376+
)
377+
378+
# Output has to be scaled by x_scale * weight_scale_group
379+
# However we will first scale by weight_scale_group, that is accounting
380+
# only for scale of weight, and then scale by x_scale at the end because
381+
# x_scale applies to all groups
382+
acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view(
383+
1, -1
384+
)
385+
386+
# we must also subtract x_zero_point * weight_group_sum
387+
# since (X - x_zero_point) * W = X * W - x_zero_point * W
388+
weights_col_sum_adjusted = (
389+
weight_group_col_sum.to(torch.float32).view(1, -1)
390+
* x_zero_point.view(-1, 1)
391+
* weight_scale_group.view(1, -1)
392+
)
393+
acc_fp32 = acc_fp32 - weights_col_sum_adjusted
394+
x_scale_multiplier = x_scale.view(-1, 1)
395+
out_fp32 = acc_fp32 * x_scale_multiplier
396+
if bias_fp32 is not None:
397+
out_fp32 = out_fp32 + bias_fp32
398+
399+
return out_fp32
400+
401+
402+
def _reference_dynamic_quantized_linear_4bit_groupwise(
403+
x_fp32,
404+
x_eps,
405+
weight_i4,
406+
weight_scale,
407+
weight_zero_point, # Not used because assuming weight is symmetric
408+
bias_fp32,
409+
group_size,
410+
):
411+
"""
412+
Reference implementation for dynamically quantized linear 4-bit groupwise operation.
413+
This function now delegates to the custom op implementation.
414+
"""
415+
return torch.ops.torchao.reference_dqlinear_int4(
416+
x_fp32,
417+
x_eps,
418+
weight_i4,
419+
weight_scale,
420+
weight_zero_point,
421+
bias_fp32,
422+
(1, group_size),
423+
)
424+
425+
206426
def _qdq_quantized_conv2d(
207427
x_i8,
208428
x_scale,
@@ -739,6 +959,18 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
739959
127,
740960
)
741961

962+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS = (
963+
torch.randn((2, 32), dtype=torch.float), # x_fp32
964+
torch.finfo(torch.float32).eps, # x_eps
965+
torch.randint(-8, 7, (8, 32), dtype=torch.int8), # weight_i4 (stored as int8)
966+
torch.randn(8, 4, dtype=torch.float), # weight_scale [out_features, num_groups]
967+
torch.zeros(
968+
8, 4, dtype=torch.int
969+
), # weight_zero_point [out_features, num_groups]
970+
torch.randn(8, dtype=torch.float), # bias_fp32
971+
8, # group_size
972+
)
973+
742974
_REWRITE_INFO_LIST = [
743975
_RewriteInfo(
744976
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
@@ -753,6 +985,26 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
753985
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
754986
),
755987
),
988+
_RewriteInfo(
989+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS,
990+
WrapperModule(_qdq_dynamic_quantized_linear_4bit_groupwise),
991+
WrapperModule(_reference_dynamic_quantized_linear_4bit_groupwise),
992+
partial(
993+
_replace_literals_with_existing_placeholders,
994+
literal_to_ph_idx={
995+
torch.finfo(torch.float32).eps: 1,
996+
(1, 8): 6,
997+
},
998+
),
999+
partial(
1000+
_replace_literals_with_existing_placeholders,
1001+
literal_to_ph_idx={
1002+
torch.finfo(torch.float32).eps: 1,
1003+
(1, 8): 6,
1004+
},
1005+
),
1006+
ignore_literals=True,
1007+
),
7561008
_RewriteInfo(
7571009
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
7581010
WrapperModule(_qdq_quantized_linear),

0 commit comments

Comments
 (0)