Skip to content

Commit d22a137

Browse files
committed
initial commit
1 parent 5742998 commit d22a137

File tree

4 files changed

+101
-28
lines changed

4 files changed

+101
-28
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from compressed_tensors.quantization.quant_args import (
2121
QuantizationArgs,
2222
QuantizationStrategy,
23+
QuantizationType,
2324
round_to_quantized_type,
2425
)
2526
from compressed_tensors.quantization.quant_config import QuantizationStatus
@@ -359,18 +360,22 @@ def _quantize(
359360
dtype: Optional[torch.dtype] = None,
360361
) -> torch.Tensor:
361362

362-
scaled = x / scale
363-
if zero_point is not None:
364-
scaled += zero_point.to(x.dtype)
365-
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
366-
clamped_value = torch.clamp(
367-
scaled,
368-
q_min,
369-
q_max,
370-
)
371-
quantized_value = round_to_quantized_type(clamped_value, args)
372-
if dtype is not None:
373-
quantized_value = quantized_value.to(dtype)
363+
if args.num_bits == 4 and args.type == QuantizationType.FLOAT:
364+
# apply fp4 quant
365+
return quantized_value
366+
else:
367+
scaled = x / scale
368+
if zero_point is not None:
369+
scaled += zero_point.to(x.dtype)
370+
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
371+
clamped_value = torch.clamp(
372+
scaled,
373+
q_min,
374+
q_max,
375+
)
376+
quantized_value = round_to_quantized_type(clamped_value, args)
377+
if dtype is not None:
378+
quantized_value = quantized_value.to(dtype)
374379

375380
return quantized_value
376381

@@ -382,13 +387,18 @@ def _dequantize(
382387
zero_point: torch.Tensor = None,
383388
dtype: Optional[torch.dtype] = None,
384389
) -> torch.Tensor:
385-
dequant_value = x_q.to(scale.dtype)
386390

387-
if zero_point is not None:
388-
dequant_value = dequant_value - zero_point.to(scale.dtype)
389-
dequant_value = dequant_value * scale
391+
if args.num_bits == 4 and args.type == QuantizationType.FLOAT:
392+
# apply fp4 deqquant
393+
dequant_value = None
394+
else:
395+
dequant_value = x_q.to(scale.dtype)
396+
397+
if zero_point is not None:
398+
dequant_value = dequant_value - zero_point.to(scale.dtype)
399+
dequant_value = dequant_value * scale
390400

391-
if dtype is not None:
392-
dequant_value = dequant_value.to(dtype)
401+
if dtype is not None:
402+
dequant_value = dequant_value.to(dtype)
393403

