Skip to content

Commit a66f34d

Browse files
committed
Fix wrong scale eps applied
1 parent 2a3fbff commit a66f34d

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
quantize_affine,
2222
quantize_affine_floatx,
2323
)
24+
from torchao.quantization.utils import (
25+
calculate_scale_eps_for_dtype,
26+
)
2427
from torchao.utils import (
2528
TORCH_VERSION_AT_LEAST_2_5,
2629
TorchAOBaseTensor,
@@ -350,14 +353,19 @@ def from_hp_to_floatx(
350353
):
351354
"""Convert a high precision tensor to a float8 quantized tensor."""
352355
if target_dtype in FP8_TYPES:
356+
eps = (
357+
calculate_scale_eps_for_dtype(input_float.dtype)
358+
if torch.is_floating_point(input_float)
359+
else torch.finfo(torch.float32).eps
360+
)
353361
return cls.from_hp_to_intx(
354362
input_float=input_float,
355363
mapping_type=MappingType.SYMMETRIC,
356364
block_size=block_size,
357365
target_dtype=target_dtype,
358366
quant_min=math.ceil(torch.finfo(target_dtype).min),
359367
quant_max=math.ceil(torch.finfo(target_dtype).max),
360-
eps=torch.finfo(torch.float32).eps,
368+
eps=eps,
361369
scale_dtype=scale_dtype,
362370
zero_point_dtype=None,
363371
preserve_zero=True,

torchao/quantization/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import functools
67
import importlib.util
78
from typing import Dict, List, Optional
89

@@ -41,6 +42,7 @@
4142
"per_token_dynamic_quant",
4243
"get_group_qparams_symmetric",
4344
"recommended_inductor_config_setter",
45+
"calculate_scale_eps_for_dtype",
4446
]
4547

4648
_lm_eval_available = importlib.util.find_spec("lm_eval") is not None
@@ -587,3 +589,29 @@ def recommended_inductor_config_setter():
587589
torch._inductor.config.fx_graph_cache = True
588590
torch._inductor.config.triton.unique_kernel_names = True
589591
torch.set_float32_matmul_precision("high")
592+
593+
594+
@functools.lru_cache
595+
def calculate_scale_eps_for_dtype(dtype: torch.dtype):
596+
assert torch.is_floating_point(torch.empty(0, dtype=dtype))
597+
598+
def predecessor(x: torch.Tensor):
599+
assert x.numel() == 1
600+
601+
dtype = x.dtype
602+
if dtype == torch.float16:
603+
zero = torch.tensor(0, dtype=dtype)
604+
else:
605+
zero = torch.tensor(0.0, dtype=dtype)
606+
return torch.nextafter(x, zero)
607+
608+
x = torch.tensor(torch.finfo(dtype).max, dtype=dtype)
609+
x_rec = 1.0 / x
610+
while True:
611+
if torch.any(torch.isinf(x_rec.reciprocal())).item():
612+
x = predecessor(x)
613+
x_rec = 1.0 / x
614+
else:
615+
break
616+
617+
return x_rec.item()

0 commit comments

Comments
 (0)