Skip to content

Commit 33a3f9d

Browse files
authored
[NVFP4] update global scale generation (#339)
* update * fix conditions * update * remove global weight scale generation into llmcompresssor * fix test
1 parent b2b95b7 commit 33a3f9d

File tree

4 files changed

+19
-133
lines changed

4 files changed

+19
-133
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,8 @@
2727
)
2828
from compressed_tensors.quantization.lifecycle.initialize import (
2929
initialize_module_for_quantization,
30-
update_fused_layer_weight_global_scales,
31-
)
32-
from compressed_tensors.quantization.quant_args import (
33-
FP4_E2M1_DATA,
34-
FP8_E4M3_DATA,
35-
QuantizationArgs,
36-
QuantizationType,
3730
)
31+
from compressed_tensors.quantization.quant_args import QuantizationArgs
3832
from compressed_tensors.quantization.quant_config import (
3933
QuantizationConfig,
4034
QuantizationStatus,
@@ -272,9 +266,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
272266
)
273267
)
274268

275-
if status == QuantizationStatus.INITIALIZED:
276-
update_fused_layer_weight_global_scales(model)
277-
278269
if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
279270
model.apply(compress_quantized_weights)
280271

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 7 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,18 @@
2323
wrap_module_forward_quantized,
2424
)
2525
from compressed_tensors.quantization.quant_args import (
26-
FP4_E2M1_DATA,
2726
FP8_E4M3_DATA,
2827
ActivationOrdering,
2928
QuantizationArgs,
3029
QuantizationStrategy,
31-
QuantizationType,
3230
)
3331
from compressed_tensors.quantization.quant_config import QuantizationStatus
3432
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
35-
from compressed_tensors.quantization.utils import (
36-
generate_global_scale,
37-
is_fp4,
38-
is_kv_cache_quant_scheme,
39-
iter_named_quantizable_modules,
40-
)
33+
from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
4134
from compressed_tensors.utils import (
4235
disable_hf_hook,
4336
get_execution_device,
4437
register_offload_parameter,
45-
update_parameter_data,
4638
)
4739
from torch.nn import Module, Parameter
4840

@@ -51,7 +43,6 @@
5143
"initialize_module_for_quantization",
5244
"is_attention_module",
5345
"KVCacheScaleType",
54-
"update_fused_layer_weight_global_scales",
5546
]
5647

5748

