Skip to content

Commit 3f5705d

Browse files
[NVFP4] Expand dynamic types, clean-up conditions (#325)
* add DynamicType * update to use tensor_group * more condition clean-up * update global scale creation * fix conditions, fix tests * add validation * update/fix conditiosn * Update src/compressed_tensors/quantization/lifecycle/initialize.py Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> * Update src/compressed_tensors/quantization/quant_args.py Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> * use explicit condition --------- Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent 8367985 commit 3f5705d

File tree

7 files changed

+135
-77
lines changed

7 files changed

+135
-77
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020
from compressed_tensors.quantization.quant_args import (
21+
DynamicType,
2122
QuantizationArgs,
2223
QuantizationStrategy,
2324
QuantizationType,
@@ -190,8 +191,8 @@ def _process_quantization(
190191
group_size = args.group_size
191192

192193
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
193-
if args.strategy == QuantizationStrategy.TENSOR_GROUP:
194-
# only valid for activation; remove dim 0
194+
n_dims = x.shape
195+
if len(n_dims) > 2:
195196
x = x.squeeze(0)
196197

197198
output_dtype = dtype if dtype is not None else x.dtype
@@ -255,7 +256,7 @@ def _process_quantization(
255256
if not is_column_order:
256257
output = safe_permute(output, torch.argsort(perm), dim=1)
257258

258-
if args.strategy == QuantizationStrategy.TENSOR_GROUP:
259+
if len(n_dims) > 2:
259260
output = output.unsqueeze(0)
260261

261262
else: # covers channel, token and tensor strategies
@@ -359,7 +360,7 @@ def forward_quantize(
359360
g_idx = getattr(module, "weight_g_idx", None)
360361
global_scale = getattr(module, f"{base_name}_global_scale", None)
361362

362-
if args.dynamic or args.strategy == QuantizationStrategy.TENSOR_GROUP:
363+
if args.dynamic in (True, DynamicType.LOCAL):
363364
# dynamic quantization - determine the scale/zp on the fly
364365
scale, zero_point = compute_dynamic_scales_and_zp(
365366
value=value, args=args, module=module, global_scale=global_scale

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,33 @@ def _initialize_scale_zero_point(
156156
force_zero_point: bool = True,
157157
scale_dtype: Optional[torch.dtype] = None,
158158
):
159-
if quantization_args.dynamic:
159+
if quantization_args.dynamic is True:
160160
return
161161

162162
# initialize on execution device to avoid performing quantized ops on cpu
163163
device = get_execution_device(module)
164164

165-
# infer expected scale/zero point shape
165+
# 1. Create global_scales for tensor_group
166+
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
167+
# TODO: should move to llmcompressor
168+
if base_name == "weight":
169+
# When applying weight-only FP4 quantization, generate a global_scale
170+
# This scale is applied during runtime to ensure that the generated
171+
# local scale falls properly within the FP8 range (i.e max value is FP8_max)
172+
# which is the expected dtype of NVFP4A16 scales
173+
value = generate_global_scale(input_tensor=module.weight)
174+
value = value.to(device)
175+
init_global_scale = Parameter(value, requires_grad=False)
176+
else:
177+
init_global_scale = Parameter(
178+
torch.empty(1, dtype=torch.float32, device=device),
179+
requires_grad=False,
180+
)
181+
register_offload_parameter(
182+
module, f"{base_name}_global_scale", init_global_scale
183+
)
184+
185+
# 2. Infer expected scale/zero point shape
166186
if quantization_args.strategy == QuantizationStrategy.TOKEN:
167187
expected_shape = (1, 1)
168188
else:
@@ -172,55 +192,35 @@ def _initialize_scale_zero_point(
172192
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
173193
# (output_channels, 1)
174194
expected_shape = (weight_shape[0], 1)
175-
elif quantization_args.strategy == QuantizationStrategy.GROUP:
195+
elif quantization_args.strategy in (
196+
QuantizationStrategy.TENSOR_GROUP,
197+
QuantizationStrategy.GROUP,
198+
):
176199
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
177200
expected_shape = (weight_shape[0], max(num_groups, 1))
178201

202+
# 3. Identify quantization scale and zp dtype
179203
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
180-
# TODO: consider erroring out in the future as if the dtype if not one fo these,
181-
# there is likely bug
182-
183-
if is_fp4(quantization_args=quantization_args) and base_name == "weight":
184-
assert quantization_args.strategy == QuantizationStrategy.GROUP
185-
scale_dtype = FP8_E4M3_DATA.dtype
186-
# When applying weight-only FP4 quantization, generate a global_scale
187-
# This scale is applied during runtime to ensure that the generated
188-
# local scale falls properly within the FP8 range (i.e max value is FP8_max)
189-
# which is the expected dtype of NVFP4A16 scales
190-
value = generate_global_scale(input_tensor=module.weight)
191-
value = value.to(device)
192-
init_global_scale = Parameter(value, requires_grad=False)
193-
register_offload_parameter(
194-
module, f"{base_name}_global_scale", init_global_scale
195-
)
196204

197-
# initializes empty scale, zero point, and g_idx parameters for the module
198-
if is_fp4(quantization_args=quantization_args) and base_name == "input":
199-
assert quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP
200-
scale_dtype = torch.float32
201-
scale_name = f"{base_name}_global_scale"
205+
if is_fp4(quantization_args=quantization_args):
206+
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
202207
else:
203-
scale_name = f"{base_name}_scale"
204-
205-
if scale_dtype not in [
206-
torch.float16,
207-
torch.bfloat16,
208-
torch.float32,
209-
] and not is_fp4(quantization_args=quantization_args):
210-
scale_dtype = torch.float16
211-
212-
init_scale = Parameter(
213-
torch.empty(expected_shape, dtype=scale_dtype, device=device),
214-
requires_grad=False,
215-
)
216-
register_offload_parameter(module, scale_name, init_scale)
208+
# TODO: consider erroring out in the future as if the dtype if not one of these,
209+
# there is likely bug
210+
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
211+
scale_dtype = torch.float16
212+
zp_dtype = quantization_args.pytorch_dtype()
213+
214+
# 4. Initializes empty scale, zero point, and g_idx parameters for the module
215+
# do not init scales for quantzation_args.dynamic == DynamicType.local
216+
if not quantization_args.dynamic:
217+
init_scale = Parameter(
218+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
219+
requires_grad=False,
220+
)
221+
register_offload_parameter(module, f"{base_name}_scale", init_scale)
217222

218223
if force_zero_point or not quantization_args.symmetric:
219-
if is_fp4(quantization_args=quantization_args):
220-
zp_dtype = FP8_E4M3_DATA.dtype
221-
else:
222-
zp_dtype = quantization_args.pytorch_dtype()
223-
224224
init_zero_point = Parameter(
225225
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
226226
requires_grad=False,

src/compressed_tensors/quantization/quant_args.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"QuantizationArgs",
3333
"round_to_quantized_type",
3434
"ActivationOrdering",
35+
"DynamicType",
3536
]
3637

3738

@@ -101,6 +102,21 @@ class QuantizationStrategy(str, Enum):
101102
TENSOR_GROUP = "tensor_group"
102103

103104

105+
class DynamicType(str, Enum):
106+
"""
107+
Enum storing potential dynamic types.
108+
109+
1. If dynamic is True, all quantization parameters are generated on the fly.
110+
2. If dynamic is False, all quantization parameters generated are static.
111+
3. If "local" is provided, only local quantization parameters are dynamic.
112+
113+
Note: "local" is only currently supported for NVFP4.
114+
115+
"""
116+
117+
LOCAL = "local"
118+
119+
104120
class ActivationOrdering(Aliasable, str, Enum):
105121
"""
106122
Enum storing strategies for activation ordering
@@ -153,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
153169
group_size: Optional[int] = None
154170
strategy: Optional[QuantizationStrategy] = None
155171
block_structure: Optional[str] = None
156-
dynamic: bool = False
172+
dynamic: Union[DynamicType, bool] = False
157173
actorder: Union[ActivationOrdering, bool, None] = None
158174
observer: Optional[str] = Field(
159175
default=None,
@@ -207,6 +223,12 @@ def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
207223

208224
return value
209225

226+
@field_validator("dynamic", mode="before")
227+
def validate_dynamic(cls, value) -> Union[DynamicType, bool]:
228+
if isinstance(value, str):
229+
return DynamicType(value.lower())
230+
return value
231+
210232
@model_validator(mode="after")
211233
def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
212234
# extract user-passed values from dictionary
@@ -257,18 +279,31 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
257279
if strategy not in (
258280
QuantizationStrategy.TOKEN,
259281
QuantizationStrategy.TENSOR,
282+
QuantizationStrategy.TENSOR_GROUP,
260283
):
261284
raise ValueError(
262-
f"One of {QuantizationStrategy.TOKEN} or "
263-
f"{QuantizationStrategy.TENSOR} must be used for dynamic ",
264-
"quantization",
285+
f"One of {(QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP)} "
286+
"must be used for dynamic quantization",
265287
)
288+
289+
if (
290+
dynamic == DynamicType.LOCAL
291+
and strategy != QuantizationStrategy.TENSOR_GROUP
292+
):
293+
raise ValueError("local is only supported for strategy tensor_group")
294+
266295
if observer is not None:
267-
if observer != "memoryless": # avoid annoying users with old configs
268-
warnings.warn(
269-
"No observer is used for dynamic quantization, setting to None"
270-
)
271-
observer = None
296+
if dynamic is True: # checking if dynamic is True, not "local"
297+
if (
298+
observer != "memoryless"
299+
): # avoid annoying users with old configs
300+
warnings.warn(
301+
"No observer is used for dynamic quantization, setting to None"
302+
)
303+
observer = None
304+
else:
305+
if dynamic == DynamicType.LOCAL:
306+
observer = "minmax"
272307

273308
elif observer is None:
274309
# default to minmax for non-dynamic cases

src/compressed_tensors/quantization/quant_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Dict, List, Optional, Union
1717

1818
from compressed_tensors.config import CompressionFormat
19-
from compressed_tensors.quantization.quant_args import QuantizationArgs
19+
from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
2020
from compressed_tensors.quantization.quant_scheme import (
2121
QuantizationScheme,
2222
preset_name_to_scheme,
@@ -251,7 +251,7 @@ def requires_calibration_data(self):
251251

252252
for _, scheme in self.config_groups.items():
253253
if scheme.input_activations is not None:
254-
if not scheme.input_activations.dynamic:
254+
if scheme.input_activations.dynamic in (False, DynamicType.LOCAL):
255255
return True
256256
if scheme.output_activations is not None:
257257
if not scheme.output_activations.dynamic:

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Dict, List, Optional
1717

1818
from compressed_tensors.quantization.quant_args import (
19+
DynamicType,
1920
QuantizationArgs,
2021
QuantizationStrategy,
2122
QuantizationType,
@@ -104,22 +105,19 @@ def is_preset_scheme(name: str) -> bool:
104105
weights=QuantizationArgs(
105106
num_bits=4,
106107
type=QuantizationType.FLOAT,
107-
strategy=QuantizationStrategy.GROUP,
108+
strategy=QuantizationStrategy.TENSOR_GROUP,
108109
symmetric=True,
109110
dynamic=False,
110111
group_size=16,
111112
)
112113
)
113114

114-
# TODO: the local scales are dynamic, the global scale is static/calibrated
115-
# We could potentially extend the dynamic kwarg so that is goes
116-
# beyond being just a boolean - however we may also want a dynamically
117-
# generated global scale, so we could use that to separate between the two
115+
118116
NVFP4 = dict(
119117
weights=QuantizationArgs(
120118
num_bits=4,
121119
type=QuantizationType.FLOAT,
122-
strategy=QuantizationStrategy.GROUP,
120+
strategy=QuantizationStrategy.TENSOR_GROUP,
123121
symmetric=True,
124122
dynamic=False,
125123
group_size=16,
@@ -129,7 +127,7 @@ def is_preset_scheme(name: str) -> bool:
129127
type=QuantizationType.FLOAT,
130128
strategy=QuantizationStrategy.TENSOR_GROUP,
131129
symmetric=True,
132-
dynamic=False,
130+
dynamic=DynamicType.LOCAL,
133131
group_size=16,
134132
),
135133
)

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,10 @@ def compute_dynamic_scales_and_zp(
171171
elif args.strategy == QuantizationStrategy.TENSOR:
172172
reduce_dims = None
173173
elif args.strategy == QuantizationStrategy.TENSOR_GROUP:
174-
# per group dynamic quantization - only valid for
175-
# activations
174+
if len(value.shape) > 2:
175+
value = value.squeeze(0)
176+
176177
dim = {0, 1}
177-
value = value.squeeze(0)
178178
reduce_dims = tuple(idx for idx in range(3) if idx not in dim)
179179
keep_dims = False
180180
value = torch.reshape(

tests/test_quantization/lifecycle/test_initialize.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,23 @@ def test_initialize_module_for_quantization_offloaded(
156156
None,
157157
),
158158
(
159-
QuantizationArgs(strategy="group", group_size=16, type="float", num_bits=4),
159+
QuantizationArgs(
160+
strategy="tensor_group", group_size=16, type="float", num_bits=4
161+
),
160162
None,
161163
),
164+
(
165+
QuantizationArgs(
166+
strategy="tensor_group", group_size=16, type="float", num_bits=4
167+
),
168+
QuantizationArgs(
169+
strategy="tensor_group",
170+
group_size=16,
171+
type="float",
172+
num_bits=4,
173+
dynamic="local",
174+
),
175+
),
162176
(
163177
QuantizationArgs(strategy="block"),
164178
QuantizationArgs(strategy="block"),
@@ -184,13 +198,19 @@ def test_initialize_quantization_parameters(weights, input_activations):
184198
continue
185199
q_param_name = Q_PARAM_NAMES[q_type]
186200

187-
if args.num_bits == 4 and args.type == QuantizationType.FLOAT:
188-
assert hasattr(layer, "weight_global_scale")
189-
assert layer.weight_global_scale.dtype == torch.float32
190-
assert layer.weight_global_scale.numel() == 1
191-
assert layer.weight_scale.dtype == FP8_E4M3_DATA.dtype
201+
if args.strategy == QuantizationStrategy.TENSOR_GROUP:
202+
if q_type == "weights":
203+
assert hasattr(layer, "weight_global_scale")
204+
assert layer.weight_global_scale.dtype == torch.float32
205+
assert layer.weight_global_scale.numel() == 1
206+
assert layer.weight_scale.dtype == FP8_E4M3_DATA.dtype
207+
elif q_type == "input_activations":
208+
assert hasattr(layer, "input_global_scale")
209+
assert layer.input_global_scale.dtype == torch.float32
210+
assert layer.input_global_scale.numel() == 1
192211
else:
193212
assert not hasattr(layer, "weight_global_scale")
213+
assert not hasattr(layer, "input_global_scale")
194214

195215
# scale and zero point
196216
if args.strategy == QuantizationStrategy.TENSOR:
@@ -199,7 +219,10 @@ def test_initialize_quantization_parameters(weights, input_activations):
199219
elif args.strategy == QuantizationStrategy.CHANNEL: # only weight
200220
expected_shape = (layer.weight.shape[0], 1)
201221

202-
elif args.strategy == QuantizationStrategy.GROUP: # only weight
222+
elif args.strategy in (
223+
QuantizationStrategy.TENSOR_GROUP,
224+
QuantizationStrategy.GROUP,
225+
):
203226
num_groups = math.ceil(layer.weight.shape[1] / args.group_size)
204227
expected_shape = (layer.weight.shape[0], max(num_groups, 1))
205228

@@ -209,8 +232,9 @@ def test_initialize_quantization_parameters(weights, input_activations):
209232
elif args.strategy == QuantizationStrategy.TOKEN:
210233
expected_shape = (1, 1)
211234

212-
assert getattr(layer, f"{q_param_name}_scale").shape == expected_shape
213-
assert getattr(layer, f"{q_param_name}_zero_point").shape == expected_shape
235+
if not args.dynamic:
236+
assert getattr(layer, f"{q_param_name}_scale").shape == expected_shape
237+
assert getattr(layer, f"{q_param_name}_zero_point").shape == expected_shape
214238

215239
# g_idx
216240
if args.actorder == ActivationOrdering.GROUP:

0 commit comments

Comments
 (0)