Skip to content

[not for land] towards QAT with exact forward pass #2529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions torchao/prototype/qat_exact/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
"""
Prototype of QAT with exact (instead of emulated) forward pass using
integer matrix multiply.

Quant spec:
* int4 symmetric weights w/ group size 32 or 256,
* int8 asymmetric per-token dynamic activations

"""

import copy

import fire
import torch
import torch.nn as nn

from torchao.float8.float8_utils import compute_error
from torchao.prototype.qat_exact.reference_gemm import (
cpu_x_token_assym_fp8_w_group_sym_int4_gemm,
naive_x_token_assym_fp8_w_group_sym_int4_gemm,
)
from torchao.prototype.qat_exact.triton_gemm import int8_matmul_triton
from torchao.quantization import quantize_
from torchao.quantization.qat import (
FakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization.qat.fake_quantizer import FakeQuantizer
from torchao.quantization.quant_primitives import (
_DTYPE_TO_QVALUE_BOUNDS,
MappingType,
)
from torchao.quantization.utils import (
_get_per_token_block_size,
)

torch.manual_seed(0)


def quantize_x(x_fp32):
# Dynamic quantization of activation
x_mapping_type = MappingType.ASYMMETRIC
per_token_block_size = _get_per_token_block_size(x_fp32)
x_quant_min, x_quant_max = _DTYPE_TO_QVALUE_BOUNDS[torch.int8]
x_eps = torch.finfo(torch.float32).eps
x_scales_type = torch.float32
x_zero_points_type = torch.int32
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
x_fp32,
x_mapping_type.name,
per_token_block_size,
torch.int8,
x_quant_min,
x_quant_max,
x_eps,
x_scales_type,
x_zero_points_type,
)
x_i8 = torch.ops.torchao.quantize_affine(
x_fp32,
per_token_block_size,
x_scale,
x_zero_point,
torch.int8,
x_quant_min,
x_quant_max,
)
return x_i8, x_scale, x_zero_point


class Int8PerTokenActivationInt4PerGroupWeightLinear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
gemm_mode = kwargs.pop("gemm_mode")
assert gemm_mode in (
"int8_naive_reference",
"int8_cpu_reference",
"int8_triton",
)
super().__init__(*args, **kwargs)
# manually create fake quantizer configs
activation_config = FakeQuantizeConfig(
torch.int8, "per_token", is_symmetric=False
)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)

# manually create fake quantizers
# reference: `FakeQuantizedLinear` (https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/qat/linear.py)
self.activation_fq = FakeQuantizer(activation_config)
self.weight_fq = FakeQuantizer(weight_config)
self.gemm_mode = gemm_mode

def forward(self, input):
# quantize x
input_i8, input_scale, input_zp = quantize_x(input)

# quantize w
_ = self.weight_fq(self.weight)
w_qmin, w_qmax = _DTYPE_TO_QVALUE_BOUNDS[torch.int4]
w_granularity = self.weight_fq.config.granularity
w_group_size = w_granularity.group_size
w_block_size = (1, w_group_size)
weight_int4 = torch.ops.torchao.quantize_affine(
self.weight,
w_block_size,
self.weight_fq.scale,
self.weight_fq.zero_point,
torch.int8,
w_qmin,
w_qmax,
)

if self.gemm_mode == "int8_naive_reference":
# original reference
q_output = naive_x_token_assym_fp8_w_group_sym_int4_gemm(
input_i8.to(torch.int32),
input_scale,
input_zp,
weight_int4.to(torch.int32),
self.weight_fq.scale,
w_group_size,
)
elif self.gemm_mode == "int8_cpu_reference":
# now also check Kimish's implementation
q_output = cpu_x_token_assym_fp8_w_group_sym_int4_gemm(
input_i8.cpu(),
input_scale.cpu(),
input_zp.cpu(),
weight_int4.cpu(),
self.weight_fq.scale.cpu(),
self.weight_fq.zero_point.cpu(),
self.bias,
self.weight_fq.config.granularity.group_size,
).cuda()
elif self.gemm_mode == "int8_triton":
# finally, check vs triton gemm
q_output = int8_matmul_triton(
input_i8,
weight_int4.t(),
input_scale,
input_zp,
self.weight_fq.scale.t(),
w_group_size,
)

return q_output

@classmethod
def from_float(cls, mod: torch.nn.Linear, gemm_mode: str):
new_mod = cls(mod.in_features, mod.out_features, gemm_mode=gemm_mode)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
return new_mod


def run():
M, K, N = 32, 64, 128

# TODO(before land): also implement bias=True
m_hp = nn.Sequential(nn.Linear(K, N, bias=False)).cuda()
mq_ref = copy.deepcopy(m_hp)
mq_naive = copy.deepcopy(m_hp)
mq_cpu = copy.deepcopy(m_hp)
mq_triton = copy.deepcopy(m_hp)

# create a baseline: QAT with fake quants. Our exact QAT's output should
# be close to this
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
mq_ref,
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
)

# create the experiment: forward pass with an integer gemm
mq_naive[0] = Int8PerTokenActivationInt4PerGroupWeightLinear.from_float(
mq_naive[0], "int8_naive_reference"
)
mq_cpu[0] = Int8PerTokenActivationInt4PerGroupWeightLinear.from_float(
mq_cpu[0], "int8_cpu_reference"
)
mq_triton[0] = Int8PerTokenActivationInt4PerGroupWeightLinear.from_float(
mq_triton[0], "int8_triton"
)

x_hp = torch.randn(M, K, device="cuda")
xq_ref = copy.deepcopy(x_hp)
xq = copy.deepcopy(x_hp)

