Skip to content

Commit 9254821

Browse files
committed
update datatype/look-up table
1 parent d49830d commit 9254821

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def fake_quantize(
143143
zero_point: torch.Tensor,
144144
args: QuantizationArgs,
145145
g_idx: Optional[torch.Tensor] = None,
146-
global_scale: Optiona[torch.Tensor] = None,
146+
global_scale: Optional[torch.Tensor] = None,
147147
) -> torch.Tensor:
148148
"""
149149
Fake quantize the input tensor x by quantizing then dequantizing with

src/compressed_tensors/quantization/quant_args.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ class FloatArgs:
4545
dtype: Optional[torch.dtype] = None
4646

4747

48+
@dataclass
49+
class FloatArgsFP4E2M1(FloatArgs):
50+
def cast_to_fp4(self, x):
51+
sign = torch.sign(x)
52+
x = torch.abs(x)
53+
x[(x >= 0.0) & (x <= 0.25)] = 0.0
54+
x[(x > 0.25) & (x < 0.75)] = 0.5
55+
x[(x >= 0.75) & (x <= 1.25)] = 1.0
56+
x[(x > 1.25) & (x < 1.75)] = 1.5
57+
x[(x >= 1.75) & (x <= 2.5)] = 2.0
58+
x[(x > 2.5) & (x < 3.5)] = 3.0
59+
x[(x >= 3.5) & (x <= 5.0)] = 4.0
60+
x[x > 5.0] = 6.0
61+
return x * sign
62+
63+
4864
# TODO: Remove soon in favour of a more descriptive FloatArgs
4965
FP8_DTYPE = torch.float8_e4m3fn
5066

@@ -56,7 +72,8 @@ class FloatArgs:
5672
min=torch.finfo(torch.float8_e4m3fn).min,
5773
dtype=torch.float8_e4m3fn,
5874
)
59-
FP4_E2M1_DATA = FloatArgs(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0)
75+
76+
FP4_E2M1_DATA = FloatArgsFP4E2M1(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0)
6077

6178

6279
class QuantizationType(str, Enum):
@@ -265,9 +282,7 @@ def pytorch_dtype(self) -> torch.dtype:
265282
return FP8_E4M3_DATA.dtype
266283
else:
267284
assert self.num_bits == 4
268-
# TODO: Use the look-up?
269-
# TODO: will return None for now until updated in FloatArgs
270-
return FP4_NVFP4_DATA.dtype
285+
raise NotImplementedError("Not supported for FP4")
271286
elif self.type == QuantizationType.INT:
272287
if self.num_bits <= 8:
273288
return torch.int8
@@ -300,9 +315,7 @@ def round_to_quantized_type(
300315
rounded = tensor.to(FP8_E4M3_DATA.dtype)
301316
else:
302317
assert args.num_bits == 4
303-
# TODO: Use the FP4_NVFP4_DATA class to use a look-up table
304-
# TODO: cast to whatever value we want fp4 to be post quantization/clamping
305-
rounded = tensor.to(FP4_NVFP4_DATA.dtype)
318+
rounded = FP4_E2M1_DATA.cast_to_fp4(tensor)
306319
elif args.type == QuantizationType.INT:
307320
rounded = torch.round(tensor)
308321
else:

0 commit comments

Comments
 (0)