Skip to content

Commit 97c2ea9

Browse files
committed
[not for land] towards QAT with exact forward pass
Summary: Exploration for QAT with exact (not emulated) forward pass, WIP and not ready for review. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent c1e84cc commit 97c2ea9

File tree

3 files changed

+624
-0
lines changed

3 files changed

+624
-0
lines changed

torchao/prototype/qat_exact/main.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""
2+
Prototype of QAT with exact (instead of emulated) forward pass using
3+
integer matrix multiply.
4+
5+
Quant spec:
6+
* int4 symmetric weights w/ group size 32 or 256,
7+
* int8 asymmetric per-token dynamic activations
8+
9+
"""
10+
11+
import copy
12+
13+
import fire
14+
import torch
15+
import torch.nn as nn
16+
17+
from torchao.float8.float8_utils import compute_error
18+
from torchao.prototype.qat_exact.reference_gemm import (
19+
cpu_x_token_assym_fp8_w_group_sym_int4_gemm,
20+
naive_x_token_assym_fp8_w_group_sym_int4_gemm,
21+
)
22+
from torchao.prototype.qat_exact.triton_gemm import int8_matmul_triton
23+
from torchao.quantization import quantize_
24+
from torchao.quantization.qat import (
25+
FakeQuantizeConfig,
26+
IntXQuantizationAwareTrainingConfig,
27+
)
28+
from torchao.quantization.qat.fake_quantizer import FakeQuantizer
29+
from torchao.quantization.quant_primitives import (
30+
_DTYPE_TO_QVALUE_BOUNDS,
31+
MappingType,
32+
)
33+
from torchao.quantization.utils import (
34+
_get_per_token_block_size,
35+
)
36+
37+
torch.manual_seed(0)
38+
39+
40+
def quantize_x(x_fp32):
41+
# Dynamic quantization of activation
42+
x_mapping_type = MappingType.ASYMMETRIC
43+
per_token_block_size = _get_per_token_block_size(x_fp32)
44+
x_quant_min, x_quant_max = _DTYPE_TO_QVALUE_BOUNDS[torch.int8]
45+
x_eps = torch.finfo(torch.float32).eps
46+
x_scales_type = torch.float32
47+
x_zero_points_type = torch.int32
48+
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
49+
x_fp32,
50+
x_mapping_type.name,
51+
per_token_block_size,
52+
torch.int8,
53+
x_quant_min,
54+
x_quant_max,
55+
x_eps,
56+
x_scales_type,
57+
x_zero_points_type,
58+
)
59+
x_i8 = torch.ops.torchao.quantize_affine(
60+
x_fp32,
61+
per_token_block_size,
62+
x_scale,
63+
x_zero_point,
64+
torch.int8,
65+
x_quant_min,
66+
x_quant_max,
67+
)
68+
return x_i8, x_scale, x_zero_point
69+
70+
71+
class Int8PerTokenActivationInt4PerGroupWeightLinear(torch.nn.Linear):
72+
def __init__(self, *args, **kwargs):
73+
super().__init__(*args, **kwargs)
74+
# manually create fake quantizer configs
75+
activation_config = FakeQuantizeConfig(
76+
torch.int8, "per_token", is_symmetric=False
77+
)
78+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
79+
80+
# manually create fake quantizers
81+
# reference: `FakeQuantizedLinear` (https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/qat/linear.py)
82+
self.activation_fq = FakeQuantizer(activation_config)
83+
self.weight_fq = FakeQuantizer(weight_config)
84+
85+
def forward(self, input):
86+
# quantize x
87+
input_i8, input_scale, input_zp = quantize_x(input)
88+
89+
# quantize w
90+
_ = self.weight_fq(self.weight)
91+
w_qmin, w_qmax = _DTYPE_TO_QVALUE_BOUNDS[torch.int4]
92+
w_granularity = self.weight_fq.config.granularity
93+
w_group_size = w_granularity.group_size
94+
w_block_size = (1, w_group_size)
95+
weight_int4 = torch.ops.torchao.quantize_affine(
96+
self.weight,
97+
w_block_size,
98+
self.weight_fq.scale,
99+
self.weight_fq.zero_point,
100+
torch.int8,
101+
w_qmin,
102+
w_qmax,
103+
)
104+
105+
# original reference
106+
q_output = naive_x_token_assym_fp8_w_group_sym_int4_gemm(
107+
input_i8.to(torch.int32),
108+
input_scale,
109+
input_zp,
110+
weight_int4.to(torch.int32),
111+
self.weight_fq.scale,
112+
w_group_size,
113+
)
114+
115+
# now also check Kimish's implementation
116+
q_output2 = cpu_x_token_assym_fp8_w_group_sym_int4_gemm(
117+
input_i8.cpu(),
118+
input_scale.cpu(),
119+
input_zp.cpu(),
120+
weight_int4.cpu(),
121+
self.weight_fq.scale.cpu(),
122+
self.weight_fq.zero_point.cpu(),
123+
self.bias,
124+
self.weight_fq.config.granularity.group_size,
125+
).cuda()
126+
sqnr = compute_error(q_output, q_output2)
127+
print("sqnr vasiliy_reference vs kimish_reference", sqnr)
128+
129+
# finally, check vs triton gemm
130+
q_output3 = int8_matmul_triton(
131+
input_i8,
132+
weight_int4.t(),
133+
input_scale,
134+
input_zp,
135+
self.weight_fq.scale.t(),
136+
w_group_size,
137+
)
138+
139+
sqnr3 = compute_error(q_output, q_output3)
140+
print("sqnr vasiliy_reference vs triton", sqnr3)
141+
142+
sqnr4 = compute_error(q_output2, q_output3)
143+
print("sqnr kimish_reference vs triton", sqnr4)
144+
145+
return q_output, q_output2
146+
147+
@classmethod
148+
def from_float(cls, mod: torch.nn.Linear):
149+
new_mod = cls(mod.in_features, mod.out_features)
150+
new_mod.weight = mod.weight
151+
new_mod.bias = mod.bias
152+
return new_mod
153+
154+
155+
def run():
156+
M, K, N = 32, 64, 128
157+
158+
# TODO(before land): also implement bias=True
159+
m_hp = nn.Sequential(nn.Linear(K, N, bias=False)).cuda()
160+
mq_ref = copy.deepcopy(m_hp)
161+
mq = copy.deepcopy(m_hp)
162+
163+
# create a baseline: QAT with fake quants. Our exact QAT's output should
164+
# be close to this
165+
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
166+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
167+
quantize_(
168+
mq_ref,
169+
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
170+
)
171+
172+
# create the experiment: forward pass with an integer gemm
173+
mq[0] = Int8PerTokenActivationInt4PerGroupWeightLinear.from_float(mq[0])
174+
175+
x_hp = torch.randn(M, K, device="cuda")
176+
xq_ref = copy.deepcopy(x_hp)
177+
xq = copy.deepcopy(x_hp)
178+
179+
with torch.no_grad():
180+
y_hp = m_hp(x_hp)
181+
yq_ref = mq_ref(xq_ref)
182+
yq, yq2 = mq(xq)
183+
184+
sqnr_hp_qref = compute_error(y_hp, yq_ref)
185+
sqnr_hp_q = compute_error(y_hp, yq)
186+
sqnr_qref_q = compute_error(yq_ref, yq)
187+
print("sqnr_hp_qref", sqnr_hp_qref)
188+
print("sqnr_hp_q", sqnr_hp_q)
189+
print("sqnr_qref_q", sqnr_qref_q)
190+
191+
192+
if __name__ == "__main__":
193+
fire.Fire(run)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import torch
2+
from torch._higher_order_ops.out_dtype import out_dtype
3+
4+
5+
def cpu_x_token_assym_fp8_w_group_sym_int4_gemm(
6+
x_i8,
7+
x_scale,
8+
x_zero_point,
9+
weight_int4,
10+
weight_scale,
11+
weight_zero_point,
12+
bias_fp32,
13+
group_size,
14+
):
15+
# For groupwise quantization, we need to handle the computation differently
16+
# weight_i4 shape: [out_features, in_features]
17+
# weight_scale shape: [out_features, in_features // group_size]
18+
# weight_zero_point shape: [out_features, in_features // group_size]
19+
out_features, in_features = weight_int4.shape
20+
num_groups = in_features // group_size
21+
22+
# scales in xnnpack are stored as bf16 and converted to fp32 for computation
23+
weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32)
24+
25+
assert x_i8.dim() == 2, "x_i8 must be 2D tensor"
26+
# Reshape for group-wise processing
27+
# x: [batch_size, in_features] -> [batch_size, num_groups, group_size]
28+
batch_size = x_i8.shape[0]
29+
x_i8_grouped = x_i8.view(batch_size, num_groups, group_size)
30+
31+
# weight: [out_features, in_features] -> [out_features, num_groups, group_size]
32+
weight_i4_grouped = weight_int4.view(out_features, num_groups, group_size)
33+
34+
# Convert to int16 for computation
35+
x_i32_grouped = x_i8_grouped.to(torch.int32)
36+
weight_i32_grouped = weight_i4_grouped.to(torch.int32)
37+
38+
# Perform groupwise integer linear operation
39+
acc_fp32 = torch.zeros(
40+
batch_size, out_features, dtype=torch.float32, device=x_i8.device
41+
)
42+
43+
for group_idx in range(num_groups):
44+
# Extract current group
45+
x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size]
46+
weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size]
47+
weight_group_col_sum = weight_group.sum(dim=-1) # [out_features]
48+
49+
# Get scale for this group
50+
weight_scale_group = weight_scale[:, group_idx] # [out_features]
51+
52+
# Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features]
53+
group_acc = out_dtype(
54+
torch.ops.aten.linear.default,
55+
torch.int32,
56+
x_group,
57+
weight_group,
58+
None,
59+
)
60+
61+
# Output has to be scaled by x_scale * weight_scale_group
62+
# However we will first scale by weight_scale_group, that is accounting
63+
# only for scale of weight, and then scale by x_scale at the end because
64+
# x_scale applies to all groups
65+
acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view(
66+
1, -1
67+
)
68+
69+
# we must also subtract x_zero_point * weight_group_sum
70+
# since (X - x_zero_point) * W = X * W - x_zero_point * W
71+
weights_col_sum_adjusted = (
72+
weight_group_col_sum.to(torch.float32).view(1, -1)
73+
* x_zero_point.view(-1, 1)
74+
* weight_scale_group.view(1, -1)
75+
)
76+
acc_fp32 = acc_fp32 - weights_col_sum_adjusted
77+
x_scale_multiplier = x_scale.view(-1, 1)
78+
out_fp32 = acc_fp32 * x_scale_multiplier
79+
if bias_fp32 is not None:
80+
out_fp32 = out_fp32 + bias_fp32
81+
82+
return out_fp32
83+
84+
85+
def naive_x_token_assym_fp8_w_group_sym_int4_gemm(
86+
act_q,
87+
act_scale,
88+
act_zp,
89+
w_q,
90+
w_scale,
91+
w_group_size,
92+
) -> torch.Tensor:
93+
#
94+
# now we have the scales/zero_points/quant values for both gemm operands
95+
# below is a manual slow gemm with integer operands and float rescaling,
96+
# implemented using eager PyTorch ops. This should be slow but closely
97+
# (but not exactly) matching a real int8,int8->int32 gemm with
98+
# rescaling, with the only difference being that the sum inside of the
99+
# dot product is done in float32 right now.
100+
#
101+
q_output = torch.zeros(
102+
act_q.shape[0],
103+
w_q.shape[0],
104+
dtype=torch.float32,
105+
device=act_q.device,
106+
)
107+
for m_idx in range(act_q.shape[0]):
108+
for n_idx in range(w_q.shape[0]):
109+
for g_idx in range(w_q.shape[1] // w_group_size):
110+
k_start = g_idx * w_group_size
111+
k_end = k_start + w_group_size
112+
act_chunk = act_q[m_idx][k_start:k_end]
113+
w_chunk = w_q[n_idx][k_start:k_end]
114+
115+
# (act_q - act_zp) * w_q
116+
# result still in int32
117+
elem_int32 = (act_chunk - act_zp[m_idx]) * w_chunk
118+
119+
# sum((act_q - act_zp) * w_q)
120+
# this is in float32, so likely a small deviation from the real
121+
# kernel, where the entire dot product would be in int32
122+
sum_float32 = torch.sum(elem_int32)
123+
124+
# scale
125+
act_scale_tmp = act_scale[m_idx].squeeze(-1)
126+
w_scale_tmp = w_scale[n_idx][g_idx].squeeze(-1).bfloat16().float()
127+
sum_scaled = sum_float32 * act_scale_tmp * w_scale_tmp
128+
129+
# accumulate
130+
q_output[m_idx][n_idx] += sum_scaled
131+
132+
return q_output

0 commit comments

Comments
 (0)