Skip to content

Commit 1d4a2f7

Browse files
committed
Update
[ghstack-poisoned]
2 parents be5a9bb + 6a3684b commit 1d4a2f7

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
MappingType,
1515
ZeroPointDomain,
1616
_choose_qparams_affine_tinygemm,
17+
_choose_scale_float8,
1718
_fake_quantize_affine,
1819
_fake_quantize_affine_cachemask,
1920
_maybe_expand_scale_to_tensor_shape,
21+
_quantize_affine_float8,
2022
choose_qparams_affine,
2123
dequantize_affine,
2224
quantize_affine,
@@ -55,6 +57,23 @@ def check_idempotent(self, fn, *args, **kwargs):
5557
return output1
5658

5759

60+
# from https://github.com/pytorch/pytorch/blob/7563f61cc8a40a5ba21a498a2d98895b4eec3f39/test/test_scaled_matmul_cuda.py#L100
61+
# with scale modified to be the inverse of the version in PT core
62+
def _tensor_to_scale_block(
63+
x: torch.Tensor,
64+
float8_dtype: torch.dtype,
65+
block_outer: int,
66+
block_inner: int,
67+
) -> tuple[torch.Tensor, torch.Tensor]:
68+
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
69+
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
70+
scale = amax / torch.finfo(float8_dtype).max
71+
x = x.div(scale).to(float8_dtype)
72+
x = x.flatten(2, 3).flatten(0, 1)
73+
scale = scale.flatten(2, 3).flatten(0, 1)
74+
return x, scale
75+
76+
5877
# Legacy tinygemm ops
5978
def _get_groupwise_affine_qparams(
6079
w,
@@ -798,6 +817,33 @@ def test_maybe_expand_scale_to_tensor_shape(self):
798817
self.assertEqual(new_scale5.shape, torch.Size([3, 2, 8]))
799818
self.assertEqual(new_scale5.unique(dim=-1).shape, torch.Size([3, 2, 2]))
800819

820+
def test_float8_blockwise_scaling(self):
821+
M, K = 512, 1024
822+
hp_tensor = torch.randn(M, K, dtype=torch.float)
823+
# make the scales from some of the blocks obviously different
824+
hp_tensor[0:128, 0:128] *= 3.0
825+
hp_tensor[0:128, 128:256] *= 7.0
826+
hp_tensor[128:256, 0:128] *= 2.0
827+
hp_tensor[128:256, 128:256] *= 100.0
828+
829+
block_size = (128, 128)
830+
831+
scale = _choose_scale_float8(
832+
hp_tensor,
833+
float8_dtype=torch.float8_e4m3fn,
834+
block_size=block_size,
835+
hp_value_lb=None,
836+
hp_value_ub=None,
837+
)
838+
data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn)
839+
840+
ref_data, ref_scale = _tensor_to_scale_block(
841+
hp_tensor, torch.float8_e4m3fn, 128, 128
842+
)
843+
844+
torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0)
845+
torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0)
846+
801847

802848
if __name__ == "__main__":
803849
unittest.main()

0 commit comments

Comments
 (0)