|  | 
| 14 | 14 |     MappingType, | 
| 15 | 15 |     ZeroPointDomain, | 
| 16 | 16 |     _choose_qparams_affine_tinygemm, | 
|  | 17 | +    _choose_scale_float8, | 
| 17 | 18 |     _fake_quantize_affine, | 
| 18 | 19 |     _fake_quantize_affine_cachemask, | 
| 19 | 20 |     _maybe_expand_scale_to_tensor_shape, | 
|  | 21 | +    _quantize_affine_float8, | 
| 20 | 22 |     choose_qparams_affine, | 
| 21 | 23 |     dequantize_affine, | 
| 22 | 24 |     quantize_affine, | 
| @@ -55,6 +57,23 @@ def check_idempotent(self, fn, *args, **kwargs): | 
| 55 | 57 |     return output1 | 
| 56 | 58 | 
 | 
| 57 | 59 | 
 | 
|  | 60 | +# from https://github.com/pytorch/pytorch/blob/7563f61cc8a40a5ba21a498a2d98895b4eec3f39/test/test_scaled_matmul_cuda.py#L100 | 
|  | 61 | +# with scale modified to be the inverse of the version in PT core | 
|  | 62 | +def _tensor_to_scale_block( | 
|  | 63 | +    x: torch.Tensor, | 
|  | 64 | +    float8_dtype: torch.dtype, | 
|  | 65 | +    block_outer: int, | 
|  | 66 | +    block_inner: int, | 
|  | 67 | +) -> tuple[torch.Tensor, torch.Tensor]: | 
|  | 68 | +    x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer)) | 
|  | 69 | +    amax = x.abs().amax(dim=[1, 3], keepdim=True).float() | 
|  | 70 | +    scale = amax / torch.finfo(float8_dtype).max | 
|  | 71 | +    x = x.div(scale).to(float8_dtype) | 
|  | 72 | +    x = x.flatten(2, 3).flatten(0, 1) | 
|  | 73 | +    scale = scale.flatten(2, 3).flatten(0, 1) | 
|  | 74 | +    return x, scale | 
|  | 75 | + | 
|  | 76 | + | 
| 58 | 77 | # Legacy tinygemm ops | 
| 59 | 78 | def _get_groupwise_affine_qparams( | 
| 60 | 79 |     w, | 
| @@ -798,6 +817,33 @@ def test_maybe_expand_scale_to_tensor_shape(self): | 
| 798 | 817 |         self.assertEqual(new_scale5.shape, torch.Size([3, 2, 8])) | 
| 799 | 818 |         self.assertEqual(new_scale5.unique(dim=-1).shape, torch.Size([3, 2, 2])) | 
| 800 | 819 | 
 | 
|  | 820 | +    def test_float8_blockwise_scaling(self): | 
|  | 821 | +        M, K = 512, 1024 | 
|  | 822 | +        hp_tensor = torch.randn(M, K, dtype=torch.float) | 
|  | 823 | +        # make the scales from some of the blocks obviously different | 
|  | 824 | +        hp_tensor[0:128, 0:128] *= 3.0 | 
|  | 825 | +        hp_tensor[0:128, 128:256] *= 7.0 | 
|  | 826 | +        hp_tensor[128:256, 0:128] *= 2.0 | 
|  | 827 | +        hp_tensor[128:256, 128:256] *= 100.0 | 
|  | 828 | + | 
|  | 829 | +        block_size = (128, 128) | 
|  | 830 | + | 
|  | 831 | +        scale = _choose_scale_float8( | 
|  | 832 | +            hp_tensor, | 
|  | 833 | +            float8_dtype=torch.float8_e4m3fn, | 
|  | 834 | +            block_size=block_size, | 
|  | 835 | +            hp_value_lb=None, | 
|  | 836 | +            hp_value_ub=None, | 
|  | 837 | +        ) | 
|  | 838 | +        data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn) | 
|  | 839 | + | 
|  | 840 | +        ref_data, ref_scale = _tensor_to_scale_block( | 
|  | 841 | +            hp_tensor, torch.float8_e4m3fn, 128, 128 | 
|  | 842 | +        ) | 
|  | 843 | + | 
|  | 844 | +        torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0) | 
|  | 845 | +        torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0) | 
|  | 846 | + | 
| 801 | 847 | 
 | 
| 802 | 848 | if __name__ == "__main__": | 
| 803 | 849 |     unittest.main() | 
0 commit comments