Skip to content

Commit a4b1ce1

Browse files
committed
Reference representation of dqlinear int4 for xnnpack
Summary: This diff adds dynamic quantized linear's integer arithmetic representation. This is quite close to how arithmetic is done in xnnpack. Basic tests added against q/dq to make things are sane. Followups: - See if such a graph is traceable. - Optimize implementation if needed Test Plan: added Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 5108e2c Pull Request resolved: #2520
1 parent 58cb352 commit a4b1ce1

File tree

2 files changed

+762
-2
lines changed

2 files changed

+762
-2
lines changed

torchao/quantization/pt2e/reference_representation_rewrite.py

Lines changed: 324 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
import contextlib
99
from dataclasses import dataclass
1010
from functools import partial
11-
from typing import Any, Callable, Optional
11+
from typing import Any, Callable, List, Optional
1212

1313
import torch
1414
from torch._higher_order_ops.out_dtype import out_dtype
1515
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1616
from torch.fx import GraphModule
17+
from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch
1718
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
1819

1920
from torchao.quantization.pt2e.export_utils import WrapperModule
@@ -23,12 +24,17 @@
2324
_replace_literals_with_new_placeholders,
2425
remove_tensor_overload_for_qdq_ops,
2526
)
27+
from torchao.quantization.quant_primitives import MappingType
28+
from torchao.quantization.utils import _get_per_token_block_size
29+
from torchao.utils import _register_custom_op
2630

2731
try:
2832
from torch._export.utils import _disable_aten_to_metadata_assertions
2933
except:
3034
_disable_aten_to_metadata_assertions = contextlib.nullcontext
3135

36+
quant_lib = torch.library.Library("torchao", "FRAGMENT")
37+
register_custom_op = _register_custom_op(quant_lib)
3238