394404
return dequant_value

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
3131
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
3232
from compressed_tensors.utils import (
33+
FP4_NVFP4_DATA,
34+
FP8_E4M3_DATA,
3335
disable_hf_hook,
3436
has_offloaded_params,
3537
register_offload_parameter,
@@ -161,7 +163,33 @@ def _initialize_scale_zero_point(
161163
expected_shape = (weight_shape[0], max(num_groups, 1))
162164

163165
scale_dtype = module.weight.dtype
164-
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
166+
167+
# NVFP4 support; use FP8 scales
168+
# For weight quant, attach global scales for NVFP4
169+
if (
170+
base_name == "weight"
171+
and quantization_args.num_bits == 4
172+
and quantization_args.strategy == QuantizationStrategy.FLOAT
173+
):
174+
scale_dtype = FP8_E4M3_DATA.dtype
175+
# create and attach nvfp4 data
176+
tensor_amax = torch.abs(module.weight.data).max().to(torch.float32)
177+
# Setting data for now - could possibly be handled later in the pipeline
178+
values = FP8_E4M3_DATA.max * FP4_NVFP4_DATA.max / tensor_amax
179+
# Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32?
180+
init_global_scale = Parameter(
181+
value, dtype=torch.float32, device=device, requires_grad=False
182+
)
183+
register_offload_parameter(
184+
module, f"f{base_name}_global_scale", init_global_scale
185+
)
186+
187+
if scale_dtype not in [
188+
torch.float16,
189+
torch.bfloat16,
190+
torch.float32,
191+
FP8_DATA.dtype,
192+
]:
165193
scale_dtype = torch.float16
166194

167195
# initializes empty scale, zero point, and g_idx parameters for the module

src/compressed_tensors/quantization/quant_args.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import warnings
16+
from dataclasses import dataclass
1617
from enum import Enum
1718
from typing import Any, Dict, Optional, Union
1819

@@ -24,15 +25,38 @@
2425

2526
__all__ = [
2627
"FP8_DTYPE",
28+
"FP8_E4M3_DATA",
29+
"FP4_NVFP4_DATA",
2730
"QuantizationType",
2831
"QuantizationStrategy",
2932
"QuantizationArgs",
3033
"round_to_quantized_type",
3134
"ActivationOrdering",
3235
]
3336

37+
# TODO: Remove soon in favour of a more descriptive FloatArgs
3438
FP8_DTYPE = torch.float8_e4m3fn
3539

40+
FP8_E4M3_DATA = FloatArgs(
41+
exponent=4,
42+
mantissa=3,
43+
bits=8,
44+
max=torch.finfo(torch.float8_e4m3fn).max,
45+
min=torch.finfo(torch.float8_e4m3fn).min,
46+
dtype=torch.float8_e4m3fn,
47+
)
48+
FP4_NVFP4_DATA = FloatArgs(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0)
49+
50+
51+
@dataclass
52+
class FloatArgs:
53+
exponent: int
54+
mantissa: int
55+
bits: int
56+
max: float
57+
min: float
58+
dtype: Optional[torch.dtype] = None
59+
3660

3761
class QuantizationType(str, Enum):
3862
"""
@@ -233,6 +257,8 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
233257
return model
234258

235259
def pytorch_dtype(self) -> torch.dtype:
260+
# TODO: required for the compressor
261+
# Add FP4_nvfp4 type when updating naive_compressor
236262
if self.type == QuantizationType.FLOAT:
237263
return FP8_DTYPE
238264
elif self.type == QuantizationType.INT:

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
import torch
1919
from compressed_tensors.quantization.quant_args import (
20-
FP8_DTYPE,
20+
FP4_NVFP4_DATA,
21+
FP8_E4M3_DATA,
2122
QuantizationArgs,
2223
QuantizationStrategy,
2324
QuantizationType,
@@ -73,6 +74,7 @@ def calculate_qparams(
7374
zp_dtype = quantization_args.pytorch_dtype()
7475

7576
if quantization_args.symmetric:
77+
# TODO: update for NVFP4 when applying observers
7678
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
7779
scales = max_val_pos / (float(bit_range) / 2)
7880
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
@@ -138,14 +140,21 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
138140
q_max = torch.tensor(bit_range / 2 - 1, device=device)
139141
q_min = torch.tensor(-bit_range / 2, device=device)
140142
elif quantization_args.type == QuantizationType.FLOAT:
141-
if quantization_args.num_bits != 8:
142-
raise ValueError(
143-
"Floating point quantization is only supported for 8 bits,"
144-
f"got {quantization_args.num_bits}"
145-
)
146-
fp_range_info = torch.finfo(FP8_DTYPE)
147-
q_max = torch.tensor(fp_range_info.max, device=device)
148-
q_min = torch.tensor(fp_range_info.min, device=device)
143+
if quantization_args.num_bits == 8:
144+
"""
145+
if quantization_args.num_bits != 8:
146+
raise ValueError(
147+
"Floating point quantization is only supported for 8 bits,"
148+
f"got {quantization_args.num_bits}"
149+
)
150+
"""
151+
q_max = torch.tensor(FP8_E4M3_DATA.max, device=device)
152+
q_min = torch.tensor(FP8_E4M3_DATA.min, device=device)
153+
else:
154+
# nvfp4 ranges
155+
assert quantization_args.num_bits == 4
156+
q_max = torch.tensor(FP4_NVFP4_DATA.max, device=device)
157+
q_min = torch.tensor(FP4_NVFP4_DATA.min, device=device)
149158
else:
150159
raise ValueError(f"Invalid quantization type {quantization_args.type}")
151160

0 commit comments

Comments
 (0)