Skip to content

Bug: Incorrect handling of odd sized dims in nvfp4 packing/unpacking #401

@fynnsu

Description

@fynnsu

# Handle odd length by padding if necessary
if indices.numel() % 2 != 0:
indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)])
# Reshape to pair consecutive elements
indices = indices.reshape(-1, 2)
# Pack pairs of 4-bit values into 8-bit values
packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8)
return packed.reshape(m, n // 2)

Given a tensor of shape (m, n) this function packs pairs of values into uint8s. However, if there is an odd number of values the current logic handles this by concatenating an additional zero onto the values. This would work but the uint8 tensor is being reshaped to (m, n // 2) before returning, which fails when n is odd.

e.g.
m=5, n=3 
-> 15 values 
-> 15 + 1 = 16 values after padding
-> 8 uint8s after packing
-> reshape 8 values into (5, 1)

The options for fixes are:

  1. Return the flattened uint8 tensor without reshaping
  2. padding with m zeros instead of one and reshape to (m, (n+1)//2)
  3. Explicitly remove support for this case

I'm happy to submit a patch fixing this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions