Skip to content

Commit eddd2a1

Browse files
authored
[NVFP4] Add tensor_group strategy; enable NVFP4 Activations (#317)
* add nvfp4 args * format * dont use a dataclass * remove dataclass * test skeletons * add tests * update * introduce new tensor_group strategy; expand dynamic forward pass to do per group; add input_global_scale * update * update dynamic scale generation * Update test_initialize.py * clean up * more clean-up * add additional checks * Update forward.py * Update forward.py * edit * remove unused import * use ceil
1 parent 8a116f7 commit eddd2a1

File tree

5 files changed

+83
-12
lines changed

5 files changed

+83
-12
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,11 @@ def _process_quantization(
189189
q_min, q_max = calculate_range(args, x.device)
190190
group_size = args.group_size
191191

192-
if args.strategy == QuantizationStrategy.GROUP:
192+
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
193+
if args.strategy == QuantizationStrategy.TENSOR_GROUP:
194+
# only valid for activation; remove dim 0
195+
x = x.squeeze(0)
196+
193197
output_dtype = dtype if dtype is not None else x.dtype
194198
output = torch.zeros_like(x).to(output_dtype)
195199
columns = output.shape[1]
@@ -251,6 +255,9 @@ def _process_quantization(
251255
if not is_column_order:
252256
output = safe_permute(output, torch.argsort(perm), dim=1)
253257

258+
if args.strategy == QuantizationStrategy.TENSOR_GROUP:
259+
output = output.unsqueeze(0)
260+
254261
else: # covers channel, token and tensor strategies
255262
if do_quantize:
256263
output = _quantize(
@@ -352,9 +359,11 @@ def forward_quantize(
352359
g_idx = getattr(module, "weight_g_idx", None)
353360
global_scale = getattr(module, f"{base_name}_global_scale", None)
354361

355-
if args.dynamic:
362+
if args.dynamic or args.strategy == QuantizationStrategy.TENSOR_GROUP:
356363
# dynamic quantization - determine the scale/zp on the fly
357-
scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
364+
scale, zero_point = compute_dynamic_scales_and_zp(
365+
value=value, args=args, module=module, global_scale=global_scale
366+
)
358367
else:
359368
# static quantization - get scale and zero point from layer
360369
scale = getattr(module, f"{base_name}_scale")
@@ -388,6 +397,7 @@ def _quantize(
388397
scale = scale.to(global_scale.dtype) / global_scale
389398

390399
scaled = x / scale
400+
391401
if zero_point is not None:
392402
scaled += zero_point.to(x.dtype)
393403

@@ -398,6 +408,7 @@ def _quantize(
398408
q_max,
399409
)
400410
quantized_value = round_to_quantized_type(clamped_value, args)
411+
401412
if dtype is not None:
402413
quantized_value = quantized_value.to(dtype)
403414

@@ -422,6 +433,7 @@ def _dequantize(
422433

423434
if zero_point is not None:
424435
dequant_value = dequant_value - zero_point.to(scale.dtype)
436+
425437
dequant_value = dequant_value * scale
426438

427439
if dtype is not None:

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def _initialize_scale_zero_point(
181181
# there is likely bug
182182

183183
if is_fp4(quantization_args=quantization_args) and base_name == "weight":
184+
assert quantization_args.strategy == QuantizationStrategy.GROUP
184185
scale_dtype = FP8_E4M3_DATA.dtype
185186
# When applying weight-only FP4 quantization, generate a global_scale
186187
# This scale is applied during runtime to ensure that the generated
@@ -193,19 +194,26 @@ def _initialize_scale_zero_point(
193194
module, f"{base_name}_global_scale", init_global_scale
194195
)
195196

197+
# initializes empty scale, zero point, and g_idx parameters for the module
198+
if is_fp4(quantization_args=quantization_args) and base_name == "input":
199+
assert quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP
200+
scale_dtype = torch.float32
201+
scale_name = f"{base_name}_global_scale"
202+
else:
203+
scale_name = f"{base_name}_scale"
204+
196205
if scale_dtype not in [
197206
torch.float16,
198207
torch.bfloat16,
199208
torch.float32,
200209
] and not is_fp4(quantization_args=quantization_args):
201210
scale_dtype = torch.float16
202211

203-
# initializes empty scale, zero point, and g_idx parameters for the module
204212
init_scale = Parameter(
205213
torch.empty(expected_shape, dtype=scale_dtype, device=device),
206214
requires_grad=False,
207215
)
208-
register_offload_parameter(module, f"{base_name}_scale", init_scale)
216+
register_offload_parameter(module, scale_name, init_scale)
209217

210218
if force_zero_point or not quantization_args.symmetric:
211219
if is_fp4(quantization_args=quantization_args):

src/compressed_tensors/quantization/quant_args.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class QuantizationStrategy(str, Enum):
9898
GROUP = "group"
9999
BLOCK = "block"
100100
TOKEN = "token"
101+
TENSOR_GROUP = "tensor_group"
101102

102103

103104
class ActivationOrdering(Aliasable, str, Enum):
@@ -239,7 +240,8 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
239240
if (
240241
group_size is not None
241242
and group_size > 0
242-
and strategy != QuantizationStrategy.GROUP
243+
and strategy
244+
not in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP)
243245
):
244246
raise ValueError("group_size requires strategy to be set to 'group'")
245247

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,29 @@ def is_preset_scheme(name: str) -> bool:
111111
)
112112
)
113113

114+
# TODO: the local scales are dynamic, the global scale is static/calibrated
115+
# We could potentially extend the dynamic kwarg so that is goes
116+
# beyond being just a boolean - however we may also want a dynamically
117+
# generated global scale, so we could use that to separate between the two
118+
NVFP4 = dict(
119+
weights=QuantizationArgs(
120+
num_bits=4,
121+
type=QuantizationType.FLOAT,
122+
strategy=QuantizationStrategy.GROUP,
123+
symmetric=True,
124+
dynamic=False,
125+
group_size=16,
126+
),
127+
input_activations=QuantizationArgs(
128+
num_bits=4,
129+
type=QuantizationType.FLOAT,
130+
strategy=QuantizationStrategy.TENSOR_GROUP,
131+
symmetric=True,
132+
dynamic=False,
133+
group_size=16,
134+
),
135+
)
136+
114137
# 8 bit integer weights and 8 bit activations quantization
115138
INT8_W8A8 = dict(
116139
weights=QuantizationArgs(
@@ -237,4 +260,5 @@ def is_preset_scheme(name: str) -> bool:
237260
"FP8": FP8,
238261
"FP8_DYNAMIC": FP8_DYNAMIC,
239262
"NVFP4A16": NVFP4A16,
263+
"NVFP4": NVFP4,
240264
}

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import math
1617
from typing import Generator, List, Optional, Tuple
1718

1819
import torch
@@ -103,7 +104,9 @@ def calculate_qparams(
103104
if is_fp4(quantization_args=quantization_args) and global_scale is not None:
104105
# Conditionally scale the generated local scale by a global_scale
105106
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
107+
scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min)
106108
scales = scales.to(FP8_E4M3_DATA.dtype)
109+
107110
else:
108111
scales = max_val_pos / (float(bit_range) / 2)
109112

@@ -143,7 +146,12 @@ def calculate_qparams(
143146
return scales, zero_points
144147

145148

146-
def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
149+
def compute_dynamic_scales_and_zp(
150+
value: Tensor,
151+
args: QuantizationArgs,
152+
module: torch.nn.Module,
153+
global_scale: Optional[Tensor] = None,
154+
):
147155
"""
148156
Returns the computed scales and zero points for dynamic activation
149157
quantization.
@@ -155,24 +163,41 @@ def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
155163
reduced dimensions
156164
:return: tuple of scale and zero point derived from the observed tensor
157165
"""
166+
167+
keep_dims = True
158168
if args.strategy == QuantizationStrategy.TOKEN:
159169
dim = {1, 2}
160170
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
161171
elif args.strategy == QuantizationStrategy.TENSOR:
162172
reduce_dims = None
173+
elif args.strategy == QuantizationStrategy.TENSOR_GROUP:
174+
# per group dynamic quantization - only valid for
175+
# activations
176+
dim = {0, 1}
177+
value = value.squeeze(0)
178+
reduce_dims = tuple(idx for idx in range(3) if idx not in dim)
179+
keep_dims = False
180+
value = torch.reshape(
181+
value,
182+
(
183+
value.shape[0],
184+
math.ceil(value.shape[1] / args.group_size),
185+
args.group_size,
186+
),
187+
)
163188
else:
164189
raise ValueError(
165-
f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ",
166-
"must be used for dynamic quantization",
190+
"Dynamic quantization is only supported for ",
191+
f"{QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP}",
167192
)
168193

169194
if not reduce_dims:
170195
min_val, max_val = torch.aminmax(value)
171196
else:
172-
min_val = torch.amin(value, dim=reduce_dims, keepdims=True)
173-
max_val = torch.amax(value, dim=reduce_dims, keepdims=True)
197+
min_val = torch.amin(value, dim=reduce_dims, keepdims=keep_dims)
198+
max_val = torch.amax(value, dim=reduce_dims, keepdims=keep_dims)
174199

175-
return calculate_qparams(min_val, max_val, args)
200+
return calculate_qparams(min_val, max_val, args, global_scale=global_scale)
176201

177202

178203
def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:

0 commit comments

Comments
 (0)