@@ -45,6 +45,22 @@ class FloatArgs:
45
45
dtype : Optional [torch .dtype ] = None
46
46
47
47
48
+ @dataclass
49
+ class FloatArgsFP4E2M1 (FloatArgs ):
50
+ def cast_to_fp4 (self , x ):
51
+ sign = torch .sign (x )
52
+ x = torch .abs (x )
53
+ x [(x >= 0.0 ) & (x <= 0.25 )] = 0.0
54
+ x [(x > 0.25 ) & (x < 0.75 )] = 0.5
55
+ x [(x >= 0.75 ) & (x <= 1.25 )] = 1.0
56
+ x [(x > 1.25 ) & (x < 1.75 )] = 1.5
57
+ x [(x >= 1.75 ) & (x <= 2.5 )] = 2.0
58
+ x [(x > 2.5 ) & (x < 3.5 )] = 3.0
59
+ x [(x >= 3.5 ) & (x <= 5.0 )] = 4.0
60
+ x [x > 5.0 ] = 6.0
61
+ return x * sign
62
+
63
+
48
64
# TODO: Remove soon in favour of a more descriptive FloatArgs
49
65
FP8_DTYPE = torch .float8_e4m3fn
50
66
@@ -56,7 +72,8 @@ class FloatArgs:
56
72
min = torch .finfo (torch .float8_e4m3fn ).min ,
57
73
dtype = torch .float8_e4m3fn ,
58
74
)
59
- FP4_E2M1_DATA = FloatArgs (exponent = 2 , mantissa = 1 , bits = 4 , max = 6.0 , min = - 6.0 )
75
+
76
+ FP4_E2M1_DATA = FloatArgsFP4E2M1 (exponent = 2 , mantissa = 1 , bits = 4 , max = 6.0 , min = - 6.0 )
60
77
61
78
62
79
class QuantizationType (str , Enum ):
@@ -265,9 +282,7 @@ def pytorch_dtype(self) -> torch.dtype:
265
282
return FP8_E4M3_DATA .dtype
266
283
else :
267
284
assert self .num_bits == 4
268
- # TODO: Use the look-up?
269
- # TODO: will return None for now until updated in FloatArgs
270
- return FP4_NVFP4_DATA .dtype
285
+ raise NotImplementedError ("Not supported for FP4" )
271
286
elif self .type == QuantizationType .INT :
272
287
if self .num_bits <= 8 :
273
288
return torch .int8
@@ -300,9 +315,7 @@ def round_to_quantized_type(
300
315
rounded = tensor .to (FP8_E4M3_DATA .dtype )
301
316
else :
302
317
assert args .num_bits == 4
303
- # TODO: Use the FP4_NVFP4_DATA class to use a look-up table
304
- # TODO: cast to whatever value we want fp4 to be post quantization/clamping
305
- rounded = tensor .to (FP4_NVFP4_DATA .dtype )
318
+ rounded = FP4_E2M1_DATA .cast_to_fp4 (tensor )
306
319
elif args .type == QuantizationType .INT :
307
320
rounded = torch .round (tensor )
308
321
else :
0 commit comments