Skip to content

Commit 35d98d5

Browse files
committed
per tensor input scales are never good???
1 parent 682c110 commit 35d98d5

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -175,23 +175,26 @@ def _initialize_scale_zero_point(
175175

176176
# NVFP4 support; use FP8 scales
177177
# For weight quant, attach global scales for NVFP4
178-
# TODO: NVFP4 Scheme
179178
if (
180179
quantization_args.num_bits == 4
181180
and quantization_args.type == QuantizationType.FLOAT
182181
):
183-
scale_dtype = FP8_E4M3_DATA.dtype
184-
# create and attach nvfp4 data
185-
tensor_amax = torch.abs(module.weight.data).max().to(torch.float32)
186-
# Setting data for now - could possibly be handled later in the pipeline
187-
value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax
188-
# TODO: use model.weight.dtype after checking
189-
value = value.to(torch.float32).to(device)
190-
# Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32?
191-
init_global_scale = Parameter(value, requires_grad=False)
192-
register_offload_parameter(
193-
module, f"{base_name}_global_scale", init_global_scale
194-
)
182+
if base_name == "weight":
183+
scale_dtype = FP8_E4M3_DATA.dtype
184+
# create and attach nvfp4 data
185+
tensor_amax = torch.abs(module.weight.data).max().to(torch.float32)
186+
# Setting data for now - could possibly be handled later in the pipeline
187+
value = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / tensor_amax
188+
# TODO: use model.weight.dtype after checking
189+
value = value.to(torch.float32).to(device)
190+
# Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32?
191+
init_global_scale = Parameter(value, requires_grad=False)
192+
register_offload_parameter(
193+
module, f"{base_name}_global_scale", init_global_scale
194+
)
195+
else:
196+
# input scales should be float32
197+
scale_dtype = torch.float32
195198

196199
# TODO: consider erroring out in the future as if the dtype if not one fo these,
197200
# there is likely bug

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,15 @@ def is_preset_scheme(name: str) -> bool:
108108
symmetric=True,
109109
dynamic=False,
110110
group_size=16,
111-
)
111+
),
112+
input_activations=QuantizationArgs(
113+
num_bits=4,
114+
type=QuantizationType.FLOAT,
115+
strategy=QuantizationStrategy.TENSOR,
116+
symmetric=True,
117+
dynamic=False,
118+
observer=None,
119+
),
112120
)
113121

114122
# 8 bit integer weights and 8 bit activations quantization

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def calculate_qparams(
8787
if (
8888
quantization_args.num_bits == 4
8989
and quantization_args.type == QuantizationType.FLOAT
90+
and global_scale is not None
9091
):
91-
assert global_scale is not None
9292
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) # Not needed
9393
scales = scales.to(FP8_E4M3_DATA.dtype)
9494
else:

0 commit comments

Comments
 (0)