diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 60a889c36b..6fe91a379f 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -70,6 +70,14 @@ def assert_sqnr_gt_threshold(orig, new, threshold): else: assert_sqnr_gt_threshold(data_hp, data_mx_dq, 13.0) + # verify that if data.shape is (M, K) then scale.shape is (M, K // block_size) + prev_dims, K = data_hp.shape[:-1], data_hp.shape[-1] + if elem_dtype is torch.float4_e2m1fn_x2: + assert data_mx._data.shape == (*prev_dims, K // 2) + else: + assert data_mx._data.shape == (*prev_dims, K) + assert data_mx._scale_e8m0.shape == (*prev_dims, K // block_size) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index afce0313b7..793acaf536 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -331,6 +331,7 @@ def to_mx( raise AssertionError("unsupported") scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) + scale_e8m0_biased = scale_e8m0_biased.squeeze(-1) return scale_e8m0_biased, data_lp