Skip to content

Commit 5dbe548

Browse files
committed
Reference representation of dqlinear int4 for xnnpack
Summary: This diff adds dynamic quantized linear's integer arithmetic representation. This is quite close to how arithmetic is done in xnnpack. Basic tests added against q/dq to make things are sane. Followups: - See if such a graph is traceable. - Optimize implementation if needed Test Plan: added Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 756a9e9 Pull Request resolved: #2520
1 parent 5670711 commit 5dbe548

File tree

2 files changed

+694
-1
lines changed

2 files changed

+694
-1
lines changed

torchao/quantization/pt2e/reference_representation_rewrite.py

Lines changed: 256 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import contextlib
99
from dataclasses import dataclass
1010
from functools import partial
11-
from typing import Any, Callable, Optional
11+
from typing import Any, Callable, List, Optional
1212

1313
import torch
1414
from torch._higher_order_ops.out_dtype import out_dtype
@@ -24,14 +24,22 @@
2424
remove_tensor_overload_for_qdq_ops,
2525
)
2626

27+
from torchao.quantization.quant_primitives import MappingType
28+
from torchao.quantization.utils import _get_per_token_block_size
29+
from torchao.utils import _register_custom_op
30+
2731
try:
2832
from torch._export.utils import _disable_aten_to_metadata_assertions
2933
except:
3034
_disable_aten_to_metadata_assertions = contextlib.nullcontext
3135

36+
quant_lib = torch.library.Library("torchao", "FRAGMENT")
37+
register_custom_op = _register_custom_op(quant_lib)
3238

3339
__all__ = [
3440
"reference_representation_rewrite",
41+
"_qdq_dynamic_quantized_linear_4bit_groupwise",
42+
"_reference_dynamic_quantized_linear_4bit_groupwise",
3543
]
3644

3745

