Skip to content

Commit 878ec7a

Browse files
authored
Add linear bias support for QAT (#1755)
**Summary:** Add linear bias support for QAT, which previously resulted in the following unintuitive error message: ``` RuntimeError: Boolean value of Tensor with more than one value is ambiguous ``` Note that we don't fake quantize the bias still. We just support applying QAT on linear modules with bias. **Test Plan:** python test/quantization/test_qat.py -k test_qat_linear_bias
1 parent e0f7148 commit 878ec7a

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

test/quantization/test_qat.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,21 @@ def forward(self, x):
133133
return x
134134

135135

136+
class ModelWithLinearBias(torch.nn.Module):
137+
def __init__(self):
138+
super().__init__()
139+
self.linear1 = torch.nn.Linear(512, 256, bias=True)
140+
self.linear2 = torch.nn.Linear(256, 512, bias=True)
141+
142+
def example_inputs(self):
143+
return (torch.randn(1, 512),)
144+
145+
def forward(self, x):
146+
x = self.linear1(x)
147+
x = self.linear2(x)
148+
return x
149+
150+
136151
class TestQAT(unittest.TestCase):
137152
SEED = 123
138153

@@ -1366,6 +1381,25 @@ def test_fake_quantizer_repr(self):
13661381
self.assertTrue("PerGroup" in fake_quantizer_repr)
13671382
self.assertTrue("MappingType.SYMMETRIC" in fake_quantizer_repr)
13681383

1384+
@unittest.skipIf(
1385+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1386+
)
1387+
def test_qat_linear_bias(self):
1388+
"""
1389+
Test that QAT supports linear bias.
1390+
"""
1391+
m = ModelWithLinearBias()
1392+
activation_config = FakeQuantizeConfig(
1393+
torch.int8, "per_token", is_symmetric=False
1394+
)
1395+
weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=32)
1396+
quantize_(
1397+
m,
1398+
intx_quantization_aware_training(activation_config, weight_config),
1399+
)
1400+
example_inputs = m.example_inputs()
1401+
m(*example_inputs)
1402+
13691403

13701404
if __name__ == "__main__":
13711405
unittest.main()

torchao/quantization/qat/linear.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,6 @@ def __init__(
7575
*args,
7676
**kwargs,
7777
)
78-
if bias:
79-
raise NotImplementedError("bias not supported yet")
80-
8178
# initialize activation fake quantizer
8279
if activation_config is not None:
8380
self.activation_fake_quantizer = FakeQuantizer(activation_config)
@@ -103,17 +100,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
103100
w = self.weight_fake_quantizer(self.weight)
104101
else:
105102
w = self.weight
106-
return F.linear(x, w)
103+
return F.linear(x, w, self.bias)
107104

108105
def to_linear(self) -> torch.nn.Linear:
109106
new_linear = torch.nn.Linear(
110-
self.in_features, self.out_features, self.bias, device=self.weight.device
107+
self.in_features,
108+
self.out_features,
109+
self.bias is not None,
110+
device=self.weight.device,
111111
)
112112
# In distributed training, the model may be instantiated
113113
# on the meta device, in which case there is no need to
114114
# copy the weights, and doing so will result in an error
115115
if self.weight.device != torch.device("meta"):
116116
new_linear.weight = self.weight
117+
new_linear.bias = self.bias
117118
return new_linear
118119

119120
@classmethod
@@ -126,7 +127,7 @@ def from_linear(
126127
new_linear = FakeQuantizedLinear(
127128
mod.in_features,
128129
mod.out_features,
129-
mod.bias,
130+
mod.bias is not None,
130131
activation_config=activation_config,
131132
weight_config=weight_config,
132133
device=mod.weight.device,
@@ -136,6 +137,7 @@ def from_linear(
136137
# copy the weights, and doing so will result in an error
137138
if mod.weight.device != torch.device("meta"):
138139
new_linear.weight = mod.weight
140+
new_linear.bias = mod.bias
139141
return new_linear
140142

141143

0 commit comments

Comments
 (0)