Skip to content

Commit 19c009d

Browse files
authored
TorchAO new observers (#2508)
TorchAO new observers (#2508) Summary: Added two observers, AffineQuantizedFixedQParamObserver (which allows manual range setting) and AffineQuantizedMSEObserver (which implements MSE range setting during the first forward pass) Bugfix in quant_primitives Reviewed By: jerryzh168 Differential Revision: D77906174
1 parent c1e84cc commit 19c009d

File tree

3 files changed

+236
-12
lines changed

3 files changed

+236
-12
lines changed

test/quantization/test_observer.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,14 @@
1414
from torch.testing._internal import common_utils
1515
from torch.testing._internal.common_utils import TestCase
1616

17-
from torchao.quantization.granularity import (
18-
PerAxis,
19-
PerTensor,
20-
)
17+
from torchao.quantization.granularity import PerAxis, PerTensor
2118
from torchao.quantization.observer import (
19+
AffineQuantizedFixedQParamObserver,
2220
AffineQuantizedMinMaxObserver,
21+
AffineQuantizedMSEObserver,
2322
)
24-
from torchao.quantization.quant_api import (
25-
insert_observers_,
26-
)
27-
from torchao.quantization.quant_primitives import (
28-
MappingType,
29-
ZeroPointDomain,
30-
)
23+
from torchao.quantization.quant_api import insert_observers_
24+
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
3125

3226

3327
class TestQuantFlow(TestCase):
@@ -145,6 +139,56 @@ def test_block_size_row_errors(self):
145139
for example_input in example_inputs:
146140
obs(example_input)
147141

142+
def test_mse_observer(self):
143+
obs = AffineQuantizedMSEObserver(
144+
MappingType.SYMMETRIC,
145+
torch.int8,
146+
granularity=PerAxis(0),
147+
eps=torch.finfo(torch.float32).eps,
148+
scale_dtype=torch.float,
149+
zero_point_dtype=torch.int,
150+
zero_point_domain=ZeroPointDomain.NONE,
151+
steps=100,
152+
run_once=True,
153+
)
154+
example_input = torch.randn(10, 2048)
155+
obs(example_input)
156+
157+
scale, zero_point = obs.calculate_qparams()
158+
self.assertIsNone(zero_point)
159+
160+
minmax_obs = AffineQuantizedMinMaxObserver(
161+
MappingType.SYMMETRIC,
162+
torch.int8,
163+
granularity=PerAxis(0),
164+
eps=torch.finfo(torch.float32).eps,
165+
scale_dtype=torch.float,
166+
zero_point_dtype=torch.int,
167+
zero_point_domain=ZeroPointDomain.NONE,
168+
)
169+
minmax_obs(example_input)
170+
min_val, max_val = minmax_obs.min_val, minmax_obs.max_val
171+
assert torch.all(
172+
obs.loss_fn(example_input, obs.min_val, obs.max_val)
173+
<= obs.loss_fn(example_input, min_val, max_val) + 1e6
174+
)
175+
176+
def test_fixed_qparams_observer(self):
177+
obs = AffineQuantizedFixedQParamObserver(
178+
MappingType.SYMMETRIC,
179+
torch.float8_e4m3fn,
180+
granularity=PerAxis(0),
181+
eps=torch.finfo(torch.float32).eps,
182+
scale_dtype=torch.float,
183+
zero_point_dtype=torch.int,
184+
zero_point_domain=ZeroPointDomain.NONE,
185+
)
186+
example_input = torch.randn(10, 2048)
187+
obs(example_input)
188+
obs.set_qparams(torch.ones(2048))
189+
scale, zero_point = obs.calculate_qparams()
190+
self.assertTrue(torch.allclose(scale, torch.ones(2048)))
191+
148192

149193
class TestLinearObserver(TestCase):
150194
@common_utils.parametrize("observe_weight", [True, False])

torchao/quantization/observer.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212

13+
from torchao.quantization.quant_primitives import _fake_quantize_affine
1314
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1415

1516
from .granularity import (
@@ -193,6 +194,185 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
193194
)
194195

195196

197+
class AffineQuantizedFixedQParamObserver(AffineQuantizedObserverBase):
198+
"""
199+
Observer that allows manual setting of fixed quantization parameters.
200+
"""
201+
202+
def __init__(
203+
self,
204+
mapping_type: MappingType,
205+
target_dtype: torch.dtype,
206+
granularity: Granularity,
207+
quant_min: Optional[int] = None,
208+
quant_max: Optional[int] = None,
209+
eps: Optional[float] = None,
210+
scale_dtype: Optional[torch.dtype] = None,
211+
zero_point_dtype: Optional[torch.dtype] = None,
212+
preserve_zero: bool = True,
213+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
214+
scale: Optional[torch.Tensor] = None,
215+
zero_point: Optional[torch.Tensor] = None,
216+
):
217+
super().__init__(
218+
mapping_type,
219+
target_dtype,
220+
granularity,
221+
quant_min,
222+
quant_max,
223+
eps,
224+
scale_dtype,
225+
zero_point_dtype,
226+
preserve_zero,
227+
zero_point_domain,
228+
)
229+
if not scale:
230+
scale = torch.Tensor([1])
231+
if not zero_point:
232+
zero_point = torch.zeros_like(scale)
233+
self.register_buffer("scale", scale.to(dtype=scale_dtype))
234+
self.register_buffer("zero_point", zero_point.to(dtype=zero_point_dtype))
235+
236+
def set_qparams(self, scale, zero_point=None):
237+
if not zero_point:
238+
zero_point = torch.zeros_like(scale)
239+
self.scale = scale.to(dtype=self.scale_dtype)
240+
self.zero_point = zero_point.to(dtype=self.zero_point_dtype)
241+
242+
def forward(self, input):
243+
return input
244+
245+
def calculate_qparams(self):
246+
return self.scale, self.zero_point
247+
248+
249+
class AffineQuantizedMSEObserver(AffineQuantizedObserverBase):
250+
"""
251+
Minimize quantization loss caused by outlier via linear search. More details can be found at https://arxiv.org/pdf/2209.13325
252+
"""
253+
254+
def __init__(
255+
self,
256+
mapping_type: MappingType,
257+
target_dtype: torch.dtype,
258+
granularity: Granularity,
259+
quant_min: Optional[int] = None,
260+
quant_max: Optional[int] = None,
261+
eps: Optional[float] = None,
262+
scale_dtype: Optional[torch.dtype] = None,
263+
zero_point_dtype: Optional[torch.dtype] = None,
264+
preserve_zero: bool = True,
265+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
266+
steps: int = 100,
267+
run_once: bool = False,
268+
):
269+
super().__init__(
270+
mapping_type,
271+
target_dtype,
272+
granularity,
273+
quant_min,
274+
quant_max,
275+
eps,
276+
scale_dtype,
277+
zero_point_dtype,
278+
preserve_zero,
279+
zero_point_domain,
280+
)
281+
self.steps = steps
282+
self.calibrated = False
283+
self.run_once = run_once
284+
285+
def mse(self, pred, expect, block_size):
286+
loss = (pred - expect).abs().pow(2)
287+
shape_for_reduction, reduction_dims = _get_reduction_params(
288+
block_size, loss.size()
289+
)
290+
loss = loss.view(shape_for_reduction)
291+
return torch.mean(loss, dim=reduction_dims, keepdim=False)
292+
293+
def loss_fn(self, x, new_min, new_max):
294+
block_size = get_block_size(x.shape, self.granularity)
295+
scale, zero_point = choose_qparams_affine_with_min_max(
296+
new_min,
297+
new_max,
298+
self.mapping_type,
299+
[],
300+
self.target_dtype,
301+
self.quant_min,
302+
self.quant_max,
303+
self.eps,
304+
self.scale_dtype,
305+
self.zero_point_dtype,
306+
self.preserve_zero,
307+
self.zero_point_domain,
308+
)
309+
x_q = _fake_quantize_affine(
310+
x,
311+
block_size,
312+
scale,
313+
zero_point,
314+
self.target_dtype,
315+
self.quant_min,
316+
self.quant_max,
317+
self.zero_point_domain,
318+
)
319+
return self.mse(x_q, x, block_size)
320+
321+
def line_search(self, input):
322+
if input.numel() == 0:
323+
return input
324+
325+
input_detached = input.detach()
326+
assert self.granularity is not None, "granularity is None"
327+
block_size = get_block_size(input_detached.shape, self.granularity)
328+
329+
shape_for_reduction, reduction_dims = _get_reduction_params(
330+
block_size, input_detached.size()
331+
)
332+
input_detached = input_detached.view(shape_for_reduction)
333+
min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False)
334+
max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False)
335+
336+
range_val = torch.max(min_val.abs(), max_val)
337+
optimal_loss = torch.zeros_like(min_val) + 1e9
338+
339+
# check which clip range could produce smallest loss
340+
for i in range(1, self.steps + 1):
341+
thres = range_val / self.steps * i
342+
current_loss = self.loss_fn(input, -thres, thres)
343+
min_val = torch.where(current_loss < optimal_loss, -thres, min_val)
344+
max_val = torch.where(current_loss < optimal_loss, thres, max_val)
345+
optimal_loss = torch.min(current_loss, optimal_loss)
346+
347+
return min_val, max_val
348+
349+
def forward(self, input):
350+
if not (self.run_once and self.calibrated):
351+
self.min_val, self.max_val = self.line_search(input)
352+
self.calibrated = True
353+
354+
return input
355+
356+
def calculate_qparams(self):
357+
assert hasattr(self, "min_val") and hasattr(self, "max_val"), (
358+
"Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
359+
)
360+
return choose_qparams_affine_with_min_max(
361+
self.min_val,
362+
self.max_val,
363+
self.mapping_type,
364+
[],
365+
self.target_dtype,
366+
self.quant_min,
367+
self.quant_max,
368+
self.eps,
369+
self.scale_dtype,
370+
self.zero_point_dtype,
371+
self.preserve_zero,
372+
self.zero_point_domain,
373+
)
374+
375+
196376
if TORCH_VERSION_AT_LEAST_2_5:
197377
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
198378
torch.serialization.add_safe_globals([PerRow, PerTensor])

torchao/quantization/quant_primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ def _do_fake_quantize_affine(
11721172
elif zero_point_domain == ZeroPointDomain.FLOAT:
11731173
_quantize_affine = _quantize_affine_tinygemm_no_dtype_cast
11741174
_dequantize_affine = _dequantize_affine_tinygemm_no_dtype_check
1175-
elif ZeroPointDomain == ZeroPointDomain.NONE:
1175+
elif zero_point_domain == ZeroPointDomain.NONE:
11761176
_quantize_affine = _quantize_affine_no_zero_point_no_dtype_cast
11771177
_dequantize_affine = _dequantize_affine_no_zero_point_no_dtype_check
11781178
else:

0 commit comments

Comments
 (0)