-
Notifications
You must be signed in to change notification settings - Fork 18
Open
Description
compressed-tensors/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py
Lines 140 to 150 in 09b7ed4
# Handle odd length by padding if necessary | |
if indices.numel() % 2 != 0: | |
indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)]) | |
# Reshape to pair consecutive elements | |
indices = indices.reshape(-1, 2) | |
# Pack pairs of 4-bit values into 8-bit values | |
packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8) | |
return packed.reshape(m, n // 2) |
Given a tensor of shape (m, n)
this function packs pairs of values into uint8
s. However, if there is an odd number of values the current logic handles this by concatenating an additional zero onto the values. This would work but the uint8
tensor is being reshaped to (m, n // 2)
before returning, which fails when n
is odd.
e.g.
m=5, n=3
-> 15 values
-> 15 + 1 = 16 values after padding
-> 8 uint8s after packing
-> reshape 8 values into (5, 1)
The options for fixes are:
- Return the flattened
uint8
tensor without reshaping - padding with
m
zeros instead of one and reshape to(m, (n+1)//2)
- Explicitly remove support for this case
I'm happy to submit a patch fixing this.
Metadata
Metadata
Assignees
Labels
No labels