Skip to content

Commit f64d5a1

Browse files
authored
Add bias support for Int8DynActInt4WeightLinear (#1845)
**Summary:** Previously, when we see a linear with bias, we simply do not swap it to `Int8DynActInt4WeightLinear` and leave it as is. Now we do swap it, but bias is not quantized and passed to F.linear in full precision. Fixes #1821 **Test Plan:** python test/quantization/test_quant_api.py -k test_8da4w_quantizer_linear_bias
1 parent 073b4f0 commit f64d5a1

File tree

4 files changed

+40
-30
lines changed

4 files changed

+40
-30
lines changed

test/quantization/test_qat.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,22 +1043,10 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
10431043
)
10441044
def test_replace_linear_8da4w(self):
10451045
module = torch.nn.ModuleList(
1046-
[torch.nn.Linear(in_features=256, out_features=50, bias=True)]
1047-
)
1048-
_replace_linear_8da4w(
1049-
module,
1050-
256,
1051-
False,
1052-
torch.float32,
1053-
torch.float32,
1054-
Int8DynActInt4WeightQATLinear,
1055-
copy_weights=True,
1056-
)
1057-
assert not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance(
1058-
module[0], torch.nn.Linear
1059-
)
1060-
module = torch.nn.ModuleList(
1061-
[torch.nn.Linear(in_features=256, out_features=50, bias=False)]
1046+
[
1047+
torch.nn.Linear(in_features=256, out_features=50, bias=True),
1048+
torch.nn.Linear(in_features=256, out_features=50, bias=False),
1049+
]
10621050
)
10631051
_replace_linear_8da4w(
10641052
module,
@@ -1070,6 +1058,7 @@ def test_replace_linear_8da4w(self):
10701058
copy_weights=True,
10711059
)
10721060
assert isinstance(module[0], Int8DynActInt4WeightQATLinear)
1061+
assert isinstance(module[1], Int8DynActInt4WeightQATLinear)
10731062

