Skip to content

Commit d49830d

Browse files
committed
update NVFP4 data type; add scheme
1 parent 36204f0 commit d49830d

File tree

4 files changed

+22
-17
lines changed

4 files changed

+22
-17
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
wrap_module_forward_quantized,
2323
)
2424
from compressed_tensors.quantization.quant_args import (
25-
FP4_NVFP4_DATA,
25+
FP4_E2M1_DATA,
2626
FP8_E4M3_DATA,
2727
ActivationOrdering,
2828
QuantizationArgs,
@@ -167,19 +167,17 @@ def _initialize_scale_zero_point(
167167

168168
# NVFP4 support; use FP8 scales
169169
# For weight quant, attach global scales for NVFP4
170-
# TODO: How do we know if we need a global scale?
171170
# TODO: NVFP4 Scheme
172171
if (
173-
base_name == "weight"
174-
and quantization_args.num_bits == 4
172+
quantization_args.num_bits == 4
175173
and quantization_args.type == QuantizationType.FLOAT
176174
):
177175
scale_dtype = FP8_E4M3_DATA.dtype
178176
# create and attach nvfp4 data
179177
tensor_amax = torch.abs(module.weight.data).max().to(torch.float32)
180178
# Setting data for now - could possibly be handled later in the pipeline
181-
value = FP8_E4M3_DATA.max * FP4_NVFP4_DATA.max / tensor_amax
182-
# use the weight dtype (bfloat) maybe use float32 to start?
179+
value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax
180+
# TODO: use model.weight.dtype
183181
value = value.to(torch.float32).to(device)
184182
# Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32?
185183
init_global_scale = Parameter(value, requires_grad=False)

src/compressed_tensors/quantization/quant_args.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
__all__ = [
2727
"FP8_DTYPE",
2828
"FP8_E4M3_DATA",
29-
"FP4_NVFP4_DATA",
29+
"FP4_E2M1_DATA",
3030
"QuantizationType",
3131
"QuantizationStrategy",
3232
"QuantizationArgs",
@@ -56,8 +56,7 @@ class FloatArgs:
5656
min=torch.finfo(torch.float8_e4m3fn).min,
5757
dtype=torch.float8_e4m3fn,
5858
)
59-
# Don't call NVFP4; should be based on exponent and mantissa
60-
FP4_NVFP4_DATA = FloatArgs(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0)
59+
FP4_E2M1_DATA = FloatArgs(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0)
6160

6261

6362
class QuantizationType(str, Enum):

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,17 @@ def is_preset_scheme(name: str) -> bool:
100100

101101
UNQUANTIZED = dict()
102102

103+
NVFP4 = dict(
104+
weights=QuantizationArgs(
105+
num_bits=4,
106+
type=QuantizationType.FLOAT,
107+
strategy=QuantizationStrategy.GROUP,
108+
symmetric=True,
109+
dynamic=False,
110+
group_size=16,
111+
)
112+
)
113+
103114
# 8 bit integer weights and 8 bit activations quantization
104115
INT8_W8A8 = dict(
105116
weights=QuantizationArgs(
@@ -212,4 +223,5 @@ def is_preset_scheme(name: str) -> bool:
212223
# Float weight and activation schemes
213224
"FP8": FP8,
214225
"FP8_DYNAMIC": FP8_DYNAMIC,
226+
"NVFP4": NVFP4,
215227
}

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import torch
1919
from compressed_tensors.quantization.quant_args import (
20-
FP4_NVFP4_DATA,
20+
FP4_E2M1_DATA,
2121
FP8_E4M3_DATA,
2222
QuantizationArgs,
2323
QuantizationStrategy,
@@ -85,11 +85,7 @@ def calculate_qparams(
8585
and quantization_args.type == QuantizationType.FLOAT
8686
):
8787
assert global_scale is not None
88-
# TODO: how do we pass in the global scale?
89-
# An observer is attached per module, we can conditionally pass in
90-
# the global scale --> TODO: check for presence of the global when updating the scale
91-
# TODO: maybe remove FP8 scale cast
92-
scale = max_val_pos / FP4_NVFP4_DATA.max
88+
scale = max_val_pos / FP4_E2M1_DATA.max # Not needed
9389
scale = scale / global_scale
9490
scale = scale.to(FP8_E4M3_DATA.dtype) # .to(torch.float32)
9591

@@ -166,8 +162,8 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
166162
else:
167163
# nvfp4 ranges
168164
assert quantization_args.num_bits == 4
169-
q_max = torch.tensor(FP4_NVFP4_DATA.max, device=device)
170-
q_min = torch.tensor(FP4_NVFP4_DATA.min, device=device)
165+
q_max = torch.tensor(FP4_E2M1_DATA.max, device=device)
166+
q_min = torch.tensor(FP4_E2M1_DATA.min, device=device)
171167
else:
172168
raise ValueError(f"Invalid quantization type {quantization_args.type}")
173169

0 commit comments

Comments
 (0)