Skip to content

Commit ff2153c

Browse files
committed
[not for land] towards QAT with exact forward pass
Summary: Exploration for QAT with exact (not emulated) forward pass, WIP and not ready for review. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent c1e84cc commit ff2153c

File tree

3 files changed

+642
-0
lines changed

3 files changed

+642
-0
lines changed

torchao/prototype/qat_exact/main.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""
2+
Prototype of QAT with exact (instead of emulated) forward pass using
3+
integer matrix multiply.
4+
5+
Quant spec:
6+
* int4 symmetric weights w/ group size 32 or 256,
7+
* int8 asymmetric per-token dynamic activations
8+
9+
"""
10+
11+
import copy
12+
13+
import fire
14+
import torch
15+
import torch.nn as nn
16+
17+
from torchao.float8.float8_utils import compute_error
18+
from torchao.prototype.qat_exact.reference_gemm import (
19+
cpu_x_token_assym_fp8_w_group_sym_int4_gemm,
20+
naive_x_token_assym_fp8_w_group_sym_int4_gemm,
21+
)
22+
from torchao.prototype.qat_exact.triton_gemm import int8_matmul_triton
23+
from torchao.quantization import quantize_
24+
from torchao.quantization.qat import (
25+
FakeQuantizeConfig,
26+
IntXQuantizationAwareTrainingConfig,
27+
)
28+
from torchao.quantization.qat.fake_quantizer import FakeQuantizer
29+
from torchao.quantization.quant_primitives import (
30+
_DTYPE_TO_QVALUE_BOUNDS,
31+
MappingType,
32+
)
33+
from torchao.quantization.utils import (
34+
_get_per_token_block_size,
35+
)
36+
37+
torch.manual_seed(0)
38+
39+
40+
def quantize_x(x_fp32):
41+
# Dynamic quantization of activation
42+
x_mapping_type = MappingType.ASYMMETRIC
43+
per_token_block_size = _get_per_token_block_size(x_fp32)
44+
x_quant_min, x_quant_max = _DTYPE_TO_QVALUE_BOUNDS[torch.int8]
45+
x_eps = torch.finfo(torch.float32).eps
46+
x_scales_type = torch.float32
47+
x_zero_points_type = torch.int32
48+
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
49+
x_fp32,
50+
x_mapping_type.name,
51+
per_token_block_size,
52+
torch.int8,
53+
x_quant_min,
54+
x_quant_max,
55+
x_eps,
56+
x_scales_type,
57+
x_zero_points_type,
58+
)
59+
x_i8 = torch.ops.torchao.quantize_affine(
60+
x_fp32,
61+
per_token_block_size,
62+
x_scale,
63+
x_zero_point,
64+
torch.int8,
65+
x_quant_min,
66+
x_quant_max,
67+
)
68+
return x_i8, x_scale, x_zero_point
69+
70+
71+
class Int8PerTokenActivationInt4PerGroupWeightLinear(torch.nn.Linear):
72+
def __init__(self, *args, **kwargs):
73+
gemm_mode = kwargs.pop("gemm_mode")
74+
assert gemm_mode in (
75+
"int8_naive_reference",
76+
"int8_cpu_reference",
77+
"int8_triton",
78+
)
79+
super().__init__(*args, **kwargs)
80+
# manually create fake quantizer configs
81+
activation_config = FakeQuantizeConfig(
82+
torch.int8, "per_token", is_symmetric=False
83+
)
84+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
85+
86+
# manually create fake quantizers
87+
# reference: `FakeQuantizedLinear` (https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/qat/linear.py)
88+
self.activation_fq = FakeQuantizer(activation_config)
89+
self.weight_fq = FakeQuantizer(weight_config)
90+
self.gemm_mode = gemm_mode
91+
92+
def forward(self, input):
93+
# quantize x
94+
input_i8, input_scale, input_zp = quantize_x(input)
95+
96+
# quantize w
97+
_ = self.weight_fq(self.weight)
98+
w_qmin, w_qmax = _DTYPE_TO_QVALUE_BOUNDS[torch.int4]
99+
w_granularity = self.weight_fq.config.granularity
100+
w_group_size = w_granularity.group_size
101+
w_block_size = (1, w_group_size)
102+
weight_int4 = torch.ops.torchao.quantize_affine(
103+
self.weight,
104+
w_block_size,
105+
self.weight_fq.scale,
106+
self.weight_fq.zero_point,
107+
torch.int8,
108+
w_qmin,
109+
w_qmax,
110+
)
111+
112+
if self.gemm_mode == "int8_naive_reference":
113+
# original reference
114+
q_output = naive_x_token_assym_fp8_w_group_sym_int4_gemm(
115+
input_i8.to(torch.int32),
116+
input_scale,
117+
input_zp,
118+
weight_int4.to(torch.int32),
119+
self.weight_fq.scale,
120+
w_group_size,
121+
)
122+
elif self.gemm_mode == "int8_cpu_reference":
123+
# now also check Kimish's implementation
124+
q_output = cpu_x_token_assym_fp8_w_group_sym_int4_gemm(
125+
input_i8.cpu(),
126+
input_scale.cpu(),
127+
input_zp.cpu(),
128+
weight_int4.cpu(),
129+
self.weight_fq.scale.cpu(),
130+
self.weight_fq.zero_point.cpu(),
131+
self.bias,
132+
self.weight_fq.config.granularity.group_size,
133+
).cuda()
134+
elif self.gemm_mode == "int8_triton":
135+
# finally, check vs triton gemm
136+
q_output = int8_matmul_triton(
137+
input_i8,
138+
weight_int4.t(),
139+
input_scale,
140+
input_zp,
141+
self.weight_fq.scale.t(),
142+
w_group_size,
143+
)
144+
145+
return q_output
146+
147+
@classmethod
148+
def from_float(cls, mod: torch.nn.Linear, gemm_mode: str):
149+
new_mod = cls(mod.in_features, mod.out_features, gemm_mode=gemm_mode)
150+
new_mod.weight = mod.weight
151+
new_mod.bias = mod.bias
152+
return new_mod
153+
154+
155+
def run():
156+
M, K, N = 32, 64, 128
157+
158+
# TODO(before land): also implement bias=True
159+
m_hp = nn.Sequential(nn.Linear(K, N, bias=False)).cuda()
160+
mq_ref = copy.deepcopy(m_hp)
161+
mq_naive = copy.deepcopy(m_hp)
162+
mq_cpu = copy.deepcopy(m_hp)
163+
mq_triton = copy.deepcopy(m_hp)
164+
165+
# create a baseline: QAT with fake quants. Our exact QAT's output should
166+
# be close to this
167+
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
168+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
169+
quantize_(
170+
mq_ref,
171+
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
172+
)
173+
174+
# create the experiment: forward pass with an integer gemm
175+
mq_naive[0] = Int8PerTokenActivationInt4PerGroupWeightLinear.from_float(
176+
mq_naive[0], "int8_naive_reference"
177+
)
178+
mq_cpu[0] = Int8PerTokenActivationInt4PerGroupWeightLinear.from_float(
179+
mq_cpu[0], "int8_cpu_reference"
180+
)
181+
mq_triton[0] = Int8PerTokenActivationInt4PerGroupWeightLinear.from_float(
182+
mq_triton[0], "int8_triton"
183+
)
184+
185+
x_hp = torch.randn(M, K, device="cuda")
186+
xq_ref = copy.deepcopy(x_hp)
187+
xq = copy.deepcopy(x_hp)
188+
189+
with torch.no_grad():
190+
y_hp = m_hp(x_hp)
191+
yq_ref = mq_ref(xq_ref)
192+
yq_naive = mq_naive(xq)
193+
yq_cpu = mq_cpu(xq)
194+
yq_triton = mq_triton(xq)
195+
196+
sqnr_hp_qref = compute_error(y_hp, yq_ref)
197+
sqnr_hp_qnaive = compute_error(y_hp, yq_naive)
198+
sqnr_qref_qnaive = compute_error(yq_ref, yq_naive)
199+
sqnr_qcpu_qnaive = compute_error(yq_cpu, yq_naive)
200+
sqnr_qcpu_qtriton = compute_error(yq_cpu, yq_triton)
201+
sqnr_qnaive_qtriton = compute_error(yq_naive, yq_triton)
202+
print("sqnr_hp_qref", sqnr_hp_qref)
203+
print("sqnr_hp_qnaive", sqnr_hp_qnaive)
204+
print("sqnr_qref_qnaive", sqnr_qref_qnaive)
205+
print("sqnr_qcpu_qnaive", sqnr_qcpu_qnaive)
206+
print("sqnr_qcpu_triton", sqnr_qcpu_qtriton)
207+
print("sqnr_qnaive_qtriton", sqnr_qnaive_qtriton)
208+
209+
210+
if __name__ == "__main__":
211+
fire.Fire(run)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import torch
2+
from torch._higher_order_ops.out_dtype import out_dtype
3+
4+
5+
def cpu_x_token_assym_fp8_w_group_sym_int4_gemm(
6+
x_i8,
7+
x_scale,
8+
x_zero_point,
9+
weight_int4,
10+
weight_scale,
11+
weight_zero_point,
12+
bias_fp32,
13+
group_size,
14+
):
15+
# For groupwise quantization, we need to handle the computation differently
16+
# weight_i4 shape: [out_features, in_features]
17+
# weight_scale shape: [out_features, in_features // group_size]
18+
# weight_zero_point shape: [out_features, in_features // group_size]
19+
out_features, in_features = weight_int4.shape
20+
num_groups = in_features // group_size
21+
22+
# scales in xnnpack are stored as bf16 and converted to fp32 for computation
23+
weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32)
24+
25+
assert x_i8.dim() == 2, "x_i8 must be 2D tensor"
26+
# Reshape for group-wise processing
27+
# x: [batch_size, in_features] -> [batch_size, num_groups, group_size]
28+
batch_size = x_i8.shape[0]
29+
x_i8_grouped = x_i8.view(batch_size, num_groups, group_size)
30+
31+
# weight: [out_features, in_features] -> [out_features, num_groups, group_size]
32+
weight_i4_grouped = weight_int4.view(out_features, num_groups, group_size)
33+
34+
# Convert to int16 for computation
35+
x_i32_grouped = x_i8_grouped.to(torch.int32)
36+
weight_i32_grouped = weight_i4_grouped.to(torch.int32)
37+
38+
# Perform groupwise integer linear operation
39+
acc_fp32 = torch.zeros(
40+
batch_size, out_features, dtype=torch.float32, device=x_i8.device
41+
)
42+
43+
for group_idx in range(num_groups):
44+
# Extract current group
45+
x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size]
46+
weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size]
47+
weight_group_col_sum = weight_group.sum(dim=-1) # [out_features]
48+
49+
# Get scale for this group
50+
weight_scale_group = weight_scale[:, group_idx] # [out_features]
51+
52+
# Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features]
53+
group_acc = out_dtype(
54+
torch.ops.aten.linear.default,
55+
torch.int32,
56+
x_group,
57+
weight_group,
58+
None,
59+
)
60+
61+
# Output has to be scaled by x_scale * weight_scale_group
62+
# However we will first scale by weight_scale_group, that is accounting
63+
# only for scale of weight, and then scale by x_scale at the end because
64+
# x_scale applies to all groups
65+
acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view(
66+
1, -1
67+
)
68+
69+
# we must also subtract x_zero_point * weight_group_sum
70+
# since (X - x_zero_point) * W = X * W - x_zero_point * W
71+
weights_col_sum_adjusted = (
72+
weight_group_col_sum.to(torch.float32).view(1, -1)
73+
* x_zero_point.view(-1, 1)
74+
* weight_scale_group.view(1, -1)
75+
)
76+
acc_fp32 = acc_fp32 - weights_col_sum_adjusted
77+
x_scale_multiplier = x_scale.view(-1, 1)
78+
out_fp32 = acc_fp32 * x_scale_multiplier
79+
if bias_fp32 is not None:
80+
out_fp32 = out_fp32 + bias_fp32
81+
82+
return out_fp32
83+
84+
85+
def naive_x_token_assym_fp8_w_group_sym_int4_gemm(
86+
act_q,
87+
act_scale,
88+
act_zp,
89+
w_q,
90+
w_scale,
91+
w_group_size,
92+
) -> torch.Tensor:
93+
#
94+
# now we have the scales/zero_points/quant values for both gemm operands
95+
# below is a manual slow gemm with integer operands and float rescaling,
96+
# implemented using eager PyTorch ops. This should be slow but closely
97+
# (but not exactly) matching a real int8,int8->int32 gemm with
98+
# rescaling, with the only difference being that the sum inside of the
99+
# dot product is done in float32 right now.
100+
#
101+
q_output = torch.zeros(
102+
act_q.shape[0],
103+
w_q.shape[0],
104+
dtype=torch.float32,
105+
device=act_q.device,
106+
)
107+
for m_idx in range(act_q.shape[0]):
108+
for n_idx in range(w_q.shape[0]):
109+
for g_idx in range(w_q.shape[1] // w_group_size):
110+
k_start = g_idx * w_group_size
111+
k_end = k_start + w_group_size
112+
act_chunk = act_q[m_idx][k_start:k_end]
113+
w_chunk = w_q[n_idx][k_start:k_end]
114+
115+
# (act_q - act_zp) * w_q
116+
# result still in int32
117+
elem_int32 = (act_chunk - act_zp[m_idx]) * w_chunk
118+
119+
# sum((act_q - act_zp) * w_q)
120+
# this is in float32, so likely a small deviation from the real
121+
# kernel, where the entire dot product would be in int32
122+
sum_float32 = torch.sum(elem_int32)
123+
124+
# scale
125+
act_scale_tmp = act_scale[m_idx].squeeze(-1)
126+
w_scale_tmp = w_scale[n_idx][g_idx].squeeze(-1).bfloat16().float()
127+
sum_scaled = sum_float32 * act_scale_tmp * w_scale_tmp
128+
129+
# accumulate
130+
q_output[m_idx][n_idx] += sum_scaled
131+
132+
return q_output

0 commit comments

Comments
 (0)