Skip to content

Commit a6327c7

Browse files
authored
[NVFP4] Enable Fp4 Quantization; introduce / apply global_scales (#315)
* add nvfp4 args * format * dont use a dataclass * remove dataclass * update forward pass * update helpers * update docstring * add init functionality * clean-up * add docstring * add import * test skeletons * fix args * fix condition * add tests * update * fix compress_weight * add docstring * fix * use approx * fix illegal device access * update docstring; fix typos * use helper * update condition * remove TODO
1 parent b8a443a commit a6327c7

File tree

10 files changed

+318
-35
lines changed

10 files changed

+318
-35
lines changed

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def compress(
9999
scale = model_state.get(prefix + "weight_scale", None)
100100
g_idx = model_state.get(prefix + "weight_g_idx", None)
101101
zp = model_state.get(prefix + "weight_zero_point", None)
102+
global_scale = model_state.get(prefix + "weight_global_scale", None)
102103

103104
# is scale does not exist, then weight cannot be compressed
104105
if scale is None:
@@ -112,6 +113,7 @@ def compress(
112113
weight=value,
113114
scale=scale,
114115
zero_point=zp,
116+
global_scale=global_scale,
115117
g_idx=g_idx,
116118
quantization_args=quant_args,
117119
device="cpu",

src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def compress_weight(
7878
zero_point: Optional[Tensor] = None,
7979
g_idx: Optional[torch.Tensor] = None,
8080
device: Optional[torch.device] = None,
81+
global_scale: Optional[torch.Tensor] = None,
8182
) -> Dict[str, torch.Tensor]:
8283
"""
8384
Compresses a single uncompressed weight
@@ -90,6 +91,11 @@ def compress_weight(
9091
:param device: optional device to move compressed output to
9192
:return: dictionary of compressed weight data
9293
"""
94+
if global_scale is not None:
95+
raise ValueError(
96+
"global_scale is not supported for the NaiveQuantizationCompressor"
97+
)
98+
9399
if can_quantize(weight, quantization_args):
94100
quantized_weight = quantize(
95101
x=weight,

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def compress_weight(
9494
zero_point: Optional[Tensor] = None,
9595
g_idx: Optional[torch.Tensor] = None,
9696
device: Optional[torch.device] = None,
97+
global_scale: Optional[torch.Tensor] = None,
9798
) -> Dict[str, torch.Tensor]:
9899
"""
99100
Compresses a single uncompressed weight
@@ -106,6 +107,11 @@ def compress_weight(
106107
:param device: optional device to move compressed output to
107108
:return: dictionary of compressed weight data
108109
"""
110+
if global_scale is not None:
111+
raise ValueError(
112+
"global_scale is not supported for the PackQuantizationCompressor"
113+
)
114+
109115
compressed_dict = {}
110116
if can_quantize(weight, quantization_args):
111117
quantized_weight = quantize(

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,14 @@
2727
)
2828
from compressed_tensors.quantization.lifecycle.initialize import (
2929
initialize_module_for_quantization,
30+
update_fused_layer_weight_global_scales,
31+
)
32+
from compressed_tensors.quantization.quant_args import (
33+
FP4_E2M1_DATA,
34+
FP8_E4M3_DATA,
35+
QuantizationArgs,
36+
QuantizationType,
3037
)
31-
from compressed_tensors.quantization.quant_args import QuantizationArgs
3238
from compressed_tensors.quantization.quant_config import (
3339
QuantizationConfig,
3440
QuantizationStatus,
@@ -266,6 +272,9 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
266272
)
267273
)
268274

275+
if status == QuantizationStatus.INITIALIZED:
276+
update_fused_layer_weight_global_scales(model)
277+
269278
if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
270279
model.apply(compress_quantized_weights)
271280

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from compressed_tensors.quantization.quant_args import (
2121
QuantizationArgs,
2222
QuantizationStrategy,
23+
QuantizationType,
2324
round_to_quantized_type,
2425
)
2526
from compressed_tensors.quantization.quant_config import QuantizationStatus
@@ -49,6 +50,7 @@ def quantize(
4950
args: QuantizationArgs,
5051
dtype: Optional[torch.dtype] = None,
5152
g_idx: Optional[torch.Tensor] = None,
53+
global_scale: Optional[torch.Tensor] = None,
5254
) -> torch.Tensor:
5355
"""
5456
Quantize the input tensor x using the QuantizationStrategy specified in args.
@@ -63,6 +65,7 @@ def quantize(
6365
:param args: quantization args dictating how to quantize x
6466
:param dtype: optional dtype to cast the quantized output to
6567
:param g_idx: optional mapping from column index to group index
68+
:param global_scale: optional constant to scale the quantization scale during QDQ
6669
:return: fake quantized tensor
6770
"""
6871

@@ -75,6 +78,7 @@ def quantize(
7578
do_quantize=True,
7679
do_dequantize=False,
7780
g_idx=g_idx,
81+
global_scale=global_scale,
7882
)
7983

8084

@@ -86,6 +90,7 @@ def dequantize(
8690
args: Optional[QuantizationArgs] = None,
8791
dtype: Optional[torch.dtype] = None,
8892
g_idx: Optional[torch.Tensor] = None,
93+
global_scale: Optional[torch.Tensor] = None,
8994
) -> torch.Tensor:
9095
"""
9196
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
@@ -97,6 +102,7 @@ def dequantize(
97102
:param args: quantization args used to quantize x_q
98103
:param dtype: optional dtype to cast the dequantized output to
99104
:param g_idx: optional mapping from column index to group index
105+
:param global_scale: optional constant to scale the quantization scale during QDQ
100106
:return: dequantized float tensor
101107
"""
102108
if args is None:
@@ -128,6 +134,7 @@ def dequantize(
128134
do_dequantize=True,
129135
dtype=dtype,
130136
g_idx=g_idx,
137+
global_scale=global_scale,
131138
)
132139

133140

@@ -138,6 +145,7 @@ def fake_quantize(
138145
zero_point: torch.Tensor,
139146
args: QuantizationArgs,
140147
g_idx: Optional[torch.Tensor] = None,
148+
global_scale: Optional[torch.Tensor] = None,
141149
) -> torch.Tensor:
142150
"""
143151
Fake quantize the input tensor x by quantizing then dequantizing with
@@ -151,6 +159,7 @@ def fake_quantize(
151159
:param zero_point: zero point tensor
152160
:param args: quantization args dictating how to quantize x
153161
:param g_idx: optional mapping from column index to group index
162+
:param global_scale: optional constant to scale the quantization scale during QDQ
154163
:return: fake quantized tensor
155164
"""
156165
return _process_quantization(
@@ -161,6 +170,7 @@ def fake_quantize(
161170
do_quantize=True,
162171
do_dequantize=True,
163172
g_idx=g_idx,
173+
global_scale=global_scale,
164174
)
165175

166176

@@ -174,6 +184,7 @@ def _process_quantization(
174184
dtype: Optional[torch.dtype] = None,
175185
do_quantize: bool = True,
176186
do_dequantize: bool = True,
187+
global_scale: Optional[torch.Tensor] = None,
177188
) -> torch.Tensor:
178189
q_min, q_max = calculate_range(args, x.device)
179190
group_size = args.group_size
@@ -221,35 +232,44 @@ def _process_quantization(
221232
end = start + group_count
222233
if do_quantize:
223234
output[:, start:end] = _quantize(
224-
x[:, start:end],
225-
sc,
226-
zp,
227-
q_min,
228-
q_max,
229-
args,
235+
x=x[:, start:end],
236+
scale=sc,
237+
zero_point=zp,
238+
q_min=q_min,
239+
q_max=q_max,
240+
args=args,
230241
dtype=dtype,
242+
global_scale=global_scale,
231243
)
232244

233245
if do_dequantize:
234246
input = output[:, start:end] if do_quantize else x[:, start:end]
235-
output[:, start:end] = _dequantize(input, sc, zp)
247+
output[:, start:end] = _dequantize(
248+
x_q=input, scale=sc, zero_point=zp, global_scale=global_scale
249+
)
236250

237251
if not is_column_order:
238252
output = safe_permute(output, torch.argsort(perm), dim=1)
239253

240254
else: # covers channel, token and tensor strategies
241255
if do_quantize:
242256
output = _quantize(
243-
x,
244-
scale,
245-
zero_point,
246-
q_min,
247-
q_max,
248-
args,
257+
x=x,
258+
scale=scale,
259+
zero_point=zero_point,
260+
q_min=q_min,
261+
q_max=q_max,
262+
args=args,
249263
dtype=dtype,
264+
global_scale=global_scale,
250265
)
251266
if do_dequantize:
252-
output = _dequantize(output if do_quantize else x, scale, zero_point)
267+
output = _dequantize(
268+
output if do_quantize else x,
269+
scale=scale,
270+
zero_point=zero_point,
271+
global_scale=global_scale,
272+
)
253273

254274
return output
255275

@@ -330,6 +350,7 @@ def forward_quantize(
330350
return value
331351

332352
g_idx = getattr(module, "weight_g_idx", None)
353+
global_scale = getattr(module, f"{base_name}_global_scale", None)
333354

334355
if args.dynamic:
335356
# dynamic quantization - determine the scale/zp on the fly
@@ -345,6 +366,7 @@ def forward_quantize(
345366
zero_point=zero_point,
346367
args=args,
347368
g_idx=g_idx,
369+
global_scale=global_scale,
348370
)
349371

350372

@@ -357,11 +379,18 @@ def _quantize(
357379
q_max: torch.Tensor,
358380
args: QuantizationArgs,
359381
dtype: Optional[torch.dtype] = None,
382+
global_scale: Optional[torch.Tensor] = None,
360383
) -> torch.Tensor:
361384

385+
# if a global scale is optionally provided, use it
386+
# to further scale the local `scale` parameter
387+
if global_scale:
388+
scale = scale.to(global_scale.dtype) / global_scale
389+
362390
scaled = x / scale
363391
if zero_point is not None:
364392
scaled += zero_point.to(x.dtype)
393+
365394
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
366395
clamped_value = torch.clamp(
367396
scaled,
@@ -381,7 +410,14 @@ def _dequantize(
381410
scale: torch.Tensor,
382411
zero_point: torch.Tensor = None,
383412
dtype: Optional[torch.dtype] = None,
413+
global_scale: Optional[torch.Tensor] = None,
384414
) -> torch.Tensor:
415+
416+
# if a global scale is optionally provided, use it
417+
# to further scale the local `scale` parameter
418+
if global_scale:
419+
scale = scale.to(global_scale.dtype) / global_scale
420+
385421
dequant_value = x_q.to(scale.dtype)
386422

387423
if zero_point is not None:

0 commit comments

Comments
 (0)