Skip to content

Commit 61e5c00

Browse files
dsikkarahul-tuli
andauthored
[NVFP4] Expand observers to calculate gparam, support NVFP4 Activations (#1487)
# SUMMARY: - Add NVFP4 Example - Update compression condition to no longer be weight only - Support NVFP4 Activations: - Update observers to also provide the option to calculate gparam (global_param), not just qparams - Update dynamic activation condition checks to consider DynamicType.LOCAL # Testing - All test cases pass # Next Steps: We now have the framework to also calculate the weight global scale in llmcompressor. Will remove it from compressed-tensors and add it here onc this lands. --------- Co-authored-by: Rahul Tuli <rtuli@redhat.com>
1 parent a0adf83 commit 61e5c00

File tree

8 files changed

+270
-31
lines changed

8 files changed

+270
-31
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
7+
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
8+
9+
# Load model.
10+
model = AutoModelForCausalLM.from_pretrained(
11+
MODEL_ID, device_map="auto", torch_dtype="auto"
12+
)
13+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
14+
15+
16+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
17+
DATASET_SPLIT = "train_sft"
18+
19+
# Select number of samples. 512 samples is a good place to start.
20+
# Increasing the number of samples can improve accuracy.
21+
NUM_CALIBRATION_SAMPLES = 20
22+
MAX_SEQUENCE_LENGTH = 2048
23+
24+
# Load dataset and preprocess.
25+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
26+
ds = ds.shuffle(seed=42)
27+
28+
29+
def preprocess(example):
30+
return {
31+
"text": tokenizer.apply_chat_template(
32+
example["messages"],
33+
tokenize=False,
34+
)
35+
}
36+
37+
38+
ds = ds.map(preprocess)
39+
40+
41+
# Tokenize inputs.
42+
def tokenize(sample):
43+
return tokenizer(
44+
sample["text"],
45+
padding=False,
46+
max_length=MAX_SEQUENCE_LENGTH,
47+
truncation=True,
48+
add_special_tokens=False,
49+
)
50+
51+
52+
ds = ds.map(tokenize, remove_columns=ds.column_names)
53+
54+
# Configure the quantization algorithm and scheme.
55+
# In this case, we:
56+
# * quantize the weights to fp4 with per group 16 via ptq
57+
# * calibrate a global_scale for activations, which will be used to
58+
# quantize activations to fp4 on the fly
59+
recipe = QuantizationModifier(targets="Linear", scheme="NVFP4", ignore=["lm_head"])
60+
61+
# Apply quantization.
62+
oneshot(
63+
model=model,
64+
dataset=ds,
65+
recipe=recipe,
66+
max_seq_length=MAX_SEQUENCE_LENGTH,
67+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
68+
)
69+
70+
# Save to disk in compressed-tensors format.
71+
SAVE_DIR = MODEL_ID.split("/")[1] + "-NVFP4"
72+
model.save_pretrained(SAVE_DIR, save_compressed=True)
73+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import torch
44
from compressed_tensors.quantization import (
5+
DynamicType,
56
KVCacheScaleType,
67
QuantizationScheme,
78
QuantizationStatus,
9+
QuantizationStrategy,
810
)
911
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
1012
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
@@ -53,7 +55,10 @@ def initialize_observer(
5355

5456
quantization_args = getattr(quantization_scheme, arg_name, None)
5557
# dont need observers for dynamic
56-
if quantization_args is not None and not quantization_args.dynamic:
58+
if quantization_args is not None and quantization_args.dynamic in (
59+
False,
60+
DynamicType.LOCAL,
61+
):
5762
observer = Observer.load_from_registry(
5863
quantization_args.observer,
5964
quantization_args=quantization_args,
@@ -84,14 +89,43 @@ def call_observer(module: Module, base_name: str, value: Optional[torch.Tensor]
8489
"Must provide a value to observe if not using weight observer"
8590
)
8691

92+
quantization_scheme = getattr(module, "quantization_scheme", None)
93+
arg_name = "weights" if base_name == "weight" else f"{base_name}_activations"
94+
quant_args = getattr(quantization_scheme, arg_name, None)
95+
96+
# We always calculate quantizaton parameters by default and no global parameters
97+
should_calculate_gparam = False
98+
should_calculate_qparams = True
99+
100+
# TODO: will update to be the case for both weight and input in a follow-up
101+
# weight global calculate is currently done in ct right now;
102+
# should be moved here to unify global scale calculations
103+
if (
104+
quant_args.strategy == QuantizationStrategy.TENSOR_GROUP
105+
and base_name == "input"
106+
):
107+
should_calculate_gparam = True
108+
should_calculate_qparams = False
109+
87110
observer = getattr(module, f"{base_name}_observer")
88-
updated_scale, updated_zero_point = observer(
89-
value, g_idx=g_idx, global_scale=global_scale
111+
observer_outputs = observer(
112+
value,
113+
g_idx=g_idx,
114+
global_scale=global_scale,
115+
should_calculate_gparam=should_calculate_gparam,
90116
)
91117

92-
# update scale and zero point
93-
update_parameter_data(module, updated_scale, f"{base_name}_scale")
94-
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
118+
if should_calculate_gparam:
119+
updated_global_scale = observer_outputs
120+
update_parameter_data(
121+
module, updated_global_scale, f"{base_name}_global_scale"
122+
)
123+
124+
if should_calculate_qparams:
125+
# update scale and zero point
126+
updated_scale, updated_zero_point = observer_outputs
127+
update_parameter_data(module, updated_scale, f"{base_name}_scale")
128+
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
95129

96130

97131
def update_weight_zp_scale(module: Module):

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from compressed_tensors.quantization import (
5+
DynamicType,
56
QuantizationArgs,
67
QuantizationConfig,
78
QuantizationScheme,
@@ -212,7 +213,10 @@ def _initialize_observers(self, module: torch.nn.Module):
212213
return
213214

214215
scheme: QuantizationScheme = module.quantization_scheme
215-
input = scheme.input_activations and not scheme.input_activations.dynamic
216+
input = scheme.input_activations and scheme.input_activations.dynamic in (
217+
False,
218+
DynamicType.LOCAL,
219+
)
216220
weight = scheme.weights is not None
217221
output = scheme.output_activations and not scheme.output_activations.dynamic
218222
is_attention = is_attention_module(module)
@@ -241,7 +245,10 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
241245
continue
242246

243247
scheme: QuantizationScheme = module.quantization_scheme
244-
input = scheme.input_activations and not scheme.input_activations.dynamic
248+
input = scheme.input_activations and scheme.input_activations.dynamic in (
249+
False,
250+
DynamicType.LOCAL,
251+
)
245252
output = scheme.output_activations and not scheme.output_activations.dynamic
246253
is_attention = is_attention_module(module)
247254

src/llmcompressor/observers/base.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def forward(
4040
observed: Tensor,
4141
g_idx: Optional[Tensor] = None,
4242
global_scale: Optional[Tensor] = None,
43+
should_calculate_gparam: bool = False,
4344
) -> Tuple[FloatTensor, IntTensor]:
4445
"""
4546
maps directly to get_qparams
@@ -50,8 +51,12 @@ def forward(
5051
:return: tuple of scale and zero point based on last observed value
5152
"""
5253
self.record_observed_tokens(observed)
54+
if should_calculate_gparam:
55+
return self.get_gparam(observed=observed)
5356
return self.get_qparams(
54-
observed=observed, g_idx=g_idx, global_scale=global_scale
57+
observed=observed,
58+
g_idx=g_idx,
59+
global_scale=global_scale,
5560
)
5661

5762
def calculate_qparams(
@@ -68,11 +73,34 @@ def calculate_qparams(
6873
"""
6974
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
7075

76+
def calculate_gparam(
77+
self,
78+
observed: Tensor,
79+
) -> torch.Tensor:
80+
"""
81+
:param observed: observed tensor to calculate quantization parameters for
82+
:return: global scale derived from the observed tensor
83+
"""
84+
raise NotImplementedError(f"{self.__class__} must implement calculate_gparam")
85+
7186
def post_calculate_qparams(self) -> None:
7287
"""
7388
Run any logic specific to its observers after running calculate_qparams
7489
"""
7590

91+
def get_gparam(self, observed: Tensor):
92+
"""
93+
Function to derive a global scale parameter
94+
:param observed: observed tensor to calculate global parameters
95+
from
96+
:return: derived global scale
97+
"""
98+
if self.quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
99+
return self.calculate_gparam(observed)
100+
raise NotImplementedError(
101+
"global parameter generation is only supported for TENSOR_GROUP"
102+
)
103+
76104
def get_qparams(
77105
self,
78106
observed: Optional[Tensor] = None,

src/llmcompressor/observers/helpers.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
from collections import Counter
2+
from typing import Optional
23

34
import torch
5+
from compressed_tensors.quantization.quant_args import (
6+
FP4_E2M1_DATA,
7+
FP8_E4M3_DATA,
8+
FloatArgs,
9+
)
410

5-
__all__ = ["get_observer_token_count"]
11+
__all__ = ["get_observer_token_count", "generate_gparam"]
612

713

814
def get_observer_token_count(module: torch.nn.Module) -> Counter:
@@ -20,3 +26,29 @@ def get_observer_token_count(module: torch.nn.Module) -> Counter:
2026
module._num_observed_tokens
2127
)
2228
return token_counts
29+
30+
31+
# TODO: we have a similar function in ct already
32+
# consolidate when adding weight global scale
33+
# generation
34+
def generate_gparam(
35+
updated_min_val: torch.Tensor,
36+
updated_max_val: torch.Tensor,
37+
scale_data: Optional[FloatArgs] = FP8_E4M3_DATA,
38+
quant_data: Optional[FloatArgs] = FP4_E2M1_DATA,
39+
dtype: Optional[torch.dtype] = torch.float32,
40+
):
41+
"""
42+
Generate a global scale for an entire tensor (input_tensor).
43+
Goal of the scale is to ensure that the quantization (local) scale
44+
falls into the approproiate dtype range.
45+
46+
E.g. for NVFP4, group (local) scales are in dtype FP8. The global_scale
47+
attempts to use the entire FP8 dtype range while mapping a per-group max
48+
to the FP4 max.
49+
"""
50+
min_vals = torch.min(updated_min_val, torch.zeros_like(updated_min_val))
51+
max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val))
52+
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
53+
global_scale = scale_data.max * quant_data.max / max_val_pos
54+
return global_scale.to(dtype)

src/llmcompressor/observers/min_max.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from compressed_tensors.utils import deprecated
77

88
from llmcompressor.observers.base import Observer
9+
from llmcompressor.observers.helpers import generate_gparam
910

1011
__all__ = ["MinMaxObserver", "MovingAverageMinMaxObserver"]
1112

@@ -29,13 +30,12 @@ def __init__(
2930
self.max_val = {}
3031
self.averaging_constant = averaging_constant
3132

32-
def calculate_qparams(
33+
def calculate_updated_min_max(
3334
self,
3435
observed: torch.Tensor,
3536
reduce_dims: Optional[Tuple[int]] = None,
3637
tensor_id: Optional[Any] = None,
37-
global_scale: Optional[torch.Tensor] = None,
38-
) -> Tuple[torch.FloatTensor, torch.IntTensor]:
38+
):
3939
"""
4040
Updates the observed min and max using a moving average smoothed by the
4141
averaging_constant. Set the averaging_constant to 1.0 to disable averaging.
@@ -46,8 +46,7 @@ def calculate_qparams(
4646
reduced dimensions
4747
:param tensor_id: Optional id if different ranges of observed tensors are
4848
passed, useful for sharding tensors by group_size
49-
:param global_scale: optional scale to further scale local quantization scales
50-
:return: tuple of scale and zero point derived from the observed tensor
49+
:return: updated min and max values
5150
"""
5251
tensor_id = tensor_id or "default"
5352

@@ -59,12 +58,7 @@ def calculate_qparams(
5958

6059
# early stopping, save some computation and memory
6160
if self.averaging_constant == 1.0:
62-
return calculate_qparams(
63-
min_vals=min_val,
64-
max_vals=max_val,
65-
quantization_args=self.quantization_args,
66-
global_scale=global_scale,
67-
)
61+
return min_val, max_val
6862

6963
running_min_val = self.min_val.get(tensor_id, None)
7064
running_max_val = self.max_val.get(tensor_id, None)
@@ -82,7 +76,46 @@ def calculate_qparams(
8276

8377
self.min_val[tensor_id] = updated_min_val
8478
self.max_val[tensor_id] = updated_max_val
79+
return updated_min_val, updated_max_val
80+
81+
def calculate_gparam(self, observed: torch.Tensor) -> torch.Tensor:
82+
"""
83+
Generate a global scale using the observed min and max.
8584
85+
:param observed: observed tensor to calculate quantization parameters for
86+
:return: updated global scale derived from the observed tensor
87+
"""
88+
89+
updated_min_val, updated_max_val = self.calculate_updated_min_max(
90+
observed=observed
91+
)
92+
return generate_gparam(
93+
updated_min_val=updated_min_val, updated_max_val=updated_max_val
94+
)
95+
96+
def calculate_qparams(
97+
self,
98+
observed: torch.Tensor,
99+
reduce_dims: Optional[Tuple[int]] = None,
100+
tensor_id: Optional[Any] = None,
101+
global_scale: Optional[torch.Tensor] = None,
102+
) -> Tuple[torch.FloatTensor, torch.IntTensor]:
103+
"""
104+
Generate a scale and zero-point using the observed min and max.
105+
106+
:param observed: observed tensor to calculate quantization parameters for
107+
:param reduce_dims: optional tuple of dimensions to reduce along,
108+
returned scale and zero point will be shaped (1,) along the
109+
reduced dimensions
110+
:param tensor_id: Optional id if different ranges of observed tensors are
111+
passed, useful for sharding tensors by group_size
112+
:param global_scale: optional scale to further scale local quantization scales
113+
:return: tuple of scale and zero point derived from the observed tensor
114+
"""
115+
116+
updated_min_val, updated_max_val = self.calculate_updated_min_max(
117+
observed=observed, tensor_id=tensor_id, reduce_dims=reduce_dims
118+
)
86119
return calculate_qparams(
87120
min_vals=updated_min_val,
88121
max_vals=updated_max_val,

0 commit comments

Comments
 (0)