Skip to content

Commit d3b06ea

Browse files
committed
Reference representation of dqlinear int4 for xnnpack
Summary: This diff adds dynamic quantized linear's integer arithmetic representation. This is quite close to how arithmetic is done in xnnpack. Basic tests added against q/dq to make things are sane. Followups: - See if such a graph is traceable. - Optimize implementation if needed Test Plan: added Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0f79f1c Pull Request resolved: #2520
1 parent 58cb352 commit d3b06ea

File tree

2 files changed

+693
-1
lines changed

2 files changed

+693
-1
lines changed

torchao/quantization/pt2e/reference_representation_rewrite.py

Lines changed: 255 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import contextlib
99
from dataclasses import dataclass
1010
from functools import partial
11-
from typing import Any, Callable, Optional
11+
from typing import Any, Callable, List, Optional
1212

1313
import torch
1414
from torch._higher_order_ops.out_dtype import out_dtype
@@ -23,15 +23,22 @@
2323
_replace_literals_with_new_placeholders,
2424
remove_tensor_overload_for_qdq_ops,
2525
)
26+
from torchao.quantization.quant_primitives import MappingType
27+
from torchao.quantization.utils import _get_per_token_block_size
28+
from torchao.utils import _register_custom_op
2629

2730
try:
2831
from torch._export.utils import _disable_aten_to_metadata_assertions
2932
except:
3033
_disable_aten_to_metadata_assertions = contextlib.nullcontext
3134

35+
quant_lib = torch.library.Library("torchao", "FRAGMENT")
36+
register_custom_op = _register_custom_op(quant_lib)
3237

3338
__all__ = [
3439
"reference_representation_rewrite",
40+
"_qdq_dynamic_quantized_linear_4bit_groupwise",
41+
"_reference_dynamic_quantized_linear_4bit_groupwise",
3542
]
3643

3744

