Skip to content

Commit 682c110

Browse files
committed
fix condition
1 parent 271a936 commit 682c110

File tree

3 files changed

+13
-10
lines changed

3 files changed

+13
-10
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,9 @@ def compress(
377377

378378
compressed_state_dict = state_dict
379379

380-
quantized_modules_to_args: Dict[str, QuantizationArgs] = (
381-
map_modules_to_quant_args(model)
382-
)
380+
quantized_modules_to_args: Dict[
381+
str, QuantizationArgs
382+
] = map_modules_to_quant_args(model)
383383

384384
if self.quantization_compressor is not None:
385385
compressed_state_dict = self.quantization_compressor.compress(

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,17 +193,16 @@ def _initialize_scale_zero_point(
193193
module, f"{base_name}_global_scale", init_global_scale
194194
)
195195

196-
197196
# TODO: consider erroring out in the future as if the dtype if not one fo these,
198197
# there is likely bug
199-
198+
200199
if scale_dtype not in [
201200
torch.float16,
202201
torch.bfloat16,
203202
torch.float32,
204203
FP8_E4M3_DATA.dtype,
205204
]:
206-
scale_dtype = torch.float16
205+
scale_dtype = torch.float16
207206

208207
# initializes empty scale, zero point, and g_idx parameters for the module
209208
init_scale = Parameter(

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def calculate_qparams(
8282
zp_dtype = FP8_E4M3_DATA.dtype
8383

8484
if quantization_args.symmetric:
85-
# TODO: update for NVFP4 when applying observers
8685
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
8786

8887
if (
@@ -96,10 +95,15 @@ def calculate_qparams(
9695
# Divide over bit range over max value?
9796
scales = max_val_pos / (float(bit_range) / 2)
9897

99-
# TODO: clamp not implemented for FP8 - we shouldn't need to clamp this anyway as we're
100-
# casting to FP8 on line 92?
101-
if scales.dtype != FP8_E4M3_DATA.dtype:
98+
if scales.dtype == FP8_E4M3_DATA.dtype:
99+
# use the next largest fp8 value from 0
100+
# Optionally, we swap to use the reciporcal
101+
scales = torch.where(
102+
scales == 0, torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype), scales
103+
)
104+
else:
102105
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
106+
103107
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
104108
else:
105109
scales = (max_vals - min_vals) / float(bit_range)

0 commit comments

Comments
 (0)