with torch.no_grad():
y_hp = m_hp(x_hp)
yq_ref = mq_ref(xq_ref)
yq_naive = mq_naive(xq)
yq_cpu = mq_cpu(xq)
yq_triton = mq_triton(xq)

sqnr_hp_qref = compute_error(y_hp, yq_ref)
sqnr_hp_qnaive = compute_error(y_hp, yq_naive)
sqnr_qref_qnaive = compute_error(yq_ref, yq_naive)
sqnr_qcpu_qnaive = compute_error(yq_cpu, yq_naive)
sqnr_qcpu_qtriton = compute_error(yq_cpu, yq_triton)
sqnr_qnaive_qtriton = compute_error(yq_naive, yq_triton)
print("sqnr_hp_qref", sqnr_hp_qref)
print("sqnr_hp_qnaive", sqnr_hp_qnaive)
print("sqnr_qref_qnaive", sqnr_qref_qnaive)
print("sqnr_qcpu_qnaive", sqnr_qcpu_qnaive)
print("sqnr_qcpu_triton", sqnr_qcpu_qtriton)
print("sqnr_qnaive_qtriton", sqnr_qnaive_qtriton)


if __name__ == "__main__":
fire.Fire(run)
132 changes: 132 additions & 0 deletions torchao/prototype/qat_exact/reference_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import torch
from torch._higher_order_ops.out_dtype import out_dtype


def cpu_x_token_assym_fp8_w_group_sym_int4_gemm(
x_i8,
x_scale,
x_zero_point,
weight_int4,
weight_scale,
weight_zero_point,
bias_fp32,
group_size,
):
# For groupwise quantization, we need to handle the computation differently
# weight_i4 shape: [out_features, in_features]
# weight_scale shape: [out_features, in_features // group_size]
# weight_zero_point shape: [out_features, in_features // group_size]
out_features, in_features = weight_int4.shape
num_groups = in_features // group_size

# scales in xnnpack are stored as bf16 and converted to fp32 for computation
weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32)

assert x_i8.dim() == 2, "x_i8 must be 2D tensor"
# Reshape for group-wise processing
# x: [batch_size, in_features] -> [batch_size, num_groups, group_size]
batch_size = x_i8.shape[0]
x_i8_grouped = x_i8.view(batch_size, num_groups, group_size)

# weight: [out_features, in_features] -> [out_features, num_groups, group_size]
weight_i4_grouped = weight_int4.view(out_features, num_groups, group_size)

# Convert to int16 for computation
x_i32_grouped = x_i8_grouped.to(torch.int32)
weight_i32_grouped = weight_i4_grouped.to(torch.int32)

# Perform groupwise integer linear operation
acc_fp32 = torch.zeros(
batch_size, out_features, dtype=torch.float32, device=x_i8.device
)

for group_idx in range(num_groups):
# Extract current group
x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size]
weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size]
weight_group_col_sum = weight_group.sum(dim=-1) # [out_features]

# Get scale for this group
weight_scale_group = weight_scale[:, group_idx] # [out_features]

# Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features]
group_acc = out_dtype(
torch.ops.aten.linear.default,
torch.int32,
x_group,
weight_group,
None,
)

# Output has to be scaled by x_scale * weight_scale_group
# However we will first scale by weight_scale_group, that is accounting
# only for scale of weight, and then scale by x_scale at the end because
# x_scale applies to all groups
acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view(
1, -1
)

# we must also subtract x_zero_point * weight_group_sum
# since (X - x_zero_point) * W = X * W - x_zero_point * W
weights_col_sum_adjusted = (
weight_group_col_sum.to(torch.float32).view(1, -1)
* x_zero_point.view(-1, 1)
* weight_scale_group.view(1, -1)
)
acc_fp32 = acc_fp32 - weights_col_sum_adjusted
x_scale_multiplier = x_scale.view(-1, 1)
out_fp32 = acc_fp32 * x_scale_multiplier
if bias_fp32 is not None:
out_fp32 = out_fp32 + bias_fp32

return out_fp32


def naive_x_token_assym_fp8_w_group_sym_int4_gemm(
act_q,
act_scale,
act_zp,
w_q,
w_scale,
w_group_size,
) -> torch.Tensor:
#
# now we have the scales/zero_points/quant values for both gemm operands
# below is a manual slow gemm with integer operands and float rescaling,
# implemented using eager PyTorch ops. This should be slow but closely
# (but not exactly) matching a real int8,int8->int32 gemm with
# rescaling, with the only difference being that the sum inside of the
# dot product is done in float32 right now.
#
q_output = torch.zeros(
act_q.shape[0],
w_q.shape[0],
dtype=torch.float32,
device=act_q.device,
)
for m_idx in range(act_q.shape[0]):
for n_idx in range(w_q.shape[0]):
for g_idx in range(w_q.shape[1] // w_group_size):
k_start = g_idx * w_group_size
k_end = k_start + w_group_size
act_chunk = act_q[m_idx][k_start:k_end]
w_chunk = w_q[n_idx][k_start:k_end]

# (act_q - act_zp) * w_q
# result still in int32
elem_int32 = (act_chunk - act_zp[m_idx]) * w_chunk

# sum((act_q - act_zp) * w_q)
# this is in float32, so likely a small deviation from the real
# kernel, where the entire dot product would be in int32
sum_float32 = torch.sum(elem_int32)

# scale
act_scale_tmp = act_scale[m_idx].squeeze(-1)
w_scale_tmp = w_scale[n_idx][g_idx].squeeze(-1).bfloat16().float()
sum_scaled = sum_float32 * act_scale_tmp * w_scale_tmp

# accumulate
q_output[m_idx][n_idx] += sum_scaled

return q_output
Loading
Loading