Skip to content

Commit be02849

Browse files
committed
update
1 parent d22a137 commit be02849

File tree

4 files changed

+42
-37
lines changed

4 files changed

+42
-37
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -360,22 +360,18 @@ def _quantize(
360360
dtype: Optional[torch.dtype] = None,
361361
) -> torch.Tensor:
362362

363-
if args.num_bits == 4 and args.type == QuantizationType.FLOAT:
364-
# apply fp4 quant
365-
return quantized_value
366-
else:
367-
scaled = x / scale
368-
if zero_point is not None:
369-
scaled += zero_point.to(x.dtype)
370-
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
371-
clamped_value = torch.clamp(
372-
scaled,
373-
q_min,
374-
q_max,
375-
)
376-
quantized_value = round_to_quantized_type(clamped_value, args)
377-
if dtype is not None:
378-
quantized_value = quantized_value.to(dtype)
363+
scaled = x / scale
364+
if zero_point is not None:
365+
scaled += zero_point.to(x.dtype)
366+
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
367+
clamped_value = torch.clamp(
368+
scaled,
369+
q_min,
370+
q_max,
371+
)
372+
quantized_value = round_to_quantized_type(clamped_value, args)
373+
if dtype is not None:
374+
quantized_value = quantized_value.to(dtype)
379375

380376
return quantized_value
381377

@@ -388,17 +384,13 @@ def _dequantize(
388384
dtype: Optional[torch.dtype] = None,
389385
) -> torch.Tensor:
390386

391-
if args.num_bits == 4 and args.type == QuantizationType.FLOAT:
392-
# apply fp4 deqquant
393-
dequant_value = None
394-
else:
395-
dequant_value = x_q.to(scale.dtype)
387+
dequant_value = x_q.to(scale.dtype)
396388

397-
if zero_point is not None:
398-
dequant_value = dequant_value - zero_point.to(scale.dtype)
399-
dequant_value = dequant_value * scale
389+
if zero_point is not None:
390+
dequant_value = dequant_value - zero_point.to(scale.dtype)
391+
dequant_value = dequant_value * scale
400392

401-
if dtype is not None:
402-
dequant_value = dequant_value.to(dtype)
393+
if dtype is not None:
394+
dequant_value = dequant_value.to(dtype)
403395

404396
return dequant_value

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ActivationOrdering,
2626
QuantizationArgs,
2727
QuantizationStrategy,
28+
QuantizationType,
2829
)
2930
from compressed_tensors.quantization.quant_config import QuantizationStatus
3031
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
@@ -169,7 +170,7 @@ def _initialize_scale_zero_point(
169170
if (
170171
base_name == "weight"
171172
and quantization_args.num_bits == 4
172-
and quantization_args.strategy == QuantizationStrategy.FLOAT
173+
and quantization_args.type == QuantizationType.FLOAT
173174
):
174175
scale_dtype = FP8_E4M3_DATA.dtype
175176
# create and attach nvfp4 data
@@ -188,7 +189,7 @@ def _initialize_scale_zero_point(
188189
torch.float16,
189190
torch.bfloat16,
190191
torch.float32,
191-
FP8_DATA.dtype,
192+
FP8_E4M3_DATA.dtype,
192193
]:
193194
scale_dtype = torch.float16
194195

src/compressed_tensors/quantization/quant_args.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,11 @@ def round_to_quantized_type(
289289
"""
290290
original_dtype = tensor.dtype
291291
if args.type == QuantizationType.FLOAT:
292-
rounded = tensor.to(FP8_DTYPE)
292+
if args.num_bits == 8:
293+
rounded = tensor.to(FP8_E4M3_DATA.dtype)
294+
elif args.num_bits == 4:
295+
# TODO: cast to whatever value we want fp4 to be post quantization/clamping
296+
rounded = tensor.to()
293297
elif args.type == QuantizationType.INT:
294298
rounded = torch.round(tensor)
295299
else:

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,22 @@ def calculate_qparams(
7676
if quantization_args.symmetric:
7777
# TODO: update for NVFP4 when applying observers
7878
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
79-
scales = max_val_pos / (float(bit_range) / 2)
79+
80+
if (
81+
quantization_args.num_bits == 4
82+
and quantization_args.type == QuantizationType.FLOAT
83+
):
84+
# TODO: how do we pass in the global scale?
85+
# An observer is attached per module, we can conditionally pass in
86+
# the global scale
87+
scale = global_scale * (max_val_pos / FP4_NVFP4_DATA.max)
88+
scale = scale.to(FP8_E4M3_DATA.dtype).to(torch.float32)
89+
scale = scale / global_scale
90+
else:
91+
# Divide over bit range over max value?
92+
scales = max_val_pos / (float(bit_range) / 2)
93+
94+
# needed for fp4?
8095
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
8196
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
8297
else:
@@ -141,13 +156,6 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
141156
q_min = torch.tensor(-bit_range / 2, device=device)
142157
elif quantization_args.type == QuantizationType.FLOAT:
143158
if quantization_args.num_bits == 8:
144-
"""
145-
if quantization_args.num_bits != 8:
146-
raise ValueError(
147-
"Floating point quantization is only supported for 8 bits,"
148-
f"got {quantization_args.num_bits}"
149-
)
150-
"""
151159
q_max = torch.tensor(FP8_E4M3_DATA.max, device=device)
152160
q_min = torch.tensor(FP8_E4M3_DATA.min, device=device)
153161
else:

0 commit comments

Comments
 (0)