Skip to content

Commit 994a4ba

Browse files
authored
Store NVFP4 block scales in swwizzled layout on tensor (#2438)
1 parent b1163dc commit 994a4ba

File tree

6 files changed

+543
-43
lines changed

6 files changed

+543
-43
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,11 +558,13 @@ def test_nvfp4_matmul_with_amax(
558558
A,
559559
per_tensor_scale=a_scale,
560560
mm_config=mm_config,
561+
is_swizzled_scales=True,
561562
)
562563
B_nvfp4 = NVFP4Tensor.to_nvfp4(
563564
B,
564565
per_tensor_scale=b_scale,
565566
mm_config=mm_config,
567+
is_swizzled_scales=True,
566568
)
567569

568570
func = torch.compile(F.linear, fullgraph=True) if compile else F.linear

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,301 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
657657
assert x.t().dtype == x_reconstructed_t.dtype, (
658658
f"Transpose dtype mismatch: {x.t().dtype} vs {x_reconstructed_t.dtype}"
659659
)
660+
661+
662+
@pytest.mark.parametrize(
663+
"shape",
664+
[
665+
(128, 4),
666+
(256, 8),
667+
(100, 3),
668+
(4, 4),
669+
(50, 10),
670+
(384, 12),
671+
],
672+
)
673+
@pytest.mark.parametrize(
674+
"use_triton_kernel", [False, True] if torch.cuda.is_available() else [False]
675+
)
676+
@pytest.mark.skipif(
677+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
678+
)
679+
def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
680+
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
681+
682+
rows, cols = shape
683+
device = "cuda" if torch.cuda.is_available() else "cpu"
684+
685+
original = torch.randint(0, 255, (rows, cols), device=device, dtype=torch.uint8)
686+
687+
blocked = to_blocked(original, use_triton_kernel=use_triton_kernel)
688+
reconstructed = from_blocked(blocked, rows, cols)
689+
690+
torch.testing.assert_close(
691+
original,
692+
reconstructed,
693+
atol=0.0,
694+
rtol=0.0,
695+
msg=f"Roundtrip failed for shape {shape} with use_triton_kernel={use_triton_kernel}",
696+
)
697+
698+
699+
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
700+
@pytest.mark.parametrize(
701+
"shape",
702+
[
703+
(32, 64),
704+
(16, 32),
705+
(64, 128),
706+
(384, 128),
707+
],
708+
)
709+
@pytest.mark.skipif(
710+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
711+
)
712+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
713+
def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
714+
"""
715+
Test that NVFP4Tensor can be constructed with swizzled scales and
716+
that the _is_swizzled_scales flag is set correctly.
717+
"""
718+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
719+
720+
M, K = shape
721+
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
722+
723+
tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=is_swizzled_scales)
724+
assert tensor._is_swizzled_scales == is_swizzled_scales
725+
reconstructed = tensor.to_dtype(torch.bfloat16)
726+
assert reconstructed.shape == data.shape
727+
728+
729+
@pytest.mark.parametrize(
730+
"slice_dim,slice_spec",
731+
[
732+
# Row slicing - must align with 128-row boundaries
733+
pytest.param(0, slice(0, 128), id="slice_rows[0:128]"),
734+
pytest.param(0, slice(128, 256), id="slice_rows[128:256]"),
735+
# Column slicing - must align with 64-column boundaries (4 scale columns * 16 block_size)
736+
pytest.param(1, slice(0, 64), id="slice_cols[0:64]"),
737+
pytest.param(1, slice(64, 128), id="slice_cols[64:128]"),
738+
pytest.param(1, slice(0, 128), id="slice_cols[0:128]_full_width"),
739+
# Test tensor parallelism patterns (half splits)
740+
pytest.param(1, slice(0, 2048), id="slice_cols[0:2048]_tp_first_half"),
741+
pytest.param(1, slice(2048, 4096), id="slice_cols[2048:4096]_tp_second_half"),
742+
# Test quarter splits
743+
pytest.param(1, slice(0, 1024), id="slice_cols[0:1024]_quarter"),
744+
pytest.param(1, slice(1024, 2048), id="slice_cols[1024:2048]_quarter"),
745+
],
746+
)
747+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
748+
@pytest.mark.skipif(
749+
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
750+
)
751+
def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
752+
"""
753+
Test that slicing works correctly with swizzled scales and maintains
754+
the swizzled state in the output tensor.
755+
"""
756+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
757+
758+
# Use larger tensor sizes that align with swizzled requirements
759+
if slice_dim == 0:
760+
# For row slicing, need at least 256 rows to test 128-row boundaries
761+
M, K = 256, 4096
762+
else:
763+
# For column slicing, need multiples of 64 columns for alignment
764+
M, K = 128, 4096
765+
766+
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
767+
768+
tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True)
769+
assert tensor._is_swizzled_scales == True
770+
771+
if slice_dim == 0:
772+
sliced_tensor = tensor[slice_spec, :]
773+
else:
774+
sliced_tensor = tensor[:, slice_spec]
775+
776+
# Verify sliced tensor maintains swizzled state
777+
assert sliced_tensor._is_swizzled_scales == True
778+
779+
# Verify sliced tensor can be dequantized
780+
sliced_reconstructed = sliced_tensor.to_dtype(torch.bfloat16)
781+
782+
# Compare with direct slicing of original data
783+
original_reconstructed = tensor.to_dtype(torch.bfloat16)
784+
if slice_dim == 0:
785+
expected = original_reconstructed[slice_spec, :]
786+
else:
787+
expected = original_reconstructed[:, slice_spec]
788+
789+
torch.testing.assert_close(sliced_reconstructed, expected, atol=1e-6, rtol=1e-6)
790+
791+
792+
@pytest.mark.parametrize(
793+
"slice_dim,slice_spec,expected_error",
794+
[
795+
# Row slicing with misaligned boundaries
796+
pytest.param(
797+
0,
798+
slice(0, 100),
799+
"Row slicing of NVFP4Tensor with swizzled scales requires",
800+
id="misaligned_row_end",
801+
),
802+
pytest.param(
803+
0,
804+
slice(50, 150),
805+
"Row slicing of NVFP4Tensor with swizzled scales requires",
806+
id="misaligned_row_start",
807+
),
808+
# Column slicing with misaligned boundaries
809+
pytest.param(
810+
1,
811+
slice(0, 32),
812+
"Column slicing of NVFP4Tensor with swizzled scales requires",
813+
id="misaligned_col_32",
814+
),
815+
pytest.param(
816+
1,
817+
slice(16, 80),
818+
"Column slicing of NVFP4Tensor with swizzled scales requires",
819+
id="misaligned_col_start",
820+
),
821+
pytest.param(
822+
1,
823+
slice(0, 100),
824+
"Column slicing of NVFP4Tensor with swizzled scales requires",
825+
id="misaligned_col_end",
826+
),
827+
# Odd column boundaries (FP4 packing requirement)
828+
pytest.param(
829+
1,
830+
slice(1, 65),
831+
"start index to be a multiple of 64, got 1",
832+
id="odd_start",
833+
),
834+
pytest.param(
835+
1,
836+
slice(0, 65),
837+
" multiple of 64 or equal to tensor size 4096, got 65",
838+
id="odd_end",
839+
),
840+
],
841+
)
842+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
843+
@pytest.mark.skipif(
844+
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
845+
)
846+
def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_error):
847+
"""
848+
Test that slicing raises appropriate errors for misaligned boundaries.
849+
"""
850+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
851+
852+
M, K = 256, 4096
853+
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
854+
tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True)
855+
856+
with pytest.raises(RuntimeError, match=expected_error):
857+
if slice_dim == 0:
858+
_ = tensor[slice_spec, :]
859+
else:
860+
_ = tensor[:, slice_spec]
861+
862+
863+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
864+
@pytest.mark.skipif(
865+
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
866+
)
867+
def test_nvfp4_swizzled_scales_view_semantics():
868+
"""
869+
Test that slicing maintains proper view semantics where possible.
870+
"""
871+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
872+
873+
M, K = 256, 4096
874+
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
875+
tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True)
876+
877+
# Test row slicing (should maintain views)
878+
sliced_tensor = tensor[0:128, :]
879+
880+
# Test that the sliced tensor shares storage with original for data
881+
# (Note: scales might not share storage due to swizzled layout complexity)
882+
assert sliced_tensor._data.data_ptr() == tensor._data.data_ptr()
883+
884+
# Test full-width column slicing (should maintain views)
885+
full_width_slice = tensor[:, 0:K]
886+
assert full_width_slice._scale_e4m3.data_ptr() == tensor._scale_e4m3.data_ptr()
887+
assert full_width_slice._data.data_ptr() == tensor._data.data_ptr()
888+
889+
890+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
891+
@pytest.mark.skipif(
892+
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
893+
)
894+
def test_nvfp4_swizzled_scales_serialization():
895+
"""
896+
Test that tensor flatten/unflatten preserves the swizzled scales state.
897+
"""
898+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
899+
900+
M, K = 32, 64
901+
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
902+
903+
# Create tensor with swizzled scales
904+
original_tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True)
905+
906+
# Test serialization
907+
tensor_list, ctx = original_tensor.__tensor_flatten__()
908+
909+
# Verify swizzled flag is preserved in context
910+
assert "_is_swizzled_scales" in ctx
911+
assert ctx["_is_swizzled_scales"] == True
912+
913+
# Test deserialization
914+
inner_tensors = {}
915+
for name in tensor_list:
916+
inner_tensors[name] = getattr(original_tensor, name)
917+
918+
reconstructed_tensor = NVFP4Tensor.__tensor_unflatten__(
919+
inner_tensors, ctx, None, None
920+
)
921+
922+
# Verify the swizzled state is preserved
923+
assert reconstructed_tensor._is_swizzled_scales == True
924+
925+
# Verify functionality is preserved
926+
original_dq = original_tensor.to_dtype(torch.bfloat16)
927+
reconstructed_dq = reconstructed_tensor.to_dtype(torch.bfloat16)
928+
929+
torch.testing.assert_close(original_dq, reconstructed_dq, atol=1e-6, rtol=1e-6)
930+
931+
932+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
933+
@pytest.mark.skipif(
934+
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
935+
)
936+
def test_nvfp4_swizzled_scales_get_scales_method():
937+
"""
938+
Test that the get_scales() method correctly unswizzles scales when needed.
939+
"""
940+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
941+
942+
M, K = 32, 64
943+
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
944+
945+
# Create tensors with both storage methods
946+
regular_tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=False)
947+
swizzled_tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True)
948+
949+
# Get scales from both tensors and verify they are equal
950+
regular_scales = regular_tensor.get_hp_scales()
951+
swizzled_scales = swizzled_tensor.get_hp_scales()
952+
torch.testing.assert_close(regular_scales, swizzled_scales, atol=0.0, rtol=0.0)
953+
954+
# Verify scales have the expected shape
955+
expected_shape = (M, K // 16)
956+
assert regular_scales.shape == expected_shape
957+
assert swizzled_scales.shape == expected_shape

torchao/prototype/mx_formats/mx_subclass.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,11 @@ def _nvfp4_inference_linear_transform(
184184

185185
weight = module.weight
186186

187+
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
188+
raise RuntimeError(
189+
f"NVFP4 only supports weight shape divisible by 16, got {weight.shape}"
190+
)
191+
187192
if module.bias is not None and weight.dtype == torch.float32:
188193
raise RuntimeError(
189194
"Bias is not supported when module weight is in fp32 (out_dtype=Float32). "
@@ -193,8 +198,8 @@ def _nvfp4_inference_linear_transform(
193198
quantized_weight = NVFP4Tensor.to_nvfp4(
194199
weight,
195200
mm_config=config.mm_config,
201+
is_swizzled_scales=True,
196202
)
197-
198203
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
199204
module.extra_repr = types.MethodType(_linear_extra_repr, module)
200205
return module

0 commit comments

Comments
 (0)