File tree Expand file tree Collapse file tree 1 file changed +9
-3
lines changed
src/compressed_tensors/quantization Expand file tree Collapse file tree 1 file changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -260,7 +260,12 @@ def pytorch_dtype(self) -> torch.dtype:
260
260
# TODO: required for the compressor
261
261
# Add FP4_nvfp4 type when updating naive_compressor
262
262
if self .type == QuantizationType .FLOAT :
263
- return FP8_DTYPE
263
+ if self .num_bits == 8 :
264
+ return FP8_E4M3_DATA .dtype
265
+ else :
266
+ assert self .num_bits == 4
267
+ # TODO: will return None for now until updated in FloatArgs
268
+ return FP4_NVFP4_DATA .dtype
264
269
elif self .type == QuantizationType .INT :
265
270
if self .num_bits <= 8 :
266
271
return torch .int8
@@ -291,9 +296,10 @@ def round_to_quantized_type(
291
296
if args .type == QuantizationType .FLOAT :
292
297
if args .num_bits == 8 :
293
298
rounded = tensor .to (FP8_E4M3_DATA .dtype )
294
- elif args .num_bits == 4 :
299
+ else :
300
+ assert args .num_bits == 4
295
301
# TODO: cast to whatever value we want fp4 to be post quantization/clamping
296
- rounded = tensor .to ()
302
+ rounded = tensor .to (FP4_NVFP4_DATA . dtype )
297
303
elif args .type == QuantizationType .INT :
298
304
rounded = torch .round (tensor )
299
305
else :
You can’t perform that action at this time.
0 commit comments