Skip to content

Commit 974953c

Browse files
committed
update
1 parent be02849 commit 974953c

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/compressed_tensors/quantization/quant_args.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,12 @@ def pytorch_dtype(self) -> torch.dtype:
260260
# TODO: required for the compressor
261261
# Add FP4_nvfp4 type when updating naive_compressor
262262
if self.type == QuantizationType.FLOAT:
263-
return FP8_DTYPE
263+
if self.num_bits == 8:
264+
return FP8_E4M3_DATA.dtype
265+
else:
266+
assert self.num_bits == 4
267+
# TODO: will return None for now until updated in FloatArgs
268+
return FP4_NVFP4_DATA.dtype
264269
elif self.type == QuantizationType.INT:
265270
if self.num_bits <= 8:
266271
return torch.int8
@@ -291,9 +296,10 @@ def round_to_quantized_type(
291296
if args.type == QuantizationType.FLOAT:
292297
if args.num_bits == 8:
293298
rounded = tensor.to(FP8_E4M3_DATA.dtype)
294-
elif args.num_bits == 4:
299+
else:
300+
assert args.num_bits == 4
295301
# TODO: cast to whatever value we want fp4 to be post quantization/clamping
296-
rounded = tensor.to()
302+
rounded = tensor.to(FP4_NVFP4_DATA.dtype)
297303
elif args.type == QuantizationType.INT:
298304
rounded = torch.round(tensor)
299305
else:

0 commit comments

Comments
 (0)