3339
__all__ = [
3440
"reference_representation_rewrite",
@@ -203,6 +209,252 @@ def _reference_dynamic_quantized_linear(
203209
return out_fp32
204210

205211

212+
def _qdq_dynamic_quantized_linear_4bit_groupwise(
213+
x_fp32,
214+
x_eps,
215+
weight_i4,
216+
weight_scale,
217+
weight_zero_point,
218+
bias_fp32,
219+
group_size,
220+
):
221+
# Dynamic quantization of activation
222+
x_mapping_type = MappingType.ASYMMETRIC
223+
per_token_block_size = _get_per_token_block_size(x_fp32)
224+
x_quant_min = -128
225+
x_quant_max = 127
226+
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
227+
x_fp32,
228+
x_mapping_type.name,
229+
per_token_block_size,
230+
torch.int8,
231+
x_quant_min,
232+
x_quant_max,
233+
x_eps,
234+
torch.float32,
235+
torch.int32,
236+
)
237+
x_i8 = torch.ops.torchao.quantize_affine(
238+
x_fp32,
239+
per_token_block_size,
240+
x_scale,
241+
x_zero_point,
242+
torch.int8,
243+
x_quant_min,
244+
x_quant_max,
245+
)
246+
x_fp32 = torch.ops.torchao.dequantize_affine(
247+
x_i8,
248+
per_token_block_size,
249+
x_scale,
250+
x_zero_point,
251+
torch.int8,
252+
x_quant_min,
253+
x_quant_max,
254+
torch.float32,
255+
)
256+
257+
assert group_size > 0, "Group size must be positive"
258+
assert (
259+
weight_i4.shape[1] % group_size == 0
260+
), "Weight must be divisible by group_size"
261+
assert weight_i4.dim() == 2, "Weight must be 2D tensor"
262+
block_size = (1, group_size)
263+
weight_fp32 = torch.ops.torchao.dequantize_affine(
264+
weight_i4,
265+
block_size,
266+
weight_scale,
267+
weight_zero_point,
268+
torch.int8,
269+
-8,
270+
7,
271+
)
272+
273+
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
274+
return out_fp32
275+
276+
277+
@register_custom_op
278+
def _reference_dqlinear_int4(
279+
x_fp32: torch.Tensor,
280+
x_eps: float,
281+
weight_i4: torch.Tensor,
282+
weight_scale: torch.Tensor,
283+
weight_zero_point: torch.Tensor, # Not used because assuming weight is symmetric
284+
bias_fp32: Optional[torch.Tensor],
285+
group_size: List[int],
286+
) -> torch.Tensor:
287+
"""
288+
Reference implementation for dynamically quantized linear 4-bit groupwise operation.
289+
This implementation emulates actual numerics of on-device integer compute.
290+
291+
Args:
292+
x_fp32: Input activation tensor in fp32
293+
x_eps: Epsilon for quantization parameter computation
294+
weight_i4: 4-bit quantized weight (stored as int8 with values in [-8, 7])
295+
weight_scale: Groupwise scales for weight dequantization
296+
weight_zero_point: Groupwise zero points for weight (unused for symmetric)
297+
bias_fp32: Optional bias tensor in fp32
298+
group_size: Size of each group for groupwise quantization
299+
300+
Returns:
301+
Output tensor in fp32
302+
"""
303+
# Dynamic quantization of activation
304+
group_size = group_size[1]
305+
x_mapping_type = MappingType.ASYMMETRIC
306+
per_token_block_size = _get_per_token_block_size(x_fp32)
307+
x_quant_min = -128
308+
x_quant_max = 127
309+
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
310+
x_fp32,
311+
x_mapping_type.name,
312+
per_token_block_size,
313+
torch.int8,
314+
x_quant_min,
315+
x_quant_max,
316+
x_eps,
317+
torch.float32,
318+
torch.int32,
319+
)
320+
x_i8 = torch.ops.torchao.quantize_affine(
321+
x_fp32,
322+
per_token_block_size,
323+
x_scale,
324+
x_zero_point,
325+
torch.int8,
326+
x_quant_min,
327+
x_quant_max,
328+
)
329+
330+
# For groupwise quantization, we need to handle the computation differently
331+
# weight_i4 shape: [out_features, in_features]
332+
# weight_scale shape: [out_features, in_features // group_size]
333+
# weight_zero_point shape: [out_features, in_features // group_size]
334+
out_features, in_features = weight_i4.shape
335+
num_groups = in_features // group_size
336+
337+
# scales in xnnpack are stored as bf16 and converted to fp32 for computation
338+
weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32)
339+
340+
# Reshape for group-wise processing
341+
# x: [batch_size, in_features] -> [batch_size, num_groups, group_size]
342+
x_orig_shape = x_i8.shape
343+
k_dim = x_i8.shape[-1]
344+
x_i8 = x_i8.view(-1, k_dim)
345+
batch_size = x_i8.shape[0]
346+
x_i8_grouped = x_i8.view(batch_size, num_groups, group_size)
347+
348+
# weight: [out_features, in_features] -> [out_features, num_groups, group_size]
349+
weight_i4_grouped = weight_i4.view(out_features, num_groups, group_size)
350+
351+
# Convert to int16 for computation
352+
x_i32_grouped = x_i8_grouped.to(torch.int32)
353+
weight_i32_grouped = weight_i4_grouped.to(torch.int32)
354+
355+
# Perform groupwise integer linear operation
356+
acc_fp32 = torch.zeros(
357+
batch_size, out_features, dtype=torch.float32, device=x_fp32.device
358+
)
359+
out_shape = list(x_orig_shape)
360+
out_shape[-1] = out_features
361+
362+
if weight_scale.ndim == 1:
363+
weight_scale = weight_scale.unsqueeze(0)
364+
365+
for group_idx in range(num_groups):
366+
# Extract current group
367+
x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size]
368+
weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size]
369+
weight_group_col_sum = weight_group.sum(dim=-1) # [out_features]
370+
371+
# Get scale for this group
372+
weight_scale_group = weight_scale[:, group_idx] # [out_features]
373+
374+
# Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features]
375+
group_acc = out_dtype(
376+
torch.ops.aten.linear.default,
377+
torch.int32,
378+
x_group,
379+
weight_group,
380+
None,
381+
)
382+
383+
# Output has to be scaled by x_scale * weight_scale_group
384+
# However we will first scale by weight_scale_group, that is accounting
385+
# only for scale of weight, and then scale by x_scale at the end because
386+
# x_scale applies to all groups
387+
acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view(
388+
1, -1
389+
)
390+
391+
# we must also subtract x_zero_point * weight_group_sum
392+
# since (X - x_zero_point) * W = X * W - x_zero_point * W
393+
weights_col_sum_adjusted = (
394+
weight_group_col_sum.to(torch.float32).view(1, -1)
395+
* x_zero_point.view(-1, 1)
396+
* weight_scale_group.view(1, -1)
397+
)
398+
acc_fp32 = acc_fp32 - weights_col_sum_adjusted
399+
x_scale_multiplier = x_scale.view(-1, 1)
400+
out_fp32 = acc_fp32 * x_scale_multiplier
401+
if bias_fp32 is not None:
402+
out_fp32 = out_fp32 + bias_fp32
403+
404+
return out_fp32.view(out_shape)
405+
406+
407+
def _reference_dynamic_quantized_linear_4bit_groupwise(
408+
x_fp32,
409+
x_eps,
410+
weight_i4,
411+
weight_scale,
412+
weight_zero_point, # Not used because assuming weight is symmetric
413+
bias_fp32,
414+
group_size,
415+
):
416+
"""
417+
Reference implementation for dynamically quantized linear 4-bit groupwise operation.
418+
This function now delegates to the custom op implementation.
419+
"""
420+
return torch.ops.torchao.reference_dqlinear_int4(
421+
x_fp32,
422+
x_eps,
423+
weight_i4,
424+
weight_scale,
425+
weight_zero_point,
426+
bias_fp32,
427+
(1, group_size),
428+
)
429+
430+
431+
def _filter_fn_for_dynamic_quantized_linear_4bit_groupwise(
432+
match,
433+
original_graph,
434+
pattern_graph,
435+
) -> bool:
436+
weight_is_int4 = False
437+
act_quant_is_int8 = False
438+
for node in match.nodes_map.values():
439+
if (
440+
isinstance(node, torch.fx.Node)
441+
and node.op == "call_function"
442+
and node.target == torch.ops.torchao.dequantize_affine.default
443+
):
444+
args = node.args
445+
if len(args) >= 7:
446+
weight_is_int4 = args[5] == -8 and args[6] == 7
447+
if (
448+
isinstance(node, torch.fx.Node)
449+
and node.op == "call_function"
450+
and node.target == torch.ops.torchao.quantize_affine.default
451+
):
452+
args = node.args
453+
if len(args) >= 5:
454+
act_quant_is_int8 = args[4] == torch.int8
455+
return weight_is_int4 and act_quant_is_int8
456+
457+
206458
def _qdq_quantized_conv2d(
207459
x_i8,
208460
x_scale,
@@ -627,6 +879,9 @@ class _RewriteInfo:
627879
# post transformation on the exported pattern and replacement GraphModule
628880
pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
629881
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
882+
filter_fn: Optional[
883+
list[Callable[["InternalMatch", torch.fx.Graph, torch.fx.Graph], bool]]
884+
] = None
630885
ignore_literals: bool = False
631886

632887

@@ -739,6 +994,31 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
739994
127,
740995
)
741996

