Skip to content

Commit c1e84cc

Browse files
authored
enforce that MXTensor scale dimensions are consistent with data (#2506)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent e675ffd commit c1e84cc

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
7070
else:
7171
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 13.0)
7272

73+
# verify that if data.shape is (M, K) then scale.shape is (M, K // block_size)
74+
prev_dims, K = data_hp.shape[:-1], data_hp.shape[-1]
75+
if elem_dtype is torch.float4_e2m1fn_x2:
76+
assert data_mx._data.shape == (*prev_dims, K // 2)
77+
else:
78+
assert data_mx._data.shape == (*prev_dims, K)
79+
assert data_mx._scale_e8m0.shape == (*prev_dims, K // block_size)
80+
7381

7482
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
7583
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def to_mx(
331331
raise AssertionError("unsupported")
332332

333333
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
334+
scale_e8m0_biased = scale_e8m0_biased.squeeze(-1)
334335
return scale_e8m0_biased, data_lp
335336

336337

0 commit comments

Comments
 (0)