Skip to content

Commit 559ad81

Browse files
authored
[NVFP4] Update global scale generation (#1508)
# SUMMARY: - Reopening as #1504 merged through automerge with just one review - Requires: neuralmagic/compressed-tensors#339 - Uses observers to generate global weight scales; these were previously being generated during the init function in compressed-tensors however, using observers is more consistent with our workflows and parameter lifecycle - Also moves in the fused layer update step to llmcompressor - this can be removed once we have an update from vLLM. However, right now this requires us to split up the update_weight_global_scale and weight_weight_zp_scale steps - these can be combined once the vLLM change is made - Update examples to include sample generation - this is now very quick thanks to this PR: neuralmagic/compressed-tensors#336 Note: The mse observer is very much tied to generating a scale and zero-point so it can't be used for global scale generation at the moment. We will have to decouple this functionality in order to support general scale optimization # TEST PLAN: - Tested e2e with nvfp4 and nvfp4a16 - Validated existing workflows work e2e (w4a16, spaarse2of4 + fp8, w8a8 int8/fp8)
1 parent 2caa1e7 commit 559ad81

File tree

8 files changed

+179
-68
lines changed

8 files changed

+179
-68
lines changed

examples/quantization_w4a16_fp4/llama3_example.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
# Apply quantization.
2020
oneshot(model=model, recipe=recipe)
2121

22+
print("\n\n")
23+
print("========== SAMPLE GENERATION ==============")
24+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
25+
output = model.generate(input_ids, max_new_tokens=100)
26+
print(tokenizer.decode(output[0]))
27+
print("==========================================\n\n")
28+
29+
2230
# Save to disk in compressed-tensors format.
2331
SAVE_DIR = MODEL_ID.split("/")[1] + "-NVFP4A16"
2432
model.save_pretrained(SAVE_DIR, save_compressed=True)

examples/quantization_w4a4_fp4/llama3_example.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ def tokenize(sample):
6767
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
6868
)
6969

70+
print("\n\n")
71+
print("========== SAMPLE GENERATION ==============")
72+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
73+
output = model.generate(input_ids, max_new_tokens=100)
74+
print(tokenizer.decode(output[0]))
75+
print("==========================================\n\n")
76+
77+
7078
# Save to disk in compressed-tensors format.
7179
SAVE_DIR = MODEL_ID.split("/")[1] + "-NVFP4"
7280
model.save_pretrained(SAVE_DIR, save_compressed=True)

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"freeze_module_quantization",
3030
"apply_calibration_status",
3131
"reset_quantization_status",
32+
"update_weight_global_scale",
3233
]
3334

3435

