Skip to content

Commit 2caa1e7

Browse files
authored
Revert "[NVFP4] Use observers to generate global weight scales " (#1507)
Reverts #1504 For some reason, automerge merged in this PR with just one approval
1 parent 4b2b172 commit 2caa1e7

File tree

8 files changed

+68
-179
lines changed

8 files changed

+68
-179
lines changed

examples/quantization_w4a16_fp4/llama3_example.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,6 @@
1919
# Apply quantization.
2020
oneshot(model=model, recipe=recipe)
2121

22-
print("\n\n")
23-
print("========== SAMPLE GENERATION ==============")
24-
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
25-
output = model.generate(input_ids, max_new_tokens=100)
26-
print(tokenizer.decode(output[0]))
27-
print("==========================================\n\n")
28-
29-
3022
# Save to disk in compressed-tensors format.
3123
SAVE_DIR = MODEL_ID.split("/")[1] + "-NVFP4A16"
3224
model.save_pretrained(SAVE_DIR, save_compressed=True)

examples/quantization_w4a4_fp4/llama3_example.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,6 @@ def tokenize(sample):
6767
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
6868
)
6969

70-
print("\n\n")
71-
print("========== SAMPLE GENERATION ==============")
72-
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
73-
output = model.generate(input_ids, max_new_tokens=100)
74-
print(tokenizer.decode(output[0]))
75-
print("==========================================\n\n")
76-
77-
7870
# Save to disk in compressed-tensors format.
7971
SAVE_DIR = MODEL_ID.split("/")[1] + "-NVFP4"
8072
model.save_pretrained(SAVE_DIR, save_compressed=True)

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 32 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
"freeze_module_quantization",
3030
"apply_calibration_status",
3131
"reset_quantization_status",
32-
"update_weight_global_scale",
3332
]
3433

3534