@@ -203,6 +211,221 @@ def _reference_dynamic_quantized_linear(
203211
return out_fp32
204212

205213

214+
def _qdq_dynamic_quantized_linear_4bit_groupwise(
215+
x_fp32,
216+
x_eps,
217+
weight_i4,
218+
weight_scale,
219+
weight_zero_point,
220+
bias_fp32,
221+
group_size,
222+
):
223+
# Dynamic quantization of activation
224+
x_mapping_type = MappingType.ASYMMETRIC
225+
per_token_block_size = _get_per_token_block_size(x_fp32)
226+
x_quant_min = -128
227+
x_quant_max = 127
228+
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
229+
x_fp32,
230+
x_mapping_type.name,
231+
per_token_block_size,
232+
torch.int8,
233+
x_quant_min,
234+
x_quant_max,
235+
x_eps,
236+
torch.float32,
237+
torch.int32,
238+
)
239+
x_i8 = torch.ops.torchao.quantize_affine(
240+
x_fp32,
241+
per_token_block_size,
242+
x_scale,
243+
x_zero_point,
244+
torch.int8,
245+
x_quant_min,
246+
x_quant_max,
247+
)
248+
x_fp32 = torch.ops.torchao.dequantize_affine(
249+
x_i8,
250+
per_token_block_size,
251+
x_scale,
252+
x_zero_point,
253+
torch.int8,
254+
x_quant_min,
255+
x_quant_max,
256+
torch.float32,
257+
)
258+
259+
assert group_size > 0, "Group size must be positive"
260+
assert (
261+
weight_i4.shape[1] % group_size == 0
262+
), "Weight must be divisible by group_size"
263+
assert weight_i4.dim() == 2, "Weight must be 2D tensor"
264+
block_size = (1, group_size)
265+
weight_fp32 = torch.ops.torchao.dequantize_affine(
266+
weight_i4,
267+
block_size,
268+
weight_scale,
269+
weight_zero_point,
270+
torch.int8,
271+
-8,
272+
7,
273+
)
274+
275+
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
276+
return out_fp32
277+
278+
279+
@register_custom_op
280+
def _reference_dqlinear_int4(
281+
x_fp32: torch.Tensor,
282+
x_eps: float,
283+
weight_i4: torch.Tensor,
284+
weight_scale: torch.Tensor,
285+
weight_zero_point: torch.Tensor, # Not used because assuming weight is symmetric
286+
bias_fp32: Optional[torch.Tensor],
287+
group_size: List[int],
288+
) -> torch.Tensor:
289+
"""
290+
Reference implementation for dynamically quantized linear 4-bit groupwise operation.
291+
This implementation emulates actual numerics of on-device integer compute.
292+
293+
Args:
294+
x_fp32: Input activation tensor in fp32
295+
x_eps: Epsilon for quantization parameter computation
296+
weight_i4: 4-bit quantized weight (stored as int8 with values in [-8, 7])
297+
weight_scale: Groupwise scales for weight dequantization
298+
weight_zero_point: Groupwise zero points for weight (unused for symmetric)
299+
bias_fp32: Optional bias tensor in fp32
300+
group_size: Size of each group for groupwise quantization
301+
302+
Returns:
303+
Output tensor in fp32
304+
"""
305+
# Dynamic quantization of activation
306+
group_size = group_size[1]
307+
x_mapping_type = MappingType.ASYMMETRIC
308+
per_token_block_size = _get_per_token_block_size(x_fp32)
309+
x_quant_min = -128
310+
x_quant_max = 127
311+
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
312+
x_fp32,
313+
x_mapping_type.name,
314+
per_token_block_size,
315+
torch.int8,
316+
x_quant_min,
317+
x_quant_max,
318+
x_eps,
319+
torch.float32,
320+
torch.int32,
321+
)
322+
x_i8 = torch.ops.torchao.quantize_affine(
323+
x_fp32,
324+
per_token_block_size,
325+
x_scale,
326+
x_zero_point,
327+
torch.int8,
328+
x_quant_min,
329+
x_quant_max,
330+
)
331+
332+
# For groupwise quantization, we need to handle the computation differently
333+
# weight_i4 shape: [out_features, in_features]
334+
# weight_scale shape: [out_features, in_features // group_size]
335+
# weight_zero_point shape: [out_features, in_features // group_size]
336+
out_features, in_features = weight_i4.shape
337+
num_groups = in_features // group_size
338+
339+
# scales in xnnpack are stored as bf16 and converted to fp32 for computation
340+
weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32)
341+
342+
assert x_i8.dim() == 2, "x_i8 must be 2D tensor"
343+
# Reshape for group-wise processing
344+
# x: [batch_size, in_features] -> [batch_size, num_groups, group_size]
345+
batch_size = x_i8.shape[0]
346+
x_i8_grouped = x_i8.view(batch_size, num_groups, group_size)
347+
348+
# weight: [out_features, in_features] -> [out_features, num_groups, group_size]
349+
weight_i4_grouped = weight_i4.view(out_features, num_groups, group_size)
350+
351+
# Convert to int16 for computation
352+
x_i32_grouped = x_i8_grouped.to(torch.int32)
353+
weight_i32_grouped = weight_i4_grouped.to(torch.int32)
354+
355+
# Perform groupwise integer linear operation
356+
acc_fp32 = torch.zeros(
357+
batch_size, out_features, dtype=torch.float32, device=x_fp32.device
358+
)
359+
360+
if weight_scale.ndim == 1:
361+
weight_scale = weight_scale.unsqueeze(0)
362+
363+
for group_idx in range(num_groups):
364+
# Extract current group
365+
x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size]
366+
weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size]
367+
weight_group_col_sum = weight_group.sum(dim=-1) # [out_features]
368+
369+
# Get scale for this group
370+
weight_scale_group = weight_scale[:, group_idx] # [out_features]
371+
372+
# Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features]
373+
group_acc = out_dtype(
374+
torch.ops.aten.linear.default,
375+
torch.int32,
376+
x_group,
377+
weight_group,
378+
None,
379+
)
380+
381+
# Output has to be scaled by x_scale * weight_scale_group
382+
# However we will first scale by weight_scale_group, that is accounting
383+
# only for scale of weight, and then scale by x_scale at the end because
384+
# x_scale applies to all groups
385+
acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view(
386+
1, -1
387+
)
388+
389+
# we must also subtract x_zero_point * weight_group_sum
390+
# since (X - x_zero_point) * W = X * W - x_zero_point * W
391+
weights_col_sum_adjusted = (
392+
weight_group_col_sum.to(torch.float32).view(1, -1)
393+
* x_zero_point.view(-1, 1)
394+
* weight_scale_group.view(1, -1)
395+
)
396+
acc_fp32 = acc_fp32 - weights_col_sum_adjusted
397+
x_scale_multiplier = x_scale.view(-1, 1)
398+
out_fp32 = acc_fp32 * x_scale_multiplier
399+
if bias_fp32 is not None:
400+
out_fp32 = out_fp32 + bias_fp32
401+
402+
return out_fp32
403+
404+
405+
def _reference_dynamic_quantized_linear_4bit_groupwise(
406+
x_fp32,
407+
x_eps,
408+
weight_i4,
409+
weight_scale,
410+
weight_zero_point, # Not used because assuming weight is symmetric
411+
bias_fp32,
412+
group_size,
413+
):
414+
"""
415+
Reference implementation for dynamically quantized linear 4-bit groupwise operation.
416+
This function now delegates to the custom op implementation.
417+
"""
418+
return torch.ops.torchao.reference_dqlinear_int4(
419+
x_fp32,
420+
x_eps,
421+
weight_i4,
422+
weight_scale,
423+
weight_zero_point,
424+
bias_fp32,
425+
(1, group_size),
426+
)
427+
428+
206429
def _qdq_quantized_conv2d(
207430
x_i8,
208431
x_scale,
@@ -739,6 +962,18 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
739962
127,
740963
)
741964

965+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS = (
966+
torch.randn((2, 32), dtype=torch.float), # x_fp32
967+
torch.finfo(torch.float32).eps, # x_eps
968+
torch.randint(-8, 7, (8, 32), dtype=torch.int8), # weight_i4 (stored as int8)
969+
torch.randn(8, 4, dtype=torch.float), # weight_scale [out_features, num_groups]
970+
torch.zeros(
971+
8, 4, dtype=torch.int
972+
), # weight_zero_point [out_features, num_groups]
973+
torch.randn(8, dtype=torch.float), # bias_fp32
974+
8, # group_size
975+
)
976+
742977
_REWRITE_INFO_LIST = [
743978
_RewriteInfo(
744979
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
@@ -753,6 +988,26 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
753988
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
754989
),
755990
),
991+
_RewriteInfo(
992+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS,
993+
WrapperModule(_qdq_dynamic_quantized_linear_4bit_groupwise),
994+
WrapperModule(_reference_dynamic_quantized_linear_4bit_groupwise),
995+
partial(
996+
_replace_literals_with_existing_placeholders,
997+
literal_to_ph_idx={
998+
torch.finfo(torch.float32).eps: 1,
999+
(1, 8): 6,
1000+
},
1001+
),
1002+
partial(
1003+
_replace_literals_with_existing_placeholders,
1004+
literal_to_ph_idx={
1005+
torch.finfo(torch.float32).eps: 1,
1006+
(1, 8): 6,
1007+
},
1008+
),
1009+
ignore_literals=True,
1010+
),
7561011
_RewriteInfo(
7571012
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
7581013
WrapperModule(_qdq_quantized_linear),

0 commit comments

Comments
 (0)