@@ -66,7 +67,13 @@ def initialize_observer(
6667
module.register_module(f"{base_name}_observer", observer)
6768

6869

69-
def call_observer(module: Module, base_name: str, value: Optional[torch.Tensor] = None):
70+
def call_observer(
71+
module: Module,
72+
base_name: str,
73+
value: Optional[torch.Tensor] = None,
74+
should_calculate_gparam: bool = False,
75+
should_calculate_qparams: bool = True,
76+
):
7077
"""
7178
Call a module's attached input/weight/output observer using a provided value.
7279
Update the module's scale and zp using the observer's return values.
@@ -80,54 +87,51 @@ def call_observer(module: Module, base_name: str, value: Optional[torch.Tensor]
8087
if base_name == "weight":
8188
value = module.weight
8289
g_idx = getattr(module, "weight_g_idx", None)
83-
global_scale = getattr(module, f"{base_name}_global_scale", None)
8490
elif value is not None:
8591
g_idx = None
86-
global_scale = None
8792
else:
8893
raise ValueError(
8994
"Must provide a value to observe if not using weight observer"
9095
)
9196

92-
quantization_scheme = getattr(module, "quantization_scheme", None)
93-
arg_name = "weights" if base_name == "weight" else f"{base_name}_activations"
94-
quant_args = getattr(quantization_scheme, arg_name, None)
95-
96-
# We always calculate quantizaton parameters by default and no global parameters
97-
should_calculate_gparam = False
98-
should_calculate_qparams = True
99-
100-
# TODO: will update to be the case for both weight and input in a follow-up
101-
# weight global calculate is currently done in ct right now;
102-
# should be moved here to unify global scale calculations
103-
if (
104-
quant_args.strategy == QuantizationStrategy.TENSOR_GROUP
105-
and base_name == "input"
106-
):
107-
should_calculate_gparam = True
108-
should_calculate_qparams = False
109-
11097
observer = getattr(module, f"{base_name}_observer")
111-
observer_outputs = observer(
112-
value,
113-
g_idx=g_idx,
114-
global_scale=global_scale,
115-
should_calculate_gparam=should_calculate_gparam,
116-
)
11798

11899
if should_calculate_gparam:
119-
updated_global_scale = observer_outputs
120-
update_parameter_data(
121-
module, updated_global_scale, f"{base_name}_global_scale"
100+
global_scale = observer(
101+
value,
102+
should_calculate_gparam=True,
122103
)
104+
update_parameter_data(module, global_scale, f"{base_name}_global_scale")
105+
else:
106+
global_scale = getattr(module, f"{base_name}_global_scale", None)
123107

124108
if should_calculate_qparams:
125-
# update scale and zero point
126-
updated_scale, updated_zero_point = observer_outputs
109+
updated_scale, updated_zero_point = observer(
110+
value, g_idx=g_idx, global_scale=global_scale
111+
)
127112
update_parameter_data(module, updated_scale, f"{base_name}_scale")
128113
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
129114

130115

116+
def update_weight_global_scale(module: Module):
117+
if getattr_chain(module, "quantization_scheme.weights", None) is None:
118+
return
119+
120+
if (
121+
getattr_chain(module, "quantization_scheme.weights.strategy", None)
122+
!= QuantizationStrategy.TENSOR_GROUP
123+
):
124+
return
125+
126+
call_observer(
127+
module,
128+
base_name="weight",
129+
should_calculate_gparam=True,
130+
should_calculate_qparams=False,
131+
)
132+
module.weight_observer.reset()
133+
134+
131135
def update_weight_zp_scale(module: Module):
132136
"""
133137
marks a layer as ready for calibration which activates observers
@@ -165,10 +169,24 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
165169
if value.numel() == 0:
166170
return
167171

172+
quantization_scheme = getattr(module, "quantization_scheme", None)
173+
quantization_args = getattr(quantization_scheme, f"{base_name}_activations", None)
174+
175+
calculate_qparams = True
176+
calculate_gparam = False
177+
178+
if quantization_args is not None:
179+
if quantization_args.dynamic in (True, DynamicType.LOCAL):
180+
calculate_qparams = False
181+
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
182+
calculate_gparam = True
183+
168184
call_observer(
169185
module=module,
170186
base_name=base_name,
171187
value=value,
188+
should_calculate_gparam=calculate_gparam,
189+
should_calculate_qparams=calculate_qparams,
172190
)
173191

174192

src/llmcompressor/modifiers/quantization/quantization/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22

33
from llmcompressor.core import Event, EventType, State
44
from llmcompressor.modifiers import Modifier
5-
from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale
5+
from llmcompressor.modifiers.quantization.calibration import (
6+
update_weight_global_scale,
7+
update_weight_zp_scale,
8+
)
69
from llmcompressor.modifiers.quantization.quantization.mixin import QuantizationMixin
10+
from llmcompressor.modifiers.utils import update_fused_layer_weight_global_scales
711

812
__all__ = ["QuantizationModifier"]
913

@@ -66,7 +70,14 @@ def on_start(self, state: State, event: Event, **kwargs):
6670
QuantizationMixin.start_calibration(self, state.model)
6771

6872
modules = list(state.model.modules())
73+
# TODO: this step can be combined with update_weight_zp_scale
74+
# once update_fused_layer_weight_global_scales is removed
75+
# and not required by vLLM
76+
for module in tqdm.tqdm(modules):
77+
update_weight_global_scale(module)
78+
6979
for module in tqdm.tqdm(modules, desc="Calibrating weights"):
80+
update_fused_layer_weight_global_scales(module)
7081
update_weight_zp_scale(module)
7182

7283
def on_event(self, state: State, event: Event, **kwargs):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# flake8: noqa
22

33
from .constants import *
4+
from .helpers import *
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from typing import List
2+
3+
import torch
4+
from compressed_tensors.quantization import QuantizationStrategy
5+
from compressed_tensors.utils import align_module_device, update_parameter_data
6+
from torch.nn import Linear, Module
7+
8+
__all__ = ["update_fused_layer_weight_global_scales"]
9+
10+
11+
def update_fused_layer_weight_global_scales(submodule: torch.nn.Module):
12+
"""
13+
When running NVFP4 quantization, update the global scale
14+
such that q,k,v layers are treated as one tensor with the same
15+
global_scale and gate_proj/up_proj layers are treated as one tensor
16+
with the same global scale. This is requirement currently being set
17+
by vLLM and may be removed in the future OR potentially make it
18+
an optional step.
19+
20+
:param model: model to quantize
21+
"""
22+
23+
def _is_attention_module(module: Module):
24+
return "attention" in module.__class__.__name__.lower() and (
25+
hasattr(module, "k_proj")
26+
or hasattr(module, "v_proj")
27+
or hasattr(module, "qkv_proj")
28+
)
29+
30+
def _is_mlp_module(module: Module):
31+
return "mlp" in module.__class__.__name__.lower() and (
32+
hasattr(module, "gate_proj") or hasattr(module, "up_proj")
33+
)
34+
35+
def _valid_tensor_group_quant(layer_list: List[Linear]):
36+
"""
37+
Return True if all the linear layers in the layer_list are
38+
TENSOR_GROUP quantized.
39+
"""
40+
for layer in layer_list:
41+
scheme = getattr(layer, "quantization_scheme", None)
42+
if scheme is None:
43+
return False
44+
45+
weight_quant_args = scheme.weights
46+
47+
if weight_quant_args is None:
48+
return False
49+
50+
if weight_quant_args.strategy != QuantizationStrategy.TENSOR_GROUP:
51+
return False
52+
return True
53+
54+
with align_module_device(submodule):
55+
if _is_attention_module(submodule):
56+
# already fused/treated as one layer
57+
if hasattr(submodule, "qkv_proj"):
58+
return
59+
60+
if not _valid_tensor_group_quant(
61+
[submodule.q_proj, submodule.v_proj, submodule.k_proj]
62+
):
63+
return
64+
65+
global_scale = torch.min(
66+
torch.cat(
67+
(
68+
submodule.q_proj.weight_global_scale.data,
69+
submodule.k_proj.weight_global_scale.data,
70+
submodule.v_proj.weight_global_scale.data,
71+
)
72+
)
73+
)
74+
75+
update_parameter_data(submodule.q_proj, global_scale, "weight_global_scale")
76+
update_parameter_data(submodule.k_proj, global_scale, "weight_global_scale")
77+
update_parameter_data(submodule.v_proj, global_scale, "weight_global_scale")
78+
79+
with align_module_device(submodule):
80+
if _is_mlp_module(submodule):
81+
if not _valid_tensor_group_quant([submodule.gate_proj, submodule.up_proj]):
82+
return
83+
84+
global_scale = torch.min(
85+
torch.cat(
86+
(
87+
submodule.gate_proj.weight_global_scale.data,
88+
submodule.up_proj.weight_global_scale.data,
89+
)
90+
)
91+
)
92+
93+
update_parameter_data(
94+
submodule.gate_proj, global_scale, "weight_global_scale"
95+
)
96+
update_parameter_data(
97+
submodule.up_proj, global_scale, "weight_global_scale"
98+
)
Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
from collections import Counter
2-
from typing import Optional
32

43
import torch
5-
from compressed_tensors.quantization.quant_args import (
6-
FP4_E2M1_DATA,
7-
FP8_E4M3_DATA,
8-
FloatArgs,
9-
)
104

11-
__all__ = ["get_observer_token_count", "generate_gparam"]
5+
__all__ = ["get_observer_token_count"]
126

137

148
def get_observer_token_count(module: torch.nn.Module) -> Counter:
@@ -26,29 +20,3 @@ def get_observer_token_count(module: torch.nn.Module) -> Counter:
2620
module._num_observed_tokens
2721
)
2822
return token_counts
29-
30-
31-
# TODO: we have a similar function in ct already
32-
# consolidate when adding weight global scale
33-
# generation
34-
def generate_gparam(
35-
updated_min_val: torch.Tensor,
36-
updated_max_val: torch.Tensor,
37-
scale_data: Optional[FloatArgs] = FP8_E4M3_DATA,
38-
quant_data: Optional[FloatArgs] = FP4_E2M1_DATA,
39-
dtype: Optional[torch.dtype] = torch.float32,
40-
):
41-
"""
42-
Generate a global scale for an entire tensor (input_tensor).
43-
Goal of the scale is to ensure that the quantization (local) scale
44-
falls into the approproiate dtype range.
45-
46-
E.g. for NVFP4, group (local) scales are in dtype FP8. The global_scale
47-
attempts to use the entire FP8 dtype range while mapping a per-group max
48-
to the FP4 max.
49-
"""
50-
min_vals = torch.min(updated_min_val, torch.zeros_like(updated_min_val))
51-
max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val))
52-
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
53-
global_scale = scale_data.max * quant_data.max / max_val_pos
54-
return global_scale.to(dtype)

src/llmcompressor/observers/min_max.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22

33
import torch
44
from compressed_tensors.quantization.quant_args import QuantizationArgs
5-
from compressed_tensors.quantization.utils import calculate_qparams
5+
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
66
from compressed_tensors.utils import deprecated
77

88
from llmcompressor.observers.base import Observer
9-
from llmcompressor.observers.helpers import generate_gparam
109

1110
__all__ = ["MinMaxObserver", "MovingAverageMinMaxObserver"]
1211

0 commit comments

Comments
 (0)