997+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_1 = (
998+
torch.randn((1, 32), dtype=torch.float), # x_fp32
999+
torch.finfo(torch.float32).eps, # x_eps
1000+
torch.randint(-8, 7, (8, 32), dtype=torch.int8), # weight_i4 (stored as int8)
1001+
torch.randn(8, 4, dtype=torch.float), # weight_scale [out_features, num_groups]
1002+
torch.zeros(
1003+
8, 4, dtype=torch.int
1004+
), # weight_zero_point [out_features, num_groups]
1005+
torch.randn(8, dtype=torch.float), # bias_fp32
1006+
8, # group_size
1007+
)
1008+
1009+
# just saw that we can match again > 2 dim input. Hacky.
1010+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_2 = (
1011+
torch.randn((1, 1, 32), dtype=torch.float), # x_fp32
1012+
torch.finfo(torch.float32).eps, # x_eps
1013+
torch.randint(-8, 7, (8, 32), dtype=torch.int8), # weight_i4 (stored as int8)
1014+
torch.randn(8, 4, dtype=torch.float), # weight_scale [out_features, num_groups]
1015+
torch.zeros(
1016+
8, 4, dtype=torch.int
1017+
), # weight_zero_point [out_features, num_groups]
1018+
torch.randn(8, dtype=torch.float), # bias_fp32
1019+
8, # group_size
1020+
)
1021+
7421022
_REWRITE_INFO_LIST = [
7431023
_RewriteInfo(
7441024
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
@@ -753,6 +1033,48 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
7531033
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
7541034
),
7551035
),
1036+
_RewriteInfo(
1037+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_1,
1038+
WrapperModule(_qdq_dynamic_quantized_linear_4bit_groupwise),
1039+
WrapperModule(_reference_dynamic_quantized_linear_4bit_groupwise),
1040+
partial(
1041+
_replace_literals_with_existing_placeholders,
1042+
literal_to_ph_idx={
1043+
torch.finfo(torch.float32).eps: 1,
1044+
(1, 8): 6,
1045+
},
1046+
),
1047+
partial(
1048+
_replace_literals_with_existing_placeholders,
1049+
literal_to_ph_idx={
1050+
torch.finfo(torch.float32).eps: 1,
1051+
(1, 8): 6,
1052+
},
1053+
),
1054+
filter_fn=[_filter_fn_for_dynamic_quantized_linear_4bit_groupwise],
1055+
ignore_literals=True,
1056+
),
1057+
_RewriteInfo(
1058+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_2,
1059+
WrapperModule(_qdq_dynamic_quantized_linear_4bit_groupwise),
1060+
WrapperModule(_reference_dynamic_quantized_linear_4bit_groupwise),
1061+
partial(
1062+
_replace_literals_with_existing_placeholders,
1063+
literal_to_ph_idx={
1064+
torch.finfo(torch.float32).eps: 1,
1065+
(1, 8): 6,
1066+
},
1067+
),
1068+
partial(
1069+
_replace_literals_with_existing_placeholders,
1070+
literal_to_ph_idx={
1071+
torch.finfo(torch.float32).eps: 1,
1072+
(1, 8): 6,
1073+
},
1074+
),
1075+
filter_fn=[_filter_fn_for_dynamic_quantized_linear_4bit_groupwise],
1076+
ignore_literals=True,
1077+
),
7561078
_RewriteInfo(
7571079
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
7581080
WrapperModule(_qdq_quantized_linear),
@@ -835,7 +1157,7 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
8351157
model,
8361158
pattern,
8371159
replacement,
838-
match_filters=None,
1160+
match_filters=rewrite_info.filter_fn,
8391161
ignore_literals=rewrite_info.ignore_literals,
8401162
) # type: ignore[arg-type]
8411163

0 commit comments

Comments
 (0)