@@ -162,22 +153,13 @@ def _initialize_scale_zero_point(
162153
# initialize on execution device to avoid performing quantized ops on cpu
163154
device = get_execution_device(module)
164155

165-
# 1. Create global_scales for tensor_group
156+
# 1. Create global_scales for tensor_group - generates
157+
# a per tensor scale
166158
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
167-
# TODO: should move to llmcompressor
168-
if base_name == "weight":
169-
# When applying weight-only FP4 quantization, generate a global_scale
170-
# This scale is applied during runtime to ensure that the generated
171-
# local scale falls properly within the FP8 range (i.e max value is FP8_max)
172-
# which is the expected dtype of NVFP4A16 scales
173-
value = generate_global_scale(input_tensor=module.weight)
174-
value = value.to(device)
175-
init_global_scale = Parameter(value, requires_grad=False)
176-
else:
177-
init_global_scale = Parameter(
178-
torch.empty(1, dtype=torch.float32, device=device),
179-
requires_grad=False,
180-
)
159+
init_global_scale = Parameter(
160+
torch.empty(1, dtype=torch.float32, device=device),
161+
requires_grad=False,
162+
)
181163
register_offload_parameter(
182164
module, f"{base_name}_global_scale", init_global_scale
183165
)
@@ -258,91 +240,3 @@ def _initialize_attn_scales(module: Module) -> None:
258240
requires_grad=False,
259241
)
260242
register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)
261-
262-
263-
# TODO: Potentially introduce an argument to turn this off
264-
# Only relevant for NVFP4A16 currently
265-
def update_fused_layer_weight_global_scales(model: torch.nn.Module):
266-
"""
267-
When running NVFP4A16 quantization, update the global scale
268-
such that q,k,v layers are treated as one tensor with the same
269-
global_scale and gate_proj/up_proj layers are treated as one tensor
270-
with the same global scale. This is requirement currently being set
271-
by vLLM and may be removed in the future OR potentially make it
272-
an optional step.
273-
274-
:param model: model to quantize
275-
"""
276-
277-
def _is_attention_module(module: Module):
278-
return "attention" in module.__class__.__name__.lower() and (
279-
hasattr(module, "k_proj")
280-
or hasattr(module, "v_proj")
281-
or hasattr(module, "qkv_proj")
282-
)
283-
284-
def _is_mlp_module(module: Module):
285-
return "mlp" in module.__class__.__name__.lower() and (
286-
hasattr(module, "gate_proj") or hasattr(module, "up_proj")
287-
)
288-
289-
def _valid_fp4_quant(layer_list: List[torch.nn.Linear]):
290-
"""
291-
Return True if all the linear layers in the layer_list are
292-
NVFP4A16 quantized.
293-
"""
294-
for layer in layer_list:
295-
scheme = getattr(layer, "quantization_scheme", None)
296-
if scheme is None:
297-
return False
298-
299-
weight_quant_args = scheme.weights
300-
301-
if weight_quant_args is None:
302-
return False
303-
304-
if not is_fp4(quantization_args=weight_quant_args):
305-
return False
306-
return True
307-
308-
for name, submodule in iter_named_quantizable_modules(
309-
model,
310-
include_attn=True,
311-
include_mlp=True,
312-
):
313-
314-
if _is_attention_module(submodule):
315-
# already fused/treated as one layer
316-
if hasattr(submodule, "qkv_proj"):
317-
continue
318-
319-
if not _valid_fp4_quant(
320-
[submodule.q_proj, submodule.v_proj, submodule.k_proj]
321-
):
322-
continue
323-
324-
q_weight = submodule.q_proj.weight.data
325-
v_weight = submodule.v_proj.weight.data
326-
k_weight = submodule.k_proj.weight.data
327-
328-
value = generate_global_scale(
329-
input_tensor=torch.cat((q_weight, v_weight, k_weight), dim=0)
330-
)
331-
332-
update_parameter_data(submodule.q_proj, value, "weight_global_scale")
333-
update_parameter_data(submodule.k_proj, value, "weight_global_scale")
334-
update_parameter_data(submodule.v_proj, value, "weight_global_scale")
335-
336-
if _is_mlp_module(submodule):
337-
if not _valid_fp4_quant([submodule.gate_proj, submodule.up_proj]):
338-
continue
339-
340-
gate_data = submodule.gate_proj.weight.data
341-
up_data = submodule.up_proj.weight.data
342-
343-
value = generate_global_scale(
344-
input_tensor=torch.cat((gate_data, up_data), dim=0)
345-
)
346-
347-
update_parameter_data(submodule.gate_proj, value, "weight_global_scale")
348-
update_parameter_data(submodule.up_proj, value, "weight_global_scale")

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
"compute_dynamic_scales_and_zp",
4848
"calculate_range",
4949
"calculate_qparams",
50-
"generate_global_scale",
50+
"generate_gparam",
5151
"is_fp4",
5252
]
5353

@@ -475,8 +475,9 @@ def parse_out_kv_cache_args(
475475
return kv_cache_args, quant_scheme_to_layers
476476

477477

478-
def generate_global_scale(
479-
input_tensor: torch.Tensor,
478+
def generate_gparam(
479+
updated_min_val: torch.Tensor,
480+
updated_max_val: torch.Tensor,
480481
scale_data: Optional[FloatArgs] = FP8_E4M3_DATA,
481482
quant_data: Optional[FloatArgs] = FP4_E2M1_DATA,
482483
dtype: Optional[torch.dtype] = torch.float32,
@@ -490,6 +491,8 @@ def generate_global_scale(
490491
attempts to use the entire FP8 dtype range while mapping a per-group max
491492
to the FP4 max.
492493
"""
493-
tensor_amax = torch.abs(input_tensor.data).max().to(dtype)
494-
global_scale = scale_data.max * quant_data.max / tensor_amax
494+
min_vals = torch.min(updated_min_val, torch.zeros_like(updated_min_val))
495+
max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val))
496+
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
497+
global_scale = scale_data.max * quant_data.max / max_val_pos
495498
return global_scale.to(dtype)

tests/test_quantization/test_utils/test_helpers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@
2020
QuantizationArgs,
2121
QuantizationStrategy,
2222
)
23-
from compressed_tensors.quantization.utils import (
24-
calculate_qparams,
25-
generate_global_scale,
26-
)
23+
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
2724

2825

2926
@pytest.mark.parametrize(
@@ -70,7 +67,8 @@ def test_fused_global_scales():
7067
layer = torch.nn.Linear(7, 8)
7168
max_tensor_value = torch.abs(layer.weight.data).max()
7269
# use defaults
73-
global_scale = generate_global_scale(layer.weight)
70+
min_val, max_val = torch.aminmax(layer.weight)
71+
global_scale = generate_gparam(min_val.data, max_val.data)
7472
# max value should be = (448 * 6) / global_scale
7573
assert max_tensor_value == pytest.approx(
7674
FP4_E2M1_DATA.max * FP8_E4M3_DATA.max / global_scale, abs=0.001

0 commit comments

Comments
 (0)