Skip to content

Commit 5318604

Browse files
committed
enforce that MXTensor scale dimensions are consistent with data
Summary: Ensures that if the data dims are (M, K) then scale dims are (M, K // block_size). Previously the scale dims were (M, K // block_size, 1). No logic change in surrounding code, but this is definitely more correct. Test Plan: ``` pytest test/prototype/mx_formats ./test/prototype/mx_formats/test_mx_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 50e45ce ghstack-comment-id: 3049035188 Pull Request resolved: #2506
1 parent dfcfa56 commit 5318604

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
7070
else:
7171
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 13.0)
7272

73+
# verify that if data.shape is (M, K) then scale.shape is (M, K // block_size)
74+
prev_dims, K = data_hp.shape[:-1], data_hp.shape[-1]
75+
if elem_dtype is torch.float4_e2m1fn_x2:
76+
assert data_mx._data.shape == (*prev_dims, K // 2)
77+
else:
78+
assert data_mx._data.shape == (*prev_dims, K)
79+
assert data_mx._scale_e8m0.shape == (*prev_dims, K // block_size)
80+
7381

7482
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
7583
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def to_mx(
331331
raise AssertionError("unsupported")
332332

333333
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
334+
scale_e8m0_biased = scale_e8m0_biased.squeeze(-1)
334335
return scale_e8m0_biased, data_lp
335336

336337

0 commit comments

Comments
 (0)