10741063
@unittest.skipIf(
10751064
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"

test/quantization/test_quant_api.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
115115

116116

117117
class ToyLinearModel(torch.nn.Module):
118-
def __init__(self, m=64, n=32, k=64):
118+
def __init__(self, m=64, n=32, k=64, bias=False):
119119
super().__init__()
120-
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
121-
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
120+
self.linear1 = torch.nn.Linear(m, n, bias=bias).to(torch.float)
121+
self.linear2 = torch.nn.Linear(n, k, bias=bias).to(torch.float)
122122

123123
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
124124
return (
@@ -272,6 +272,21 @@ def test_8da4w_quantizer(self):
272272
assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
273273
m(*example_inputs)
274274

275+
@unittest.skipIf(
276+
not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower"
277+
)
278+
def test_8da4w_quantizer_linear_bias(self):
279+
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
280+
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
281+
282+
quantizer = Int8DynActInt4WeightQuantizer(groupsize=32)
283+
m = ToyLinearModel(bias=True).eval()
284+
example_inputs = m.example_inputs()
285+
m = quantizer.quantize(m)
286+
assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
287+
assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
288+
m(*example_inputs)
289+
275290
# TODO: save model weights as artifacts and re-enable in CI
276291
# For now, to run this test, you will need to download the weights from HF
277292
# and run this script to convert them:

torchao/quantization/GPTQ.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,7 @@ def quantize(
923923
def linear_forward_8da4w(
924924
x,
925925
weight_int8,
926+
bias,
926927
scales,
927928
zeros,
928929
out_features,
@@ -956,7 +957,7 @@ def linear_forward_8da4w(
956957

957958
# x = x.to(torch.float16)
958959
# w_dq = w_dq.to(torch.float16)
959-
c = torch.nn.functional.linear(x, w_dq)
960+
c = torch.nn.functional.linear(x, w_dq, bias)
960961

961962
# new_shape = origin_x_size[:-1] + (out_features,)
962963
# c = c.reshape(new_shape)
@@ -970,6 +971,7 @@ class Int8DynActInt4WeightLinear(torch.nn.Module):
970971
in_features: int
971972
out_features: int
972973
weight: torch.Tensor
974+
bias: torch.Tensor
973975

974976
"""
975977
This module implements a dynamic quantized linear layer with int4 weight.
@@ -1003,7 +1005,6 @@ def __init__(
10031005
# )
10041006
self.in_features = in_features
10051007
self.out_features = out_features
1006-
assert not bias, "require bias=False"
10071008
# TODO: align groupsize naming
10081009
self.groupsize = groupsize
10091010
# Precision of the activation which also indicates
@@ -1034,13 +1035,19 @@ def __init__(
10341035
),
10351036
)
10361037

1038+
if bias:
1039+
self.register_buffer("bias", torch.zeros(out_features, dtype=precision))
1040+
else:
1041+
self.bias = None
1042+
10371043
def forward(self, input: torch.Tensor) -> torch.Tensor:
10381044
input = input.to(self.precision)
10391045
# padding is removed for perf
10401046
# input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
10411047
return linear_forward_8da4w(
10421048
input,
10431049
self.weight,
1050+
self.bias,
10441051
self.scales,
10451052
self.zeros,
10461053
self.out_features,
@@ -1062,18 +1069,15 @@ def _replace_linear_8da4w(
10621069
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
10631070

10641071
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
1065-
# TODO: support linear bias
1066-
return (
1067-
isinstance(child, nn.Linear)
1068-
and child.bias is None
1069-
and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed)
1072+
return isinstance(child, nn.Linear) and (
1073+
_check_linear_int4_k(child.in_features, groupsize) or padding_allowed
10701074
)
10711075

10721076
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
10731077
new_linear = linear_class(
10741078
child.in_features,
10751079
child.out_features,
1076-
bias=False,
1080+
bias=child.bias is not None,
10771081
device=child.weight.device,
10781082
groupsize=groupsize,
10791083
precision=precision,
@@ -1084,6 +1088,7 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
10841088
# copy the weights, and doing so will result in an error
10851089
if copy_weights and child.weight.device != torch.device("meta"):
10861090
new_linear.weight = child.weight
1091+
new_linear.bias = child.bias
10871092
return new_linear
10881093

10891094
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
@@ -1130,7 +1135,7 @@ def _create_quantized_state_dict(
11301135
) -> Dict[str, torch.Tensor]:
11311136
cur_state_dict = model.state_dict()
11321137
for fqn, mod in model.named_modules():
1133-
if isinstance(mod, torch.nn.Linear) and mod.bias is None:
1138+
if isinstance(mod, torch.nn.Linear):
11341139
out_features = mod.out_features
11351140
in_features = mod.in_features
11361141
# assert out_features % 8 == 0, "require out_features % 8 == 0"
@@ -1172,7 +1177,6 @@ def _create_quantized_state_dict(
11721177
cur_state_dict[f"{fqn}.weight"] = weight_int8.to(self.device)
11731178
cur_state_dict[f"{fqn}.scales"] = scales.to(self.device)
11741179
cur_state_dict[f"{fqn}.zeros"] = zeros.to(self.device)
1175-
# TODO: support bias?
11761180

11771181
return cur_state_dict
11781182

torchao/quantization/qat/linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
208208
quantized_linear = Int8DynActInt4WeightLinear(
209209
child.in_features,
210210
child.out_features,
211-
bias=False,
211+
child.bias is not None,
212212
groupsize=config.group_size,
213213
precision=child.weight.dtype,
214214
scales_precision=config.scale_precision,
@@ -237,6 +237,8 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
237237
quantized_linear.weight = q_weight
238238
quantized_linear.scales = s
239239
quantized_linear.zeros = zp
240+
if child.bias is not None:
241+
quantized_linear.bias = child.bias
240242
else:
241243
self._convert_qat_linear_8da4w(child)
242244

0 commit comments

Comments
 (0)