From b0b690f1b1497a98d203ac73fbb9cf1da86e8644 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Thu, 27 Feb 2025 22:57:26 +0100 Subject: [PATCH] Fix wrong scale eps applied --- test/quantization/test_quant_primitives.py | 60 ++++++++++++++++++++++ torchao/quantization/quant_primitives.py | 19 ++++++- 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 861ebe5e94..974f554000 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -961,6 +961,66 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype): torch.testing.assert_close(expected_quantized, quantized) torch.testing.assert_close(expected_dequantized, dequantized) + @parameterized.expand( + [ + torch.float64, + torch.float32, + torch.bfloat16, + torch.float16, + ] + ) + def test_choose_qparams_affine_for_inf_scale_reciprocal(self, hp_dtype): + # Fixed by #1770, the test will fail for all the variants + # before that fix, and will pass afterwards. + # + # The scale value must be forcefully clamped, within + # _choose_qparams_affine() function, (that + # choose_qparams_affine() and others call into) to a large + # enough number so that its reciprocal does not become Inf. + # Otherwise during the quantization, by multiplying with scale + # reciprocal, all the values will be quantized to Inf value, + # except from zero value that would produce NaN (0*Inf) as + # quantized value. + # + # The minimal normalized value for given floating point data + # type is given by torch.finfo(hp_dtype).tiny - let's call + # this value "tiny". It could be seen by checking, that for + # all of torch.float64, torch.float32, torch.float16 and + # torch.floatb16, denormalized number that is equal to tiny/4 + # will produce Inf as its reciprocal. + # + # Thus, to reproduce the problem, one would create a tensor + # with such values that their absolute maximum, after being + # divided with the range of quantized data (that is 57344 for + # torch.float8_e5m2), would produce scale smaller than tiny/4. + # Also, eps parameter should be set to value no greater than + # tiny/4, as scale is clamped from below to that value. With + # such inputs, choose_qparams_affine() will produce Inf as + # scale value. + # + # Note that this may seem as contrieved reproducer. However, + # there are cases with existing code that would pass + # torch.finfo(torch.float32).eps as eps value, no matters of + # scale_dtype. The float16 has rather small range, so this + # value is well bellow torch.finfo(torch.float32).eps, and for + # such eps value, the code bellow would produce Inf scale even + # for float16 tensor that has 0.5 as its maximum value. + float8_dtype = torch.float8_e5m2 + tiny = torch.finfo(hp_dtype).tiny + x = torch.tensor([[0, 100 * tiny]], dtype=hp_dtype) + scale, _ = choose_qparams_affine( + input=x, + mapping_type=MappingType.SYMMETRIC, + block_size=[1, 2], + target_dtype=float8_dtype, + eps=tiny / 4, + scale_dtype=hp_dtype, + preserve_zero=True, + zero_point_domain=ZeroPointDomain.NONE, + ) + scale_reciprocal = scale.reciprocal() + assert not torch.any(torch.isinf(scale_reciprocal)).item() + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index d13ac330a0..8d68fa084b 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -862,6 +862,7 @@ def _choose_qparams_affine( 3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero` and `zero_point_domain` """ + quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) assert mapping_type in [ MappingType.SYMMETRIC.name, @@ -909,6 +910,16 @@ def _choose_qparams_affine( min_val_neg = min_val max_val_pos = max_val + # Prevent reciprocal of scale, calculated below, to become Inf. + if torch.is_floating_point(max_val): + # In this case, scale will be calculated below in + # max_val.dtype. + eps = max(eps, torch.finfo(max_val.dtype).tiny) + else: + # In this case, scale will be calculated below in + # torch.float32 dtype. + eps = max(eps, torch.finfo(torch.float32).tiny) + if ( mapping_type == MappingType.SYMMETRIC.name or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name @@ -969,7 +980,13 @@ def _choose_qparams_affine( if zero_point is not None: zero_point = zero_point.to(dtype=zero_point_dtype) - return scale.to(dtype=scale_dtype), zero_point + scale = scale.to(dtype=scale_dtype) + if torch.is_floating_point(scale): + # Again, prevent scale reciprocal to become Inf. + scale = scale.clamp( + min=torch.finfo(scale_dtype).tiny, max=torch.finfo(scale_dtype).max + ) + return scale, zero_point def choose_qparams_and_quantize_affine_qqq(