Skip to content

Commit 44021ad

Browse files
committed
[not for land] towards QAT with exact forward pass
Summary: Exploration for QAT with exact (not emulated) forward pass, WIP and not ready for review. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent c1e84cc commit 44021ad

File tree

2 files changed

+619
-0
lines changed

2 files changed

+619
-0
lines changed

torchao/prototype/qat_exact/main.py

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
"""
2+
Prototype of QAT with exact (instead of emulated) forward pass using
3+
integer matrix multiply.
4+
5+
Quant spec:
6+
* int4 symmetric weights w/ group size 32 or 256,
7+
* int8 asymmetric per-token dynamic activations
8+
9+
"""
10+
11+
import copy
12+
13+
import fire
14+
import torch
15+
import torch.nn as nn
16+
from torch._higher_order_ops.out_dtype import out_dtype
17+
18+
from torchao.float8.float8_utils import compute_error
19+
from torchao.prototype.qat_exact.triton_gemm import int8_matmul_triton
20+
from torchao.quantization import quantize_
21+
from torchao.quantization.qat import (
22+
FakeQuantizeConfig,
23+
IntXQuantizationAwareTrainingConfig,
24+
)
25+
from torchao.quantization.qat.fake_quantizer import FakeQuantizer
26+
from torchao.quantization.quant_primitives import (
27+
_DTYPE_TO_QVALUE_BOUNDS,
28+
MappingType,
29+
)
30+
from torchao.quantization.utils import (
31+
_get_per_token_block_size,
32+
)
33+
34+
torch.manual_seed(0)
35+
36+
37+
def quantize_x(x_fp32):
38+
# Dynamic quantization of activation
39+
x_mapping_type = MappingType.ASYMMETRIC
40+
per_token_block_size = _get_per_token_block_size(x_fp32)
41+
x_quant_min, x_quant_max = _DTYPE_TO_QVALUE_BOUNDS[torch.int8]
42+
x_eps = torch.finfo(torch.float32).eps
43+
x_scales_type = torch.float32
44+
x_zero_points_type = torch.int32
45+
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
46+
x_fp32,
47+
x_mapping_type.name,
48+
per_token_block_size,
49+
torch.int8,
50+
x_quant_min,
51+
x_quant_max,
52+
x_eps,
53+
x_scales_type,
54+
x_zero_points_type,
55+
)
56+
x_i8 = torch.ops.torchao.quantize_affine(
57+
x_fp32,
58+
per_token_block_size,
59+
x_scale,
60+
x_zero_point,
61+
torch.int8,
62+
x_quant_min,
63+
x_quant_max,
64+
)
65+
return x_i8, x_scale, x_zero_point
66+
67+
68+
def _xnnpack_integer_arithmetic_dqlinear_4bit_groupwise(
69+
x_i8,
70+
x_scale,
71+
x_zero_point,
72+
weight_int4,
73+
weight_scale,
74+
weight_zero_point,
75+
bias_fp32,
76+
group_size,
77+
):
78+
# For groupwise quantization, we need to handle the computation differently
79+
# weight_i4 shape: [out_features, in_features]
80+
# weight_scale shape: [out_features, in_features // group_size]
81+
# weight_zero_point shape: [out_features, in_features // group_size]
82+
out_features, in_features = weight_int4.shape
83+
num_groups = in_features // group_size
84+
85+
# scales in xnnpack are stored as bf16 and converted to fp32 for computation
86+
weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32)
87+
88+
assert x_i8.dim() == 2, "x_i8 must be 2D tensor"
89+
# Reshape for group-wise processing
90+
# x: [batch_size, in_features] -> [batch_size, num_groups, group_size]
91+
batch_size = x_i8.shape[0]
92+
x_i8_grouped = x_i8.view(batch_size, num_groups, group_size)
93+
94+
# weight: [out_features, in_features] -> [out_features, num_groups, group_size]
95+
weight_i4_grouped = weight_int4.view(out_features, num_groups, group_size)
96+
97+
# Convert to int16 for computation
98+
x_i32_grouped = x_i8_grouped.to(torch.int32)
99+
weight_i32_grouped = weight_i4_grouped.to(torch.int32)
100+
101+
# Perform groupwise integer linear operation
102+
acc_fp32 = torch.zeros(
103+
batch_size, out_features, dtype=torch.float32, device=x_i8.device
104+
)
105+
106+
for group_idx in range(num_groups):
107+
# Extract current group
108+
x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size]
109+
weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size]
110+
weight_group_col_sum = weight_group.sum(dim=-1) # [out_features]
111+
112+
# Get scale for this group
113+
weight_scale_group = weight_scale[:, group_idx] # [out_features]
114+
115+
# Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features]
116+
group_acc = out_dtype(
117+
torch.ops.aten.linear.default,
118+
torch.int32,
119+
x_group,
120+
weight_group,
121+
None,
122+
)
123+
124+
# Output has to be scaled by x_scale * weight_scale_group
125+
# However we will first scale by weight_scale_group, that is accounting
126+
# only for scale of weight, and then scale by x_scale at the end because
127+
# x_scale applies to all groups
128+
acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view(
129+
1, -1
130+
)
131+
132+
# we must also subtract x_zero_point * weight_group_sum
133+
# since (X - x_zero_point) * W = X * W - x_zero_point * W
134+
weights_col_sum_adjusted = (
135+
weight_group_col_sum.to(torch.float32).view(1, -1)
136+
* x_zero_point.view(-1, 1)
137+
* weight_scale_group.view(1, -1)
138+
)
139+
acc_fp32 = acc_fp32 - weights_col_sum_adjusted
140+
x_scale_multiplier = x_scale.view(-1, 1)
141+
out_fp32 = acc_fp32 * x_scale_multiplier
142+
if bias_fp32 is not None:
143+
out_fp32 = out_fp32 + bias_fp32
144+
145+
return out_fp32
146+
147+
148+
def vasiliy_reference_i8i8f32_gemm(
149+
act_q,
150+
act_scale,
151+
act_zp,
152+
w_q,
153+
w_scale,
154+
w_group_size,
155+
) -> torch.Tensor:
156+
#
157+
# now we have the scales/zero_points/quant values for both gemm operands
158+
# below is a manual slow gemm with integer operands and float rescaling,
159+
# implemented using eager PyTorch ops. This should be slow but closely
160+
# (but not exactly) matching a real int8,int8->int32 gemm with
161+
# rescaling, with the only difference being that the sum inside of the
162+
# dot product is done in float32 right now.
163+
#
164+
q_output = torch.zeros(
165+
act_q.shape[0],
166+
w_q.shape[0],
167+
dtype=torch.float32,
168+
device=act_q.device,
169+
)
170+
for m_idx in range(act_q.shape[0]):
171+
for n_idx in range(w_q.shape[0]):
172+
for g_idx in range(w_q.shape[1] // w_group_size):
173+
k_start = g_idx * w_group_size
174+
k_end = k_start + w_group_size
175+
act_chunk = act_q[m_idx][k_start:k_end]
176+
w_chunk = w_q[n_idx][k_start:k_end]
177+
178+
# (act_q - act_zp) * w_q
179+
# result still in int32
180+
elem_int32 = (act_chunk - act_zp[m_idx]) * w_chunk
181+
182+
# sum((act_q - act_zp) * w_q)
183+
# this is in float32, so likely a small deviation from the real
184+
# kernel, where the entire dot product would be in int32
185+
sum_float32 = torch.sum(elem_int32)
186+
187+
# scale
188+
act_scale_tmp = act_scale[m_idx].squeeze(-1)
189+
w_scale_tmp = w_scale[n_idx][g_idx].squeeze(-1).bfloat16().float()
190+
sum_scaled = sum_float32 * act_scale_tmp * w_scale_tmp
191+
192+
# accumulate
193+
q_output[m_idx][n_idx] += sum_scaled
194+
195+
return q_output
196+
197+
198+
class Int8PerTokenActivationInt4PerGroupWeightLinear(torch.nn.Linear):
199+
def __init__(self, *args, **kwargs):
200+
super().__init__(*args, **kwargs)
201+
# manually create fake quantizer configs
202+
activation_config = FakeQuantizeConfig(
203+
torch.int8, "per_token", is_symmetric=False
204+
)
205+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
206+
207+
# manually create fake quantizers
208+
# reference: `FakeQuantizedLinear` (https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/qat/linear.py)
209+
self.activation_fq = FakeQuantizer(activation_config)
210+
self.weight_fq = FakeQuantizer(weight_config)
211+
212+
def forward(self, input):
213+
# quantize x
214+
input_i8, input_scale, input_zp = quantize_x(input)
215+
216+
# quantize w
217+
_ = self.weight_fq(self.weight)
218+
w_qmin, w_qmax = _DTYPE_TO_QVALUE_BOUNDS[torch.int4]
219+
w_granularity = self.weight_fq.config.granularity
220+
w_group_size = w_granularity.group_size
221+
w_block_size = (1, w_group_size)
222+
weight_int4 = torch.ops.torchao.quantize_affine(
223+
self.weight,
224+
w_block_size,
225+
self.weight_fq.scale,
226+
self.weight_fq.zero_point,
227+
torch.int8,
228+
w_qmin,
229+
w_qmax,
230+
)
231+
232+
# original reference
233+
q_output = vasiliy_reference_i8i8f32_gemm(
234+
input_i8.to(torch.int32),
235+
input_scale,
236+
input_zp,
237+
weight_int4.to(torch.int32),
238+
self.weight_fq.scale,
239+
w_group_size,
240+
)
241+
242+
# now also check Kimish's implementation
243+
q_output2 = _xnnpack_integer_arithmetic_dqlinear_4bit_groupwise(
244+
input_i8.cpu(),
245+
input_scale.cpu(),
246+
input_zp.cpu(),
247+
weight_int4.cpu(),
248+
self.weight_fq.scale.cpu(),
249+
self.weight_fq.zero_point.cpu(),
250+
self.bias,
251+
self.weight_fq.config.granularity.group_size,
252+
).cuda()
253+
sqnr = compute_error(q_output, q_output2)
254+
print("sqnr vasiliy_reference vs kimish_reference", sqnr)
255+
256+
# finally, check vs triton gemm
257+
q_output3 = int8_matmul_triton(
258+
input_i8,
259+
weight_int4.t(),
260+
input_scale,
261+
input_zp,
262+
self.weight_fq.scale.t(),
263+
w_group_size,
264+
)
265+
266+
sqnr3 = compute_error(q_output, q_output3)
267+
print("sqnr vasiliy_reference vs triton", sqnr3)
268+
269+
sqnr4 = compute_error(q_output2, q_output3)
270+
print("sqnr kimish_reference vs triton", sqnr4)
271+
272+
return q_output, q_output2
273+
274+
@classmethod
275+
def from_float(cls, mod: torch.nn.Linear):
276+
new_mod = cls(mod.in_features, mod.out_features)
277+
new_mod.weight = mod.weight
278+
new_mod.bias = mod.bias
279+
return new_mod
280+
281+
282+
def run():
283+
M, K, N = 32, 64, 128
284+
285+
# TODO(before land): also implement bias=True
286+
m_hp = nn.Sequential(nn.Linear(K, N, bias=False)).cuda()
287+
mq_ref = copy.deepcopy(m_hp)
288+
mq = copy.deepcopy(m_hp)
289+
290+
# create a baseline: QAT with fake quants. Our exact QAT's output should
291+
# be close to this
292+
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
293+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
294+
quantize_(
295+
mq_ref,
296+
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
297+
)
298+
299+
# create the experiment: forward pass with an integer gemm
300+
mq[0] = Int8PerTokenActivationInt4PerGroupWeightLinear.from_float(mq[0])
301+
302+
x_hp = torch.randn(M, K, device="cuda")
303+
xq_ref = copy.deepcopy(x_hp)
304+
xq = copy.deepcopy(x_hp)
305+
306+
with torch.no_grad():
307+
y_hp = m_hp(x_hp)
308+
yq_ref = mq_ref(xq_ref)
309+
yq, yq2 = mq(xq)
310+
311+
sqnr_hp_qref = compute_error(y_hp, yq_ref)
312+
sqnr_hp_q = compute_error(y_hp, yq)
313+
sqnr_qref_q = compute_error(yq_ref, yq)
314+
print("sqnr_hp_qref", sqnr_hp_qref)
315+
print("sqnr_hp_q", sqnr_hp_q)
316+
print("sqnr_qref_q", sqnr_qref_q)
317+
318+
319+
if __name__ == "__main__":
320+
fire.Fire(run)

0 commit comments

Comments
 (0)