Skip to content

Commit e918d1e

Browse files
authored
[NVFP4] Update FloatArgs and NVFP4 (#313)
* add nvfp4 args * format * update args * dont use a dataclass * remove dataclass
1 parent 5c6fd5d commit e918d1e

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

src/compressed_tensors/quantization/quant_args.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,57 @@
2424

2525
__all__ = [
2626
"FP8_DTYPE",
27+
"FP8_E4M3_DATA",
28+
"FP4_E2M1_DATA",
2729
"QuantizationType",
2830
"QuantizationStrategy",
2931
"QuantizationArgs",
3032
"round_to_quantized_type",
3133
"ActivationOrdering",
3234
]
3335

36+
37+
class FloatArgs:
38+
exponent: int
39+
mantissa: int
40+
bits: int
41+
max: float
42+
min: float
43+
dtype: Optional[torch.dtype] = None
44+
45+
46+
class FP4_E2M1_DATA(FloatArgs):
47+
exponent = 2
48+
mantissa = 1
49+
bits = 4
50+
max = 6.0
51+
min = -6.0
52+
53+
@staticmethod
54+
def cast_to_fp4(x):
55+
sign = torch.sign(x)
56+
x = torch.abs(x)
57+
x[(x >= 0.0) & (x <= 0.25)] = 0.0
58+
x[(x > 0.25) & (x < 0.75)] = 0.5
59+
x[(x >= 0.75) & (x <= 1.25)] = 1.0
60+
x[(x > 1.25) & (x < 1.75)] = 1.5
61+
x[(x >= 1.75) & (x <= 2.5)] = 2.0
62+
x[(x > 2.5) & (x < 3.5)] = 3.0
63+
x[(x >= 3.5) & (x <= 5.0)] = 4.0
64+
x[x > 5.0] = 6.0
65+
return x * sign
66+
67+
68+
class FP8_E4M3_DATA(FloatArgs):
69+
exponent = 4
70+
mantissa = 3
71+
bits = 8
72+
max = torch.finfo(torch.float8_e4m3fn).max
73+
min = torch.finfo(torch.float8_e4m3fn).min
74+
dtype = torch.float8_e4m3fn
75+
76+
77+
# TODO: Remove soon in favour of a more descriptive FloatArgs
3478
FP8_DTYPE = torch.float8_e4m3fn
3579

3680

@@ -234,7 +278,10 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
234278

235279
def pytorch_dtype(self) -> torch.dtype:
236280
if self.type == QuantizationType.FLOAT:
237-
return FP8_DTYPE
281+
if self.num_bits == 8:
282+
return FP8_E4M3_DATA.dtype
283+
else:
284+
raise NotImplementedError("Only num_bits in (8) are supported")
238285
elif self.type == QuantizationType.INT:
239286
if self.num_bits <= 8:
240287
return torch.int8
@@ -263,7 +310,12 @@ def round_to_quantized_type(
263310
"""
264311
original_dtype = tensor.dtype
265312
if args.type == QuantizationType.FLOAT:
266-
rounded = tensor.to(FP8_DTYPE)
313+
if args.num_bits == 8:
314+
rounded = tensor.to(FP8_E4M3_DATA.dtype)
315+
elif args.num_bits == 4:
316+
rounded = FP4_E2M1_DATA.cast_to_fp4(tensor)
317+
else:
318+
raise NotImplementedError("Only num_bits in (4, 8) are supported")
267319
elif args.type == QuantizationType.INT:
268320
rounded = torch.round(tensor)
269321
else:

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,17 @@ def is_preset_scheme(name: str) -> bool:
100100

101101
UNQUANTIZED = dict()
102102

103+
NVFP4A16 = dict(
104+
weights=QuantizationArgs(
105+
num_bits=4,
106+
type=QuantizationType.FLOAT,
107+
strategy=QuantizationStrategy.GROUP,
108+
symmetric=True,
109+
dynamic=False,
110+
group_size=16,
111+
)
112+
)
113+
103114
# 8 bit integer weights and 8 bit activations quantization
104115
INT8_W8A8 = dict(
105116
weights=QuantizationArgs(
@@ -225,4 +236,5 @@ def is_preset_scheme(name: str) -> bool:
225236
# Float weight and activation schemes
226237
"FP8": FP8,
227238
"FP8_DYNAMIC": FP8_DYNAMIC,
239+
"NVFP4A16": NVFP4A16,
228240
}

0 commit comments

Comments
 (0)