@@ -67,13 +66,7 @@ def initialize_observer(
6766
module.register_module(f"{base_name}_observer", observer)
6867

6968

70-
def call_observer(
71-
module: Module,
72-
base_name: str,
73-
value: Optional[torch.Tensor] = None,
74-
should_calculate_gparam: bool = False,
75-
should_calculate_qparams: bool = True,
76-
):
69+
def call_observer(module: Module, base_name: str, value: Optional[torch.Tensor] = None):
7770
"""
7871
Call a module's attached input/weight/output observer using a provided value.
7972
Update the module's scale and zp using the observer's return values.
@@ -87,51 +80,54 @@ def call_observer(
8780
if base_name == "weight":
8881
value = module.weight
8982
g_idx = getattr(module, "weight_g_idx", None)
83+
global_scale = getattr(module, f"{base_name}_global_scale", None)
9084
elif value is not None:
9185
g_idx = None
86+
global_scale = None
9287
else:
9388
raise ValueError(
9489
"Must provide a value to observe if not using weight observer"
9590
)
9691

92+
quantization_scheme = getattr(module, "quantization_scheme", None)
93+
arg_name = "weights" if base_name == "weight" else f"{base_name}_activations"
94+
quant_args = getattr(quantization_scheme, arg_name, None)
95+
96+
# We always calculate quantizaton parameters by default and no global parameters
97+
should_calculate_gparam = False
98+
should_calculate_qparams = True
99+
100+
# TODO: will update to be the case for both weight and input in a follow-up
101+
# weight global calculate is currently done in ct right now;
102+
# should be moved here to unify global scale calculations
103+
if (
104+
quant_args.strategy == QuantizationStrategy.TENSOR_GROUP
105+
and base_name == "input"
106+
):
107+
should_calculate_gparam = True
108+
should_calculate_qparams = False
109+
97110
observer = getattr(module, f"{base_name}_observer")
111+
observer_outputs = observer(
112+
value,
113+
g_idx=g_idx,
114+
global_scale=global_scale,
115+
should_calculate_gparam=should_calculate_gparam,
116+
)
98117

99118
if should_calculate_gparam:
100-
global_scale = observer(
101-
value,
102-
should_calculate_gparam=True,
119+
updated_global_scale = observer_outputs
120+
update_parameter_data(
121+
module, updated_global_scale, f"{base_name}_global_scale"
103122
)
104-
update_parameter_data(module, global_scale, f"{base_name}_global_scale")
105-
else:
106-
global_scale = getattr(module, f"{base_name}_global_scale", None)
107123

108124
if should_calculate_qparams:
109-
updated_scale, updated_zero_point = observer(
110-
value, g_idx=g_idx, global_scale=global_scale
111-
)
125+
# update scale and zero point
126+
updated_scale, updated_zero_point = observer_outputs
112127
update_parameter_data(module, updated_scale, f"{base_name}_scale")
113128
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
114129

115130

116-
def update_weight_global_scale(module: Module):
117-
if getattr_chain(module, "quantization_scheme.weights", None) is None:
118-
return
119-
120-
if (
121-
getattr_chain(module, "quantization_scheme.weights.strategy", None)
122-
!= QuantizationStrategy.TENSOR_GROUP
123-
):
124-
return
125-
126-
call_observer(
127-
module,
128-
base_name="weight",
129-
should_calculate_gparam=True,
130-
should_calculate_qparams=False,
131-
)
132-
module.weight_observer.reset()
133-
134-
135131
def update_weight_zp_scale(module: Module):
136132
"""
137133
marks a layer as ready for calibration which activates observers
@@ -169,24 +165,10 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
169165
if value.numel() == 0:
170166
return
171167

172-
quantization_scheme = getattr(module, "quantization_scheme", None)
173-
quantization_args = getattr(quantization_scheme, f"{base_name}_activations", None)
174-
175-
calculate_qparams = True
176-
calculate_gparam = False
177-
178-
if quantization_args is not None:
179-
if quantization_args.dynamic in (True, DynamicType.LOCAL):
180-
calculate_qparams = False
181-
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
182-
calculate_gparam = True
183-
184168
call_observer(
185169
module=module,
186170
base_name=base_name,
187171
value=value,
188-
should_calculate_gparam=calculate_gparam,
189-
should_calculate_qparams=calculate_qparams,
190172
)
191173

192174

src/llmcompressor/modifiers/quantization/quantization/base.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,8 @@
22

33
from llmcompressor.core import Event, EventType, State
44
from llmcompressor.modifiers import Modifier
5-
from llmcompressor.modifiers.quantization.calibration import (
6-
update_weight_global_scale,
7-
update_weight_zp_scale,
8-
)
5+
from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale
96
from llmcompressor.modifiers.quantization.quantization.mixin import QuantizationMixin
10-
from llmcompressor.modifiers.utils import update_fused_layer_weight_global_scales
117

128
__all__ = ["QuantizationModifier"]
139

@@ -70,14 +66,7 @@ def on_start(self, state: State, event: Event, **kwargs):
7066
QuantizationMixin.start_calibration(self, state.model)
7167

7268
modules = list(state.model.modules())
73-
# TODO: this step can be combined with update_weight_zp_scale
74-
# once update_fused_layer_weight_global_scales is removed
75-
# and not required by vLLM
76-
for module in tqdm.tqdm(modules):
77-
update_weight_global_scale(module)
78-
7969
for module in tqdm.tqdm(modules, desc="Calibrating weights"):
80-
update_fused_layer_weight_global_scales(module)
8170
update_weight_zp_scale(module)
8271

8372
def on_event(self, state: State, event: Event, **kwargs):
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
# flake8: noqa
22

33
from .constants import *
4-
from .helpers import *

src/llmcompressor/modifiers/utils/helpers.py

Lines changed: 0 additions & 98 deletions
This file was deleted.

src/llmcompressor/observers/helpers.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
from collections import Counter
2+
from typing import Optional
23

34
import torch
5+
from compressed_tensors.quantization.quant_args import (
6+
FP4_E2M1_DATA,
7+
FP8_E4M3_DATA,
8+
FloatArgs,
9+
)
410

5-
__all__ = ["get_observer_token_count"]
11+
__all__ = ["get_observer_token_count", "generate_gparam"]
612

713

814
def get_observer_token_count(module: torch.nn.Module) -> Counter:
@@ -20,3 +26,29 @@ def get_observer_token_count(module: torch.nn.Module) -> Counter:
2026
module._num_observed_tokens
2127
)
2228
return token_counts
29+
30+
31+
# TODO: we have a similar function in ct already
32+
# consolidate when adding weight global scale
33+
# generation
34+
def generate_gparam(
35+
updated_min_val: torch.Tensor,
36+
updated_max_val: torch.Tensor,
37+
scale_data: Optional[FloatArgs] = FP8_E4M3_DATA,
38+
quant_data: Optional[FloatArgs] = FP4_E2M1_DATA,
39+
dtype: Optional[torch.dtype] = torch.float32,
40+
):
41+
"""
42+
Generate a global scale for an entire tensor (input_tensor).
43+
Goal of the scale is to ensure that the quantization (local) scale
44+
falls into the approproiate dtype range.
45+
46+
E.g. for NVFP4, group (local) scales are in dtype FP8. The global_scale
47+
attempts to use the entire FP8 dtype range while mapping a per-group max
48+
to the FP4 max.
49+
"""
50+
min_vals = torch.min(updated_min_val, torch.zeros_like(updated_min_val))
51+
max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val))
52+
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
53+
global_scale = scale_data.max * quant_data.max / max_val_pos
54+
return global_scale.to(dtype)

src/llmcompressor/observers/min_max.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import torch
44
from compressed_tensors.quantization.quant_args import QuantizationArgs
5-
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
5+
from compressed_tensors.quantization.utils import calculate_qparams
66
from compressed_tensors.utils import deprecated
77

88
from llmcompressor.observers.base import Observer
9+
from llmcompressor.observers.helpers import generate_gparam
910

1011
__all__ = ["MinMaxObserver", "MovingAverageMinMaxObserver"]
1112

0 commit comments

Comments
 (0)