@@ -175,23 +175,26 @@ def _initialize_scale_zero_point(
175
175
176
176
# NVFP4 support; use FP8 scales
177
177
# For weight quant, attach global scales for NVFP4
178
- # TODO: NVFP4 Scheme
179
178
if (
180
179
quantization_args .num_bits == 4
181
180
and quantization_args .type == QuantizationType .FLOAT
182
181
):
183
- scale_dtype = FP8_E4M3_DATA .dtype
184
- # create and attach nvfp4 data
185
- tensor_amax = torch .abs (module .weight .data ).max ().to (torch .float32 )
186
- # Setting data for now - could possibly be handled later in the pipeline
187
- value = FP8_E4M3_DATA .max * FP4_E2M1_DATA .max / tensor_amax
188
- # TODO: use model.weight.dtype after checking
189
- value = value .to (torch .float32 ).to (device )
190
- # Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32?
191
- init_global_scale = Parameter (value , requires_grad = False )
192
- register_offload_parameter (
193
- module , f"{ base_name } _global_scale" , init_global_scale
194
- )
182
+ if base_name == "weight" :
183
+ scale_dtype = FP8_E4M3_DATA .dtype
184
+ # create and attach nvfp4 data
185
+ tensor_amax = torch .abs (module .weight .data ).max ().to (torch .float32 )
186
+ # Setting data for now - could possibly be handled later in the pipeline
187
+ value = FP8_E4M3_DATA .max * FP4_E2M1_DATA .max / tensor_amax
188
+ # TODO: use model.weight.dtype after checking
189
+ value = value .to (torch .float32 ).to (device )
190
+ # Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32?
191
+ init_global_scale = Parameter (value , requires_grad = False )
192
+ register_offload_parameter (
193
+ module , f"{ base_name } _global_scale" , init_global_scale
194
+ )
195
+ else :
196
+ # input scales should be float32
197
+ scale_dtype = torch .float32
195
198
196
199
# TODO: consider erroring out in the future as if the dtype if not one fo these,
197
200
# there is likely bug
0 commit comments