@@ -203,6 +210,221 @@ def _reference_dynamic_quantized_linear(
203210
return out_fp32
204211

205212

213+
def _qdq_dynamic_quantized_linear_4bit_groupwise(
214+
x_fp32,
215+
x_eps,
216+
weight_i4,
217+
weight_scale,
218+
weight_zero_point,
219+
bias_fp32,
220+
group_size,
221+
):
222+
# Dynamic quantization of activation
223+
x_mapping_type = MappingType.ASYMMETRIC
224+
per_token_block_size = _get_per_token_block_size(x_fp32)
225+
x_quant_min = -128
226+
x_quant_max = 127
227+
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
228+
x_fp32,
229+
x_mapping_type.name,
230+
per_token_block_size,
231+
torch.int8,
232+
x_quant_min,
233+
x_quant_max,
234+
x_eps,
235+
torch.float32,
236+
torch.int32,
237+
)
238+
x_i8 = torch.ops.torchao.quantize_affine(
239+
x_fp32,
240+
per_token_block_size,
241+
x_scale,
242+
x_zero_point,
243+
torch.int8,
244+
x_quant_min,
245+
x_quant_max,
246+
)
247+
x_fp32 = torch.ops.torchao.dequantize_affine(
248+
x_i8,
249+
per_token_block_size,
250+
x_scale,
251+
x_zero_point,
252+
torch.int8,
253+
x_quant_min,
254+
x_quant_max,
255+
torch.float32,
256+
)
257+
258+
assert group_size > 0, "Group size must be positive"
259+
assert (
260+
weight_i4.shape[1] % group_size == 0
261+
), "Weight must be divisible by group_size"
262+
assert weight_i4.dim() == 2, "Weight must be 2D tensor"
263+
block_size = (1, group_size)
264+
weight_fp32 = torch.ops.torchao.dequantize_affine(
265+
weight_i4,
266+
block_size,
267+
weight_scale,
268+
weight_zero_point,
269+
torch.int8,
270+
-8,
271+
7,
272+
)
273+
274+
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
275+
return out_fp32
276+
277+
278+
@register_custom_op
279+
def _reference_dqlinear_int4(
280+
x_fp32: torch.Tensor,
281+
x_eps: float,
282+
weight_i4: torch.Tensor,
283+
weight_scale: torch.Tensor,
284+
weight_zero_point: torch.Tensor, # Not used because assuming weight is symmetric
285+
bias_fp32: Optional[torch.Tensor],
286+
group_size: List[int],
287+
) -> torch.Tensor:
288+
"""
289+
Reference implementation for dynamically quantized linear 4-bit groupwise operation.
290+
This implementation emulates actual numerics of on-device integer compute.
291+
292+
Args:
293+
x_fp32: Input activation tensor in fp32
294+
x_eps: Epsilon for quantization parameter computation
295+
weight_i4: 4-bit quantized weight (stored as int8 with values in [-8, 7])
296+
weight_scale: Groupwise scales for weight dequantization
297+
weight_zero_point: Groupwise zero points for weight (unused for symmetric)
298+
bias_fp32: Optional bias tensor in fp32
299+
group_size: Size of each group for groupwise quantization
300+
301+
Returns:
302+
Output tensor in fp32
303+
"""
304+
# Dynamic quantization of activation
305+
group_size = group_size[1]
306+
x_mapping_type = MappingType.ASYMMETRIC
307+
per_token_block_size = _get_per_token_block_size(x_fp32)
308+
x_quant_min = -128
309+
x_quant_max = 127
310+
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
311+
x_fp32,
312+
x_mapping_type.name,
313+
per_token_block_size,
314+
torch.int8,
315+
x_quant_min,
316+
x_quant_max,
317+
x_eps,
318+
torch.float32,
319+
torch.int32,
320+
)
321+
x_i8 = torch.ops.torchao.quantize_affine(
322+
x_fp32,
323+
per_token_block_size,
324+
x_scale,
325+
x_zero_point,
326+
torch.int8,
327+
x_quant_min,
328+
x_quant_max,
329+
)
330+
331+
# For groupwise quantization, we need to handle the computation differently
332+
# weight_i4 shape: [out_features, in_features]
333+
# weight_scale shape: [out_features, in_features // group_size]
334+
# weight_zero_point shape: [out_features, in_features // group_size]
335+
out_features, in_features = weight_i4.shape
336+
num_groups = in_features // group_size
337+
338+
# scales in xnnpack are stored as bf16 and converted to fp32 for computation
339+
weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32)
340+
341+
assert x_i8.dim() == 2, "x_i8 must be 2D tensor"
342+
# Reshape for group-wise processing
343+
# x: [batch_size, in_features] -> [batch_size, num_groups, group_size]
344+
batch_size = x_i8.shape[0]
345+
x_i8_grouped = x_i8.view(batch_size, num_groups, group_size)
346+
347+
# weight: [out_features, in_features] -> [out_features, num_groups, group_size]
348+
weight_i4_grouped = weight_i4.view(out_features, num_groups, group_size)
349+
350+
# Convert to int16 for computation
351+
x_i32_grouped = x_i8_grouped.to(torch.int32)
352+
weight_i32_grouped = weight_i4_grouped.to(torch.int32)
353+
354+
# Perform groupwise integer linear operation
355+
acc_fp32 = torch.zeros(
356+
batch_size, out_features, dtype=torch.float32, device=x_fp32.device
357+
)
358+
359+
if weight_scale.ndim == 1:
360+
weight_scale = weight_scale.unsqueeze(0)
361+
362+
for group_idx in range(num_groups):
363+
# Extract current group
364+
x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size]
365+
weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size]
366+
weight_group_col_sum = weight_group.sum(dim=-1) # [out_features]
367+
368+
# Get scale for this group
369+
weight_scale_group = weight_scale[:, group_idx] # [out_features]
370+
371+
# Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features]
372+
group_acc = out_dtype(
373+
torch.ops.aten.linear.default,
374+
torch.int32,
375+
x_group,
376+
weight_group,
377+
None,
378+
)
379+
380+
# Output has to be scaled by x_scale * weight_scale_group
381+
# However we will first scale by weight_scale_group, that is accounting
382+
# only for scale of weight, and then scale by x_scale at the end because
383+
# x_scale applies to all groups
384+
acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view(
385+
1, -1
386+
)
387+
388+
# we must also subtract x_zero_point * weight_group_sum
389+
# since (X - x_zero_point) * W = X * W - x_zero_point * W
390+
weights_col_sum_adjusted = (
391+
weight_group_col_sum.to(torch.float32).view(1, -1)
392+
* x_zero_point.view(-1, 1)
393+
* weight_scale_group.view(1, -1)
394+
)
395+
acc_fp32 = acc_fp32 - weights_col_sum_adjusted
396+
x_scale_multiplier = x_scale.view(-1, 1)
397+
out_fp32 = acc_fp32 * x_scale_multiplier
398+
if bias_fp32 is not None:
399+
out_fp32 = out_fp32 + bias_fp32
400+
401+
return out_fp32
402+
403+
404+
def _reference_dynamic_quantized_linear_4bit_groupwise(
405+
x_fp32,
406+
x_eps,
407+
weight_i4,
408+
weight_scale,
409+
weight_zero_point, # Not used because assuming weight is symmetric
410+
bias_fp32,
411+
group_size,
412+
):
413+
"""
414+
Reference implementation for dynamically quantized linear 4-bit groupwise operation.
415+
This function now delegates to the custom op implementation.
416+
"""
417+
return torch.ops.torchao.reference_dqlinear_int4(
418+
x_fp32,
419+
x_eps,
420+
weight_i4,
421+
weight_scale,
422+
weight_zero_point,
423+
bias_fp32,
424+
(1, group_size),
425+
)
426+
427+
206428
def _qdq_quantized_conv2d(
207429
x_i8,
208430
x_scale,
@@ -739,6 +961,18 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
739961
127,
740962
)
741963

964+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS = (
965+
torch.randn((2, 32), dtype=torch.float), # x_fp32
966+
torch.finfo(torch.float32).eps, # x_eps
967+
torch.randint(-8, 7, (8, 32), dtype=torch.int8), # weight_i4 (stored as int8)
968+
torch.randn(8, 4, dtype=torch.float), # weight_scale [out_features, num_groups]
969+
torch.zeros(
970+
8, 4, dtype=torch.int
971+
), # weight_zero_point [out_features, num_groups]
972+
torch.randn(8, dtype=torch.float), # bias_fp32
973+
8, # group_size
974+
)
975+
742976
_REWRITE_INFO_LIST = [
743977
_RewriteInfo(
744978
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
@@ -753,6 +987,26 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
753987
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
754988
),
755989
),
990+
_RewriteInfo(
991+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS,
992+
WrapperModule(_qdq_dynamic_quantized_linear_4bit_groupwise),
993+
WrapperModule(_reference_dynamic_quantized_linear_4bit_groupwise),
994+
partial(
995+
_replace_literals_with_existing_placeholders,
996+
literal_to_ph_idx={
997+
torch.finfo(torch.float32).eps: 1,
998+
(1, 8): 6,
999+
},
1000+
),
1001+
partial(
1002+
_replace_literals_with_existing_placeholders,
1003+
literal_to_ph_idx={
1004+
torch.finfo(torch.float32).eps: 1,
1005+
(1, 8): 6,
1006+
},
1007+
),
1008+
ignore_literals=True,
1009+
),
7561010
_RewriteInfo(
7571011
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
7581012
WrapperModule(_qdq_quantized_linear),

0 commit comments

Comments
 (0)