|
45 | 45 | gemm_fp8_fp8_bf16_nt,
|
46 | 46 | get_col_major_tma_aligned_tensor,
|
47 | 47 | m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
| 48 | + m_grouped_gemm_fp8_fp8_bf16_nt_masked, |
48 | 49 | )
|
49 | 50 |
|
50 | 51 | DEEPGEMM_ENABLED = True
|
@@ -871,6 +872,72 @@ def cuda(self) -> bool:
|
871 | 872 | return DEEPGEMM_ENABLED
|
872 | 873 |
|
873 | 874 |
|
| 875 | +@register_quantize_op |
| 876 | +class DeepGemmMaskedStacked(DeepGemmStacked): |
| 877 | + def preprocess(self, x, w): |
| 878 | + # Quantize weights. |
| 879 | + wq, w_scale = zip(*[quantize_fp8_block(i, block_k=128, block_m=128) for i in w]) |
| 880 | + # Group weights as single tensor. |
| 881 | + wq = torch.stack(wq, dim=0).contiguous() |
| 882 | + w_scale = torch.stack(w_scale, dim=0).contiguous() |
| 883 | + |
| 884 | + # Also view input as flattened. |
| 885 | + m_values = [i.shape[0] for i in x] |
| 886 | + expected_m = max(m_values) |
| 887 | + padded_m_max = ((max(m_values) + 127) // 128) * 128 |
| 888 | + masked_m = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device) |
| 889 | + |
| 890 | + num_groups = len(m_values) |
| 891 | + k = x[0].shape[1] |
| 892 | + x_padded = torch.zeros( |
| 893 | + [num_groups, padded_m_max, k], device=x[0].device, dtype=x[0].dtype |
| 894 | + ) |
| 895 | + for g in range(num_groups): |
| 896 | + x_padded[g, : m_values[g], :] = x[g] |
| 897 | + |
| 898 | + # Return processed tensors. |
| 899 | + return x_padded, wq, w_scale, masked_m, expected_m, m_values |
| 900 | + |
| 901 | + def quantize(self, x, wq, w_scale, masked_m, expected_m, m_values): |
| 902 | + g, m_max, k = x.shape |
| 903 | + xq, x_scale = quantize_fp8_block(x.view(-1, k), block_m=1, block_k=128) |
| 904 | + # Pretranspose scales to deepgemm format. |
| 905 | + x_scale = get_col_major_tma_aligned_tensor(x_scale) |
| 906 | + return ( |
| 907 | + xq.view(g, m_max, -1), |
| 908 | + wq, |
| 909 | + x_scale.view(g, m_max, -1), |
| 910 | + w_scale, |
| 911 | + masked_m, |
| 912 | + expected_m, |
| 913 | + m_values, |
| 914 | + ) |
| 915 | + |
| 916 | + def compute(self, xq, wq, x_scale, w_scale, masked_m, expected_m, m_values): |
| 917 | + # Preallocate output. |
| 918 | + out = torch.empty( |
| 919 | + [xq.shape[0], xq.shape[1], wq.shape[1]], |
| 920 | + device=xq.device, |
| 921 | + dtype=torch.bfloat16, |
| 922 | + ) |
| 923 | + m_grouped_gemm_fp8_fp8_bf16_nt_masked( |
| 924 | + (xq, x_scale), (wq, w_scale), out, masked_m, expected_m |
| 925 | + ) |
| 926 | + num_groups = xq.shape[0] |
| 927 | + out_list = [out[g, : m_values[g], :] for g in range(num_groups)] |
| 928 | + return out_list |
| 929 | + |
| 930 | + def quantize_and_compute(self, x, wq, w_scale, masked_m, expected_m, m_values): |
| 931 | + xq, wq, x_scale, w_scale, masked_m, expected_m = self.quantize( |
| 932 | + x, wq, w_scale, masked_m, expected_m, m_values |
| 933 | + ) |
| 934 | + return self.compute(xq, wq, x_scale, w_scale, masked_m, expected_m, m_values) |
| 935 | + |
| 936 | + @property |
| 937 | + def name(self) -> str: |
| 938 | + return "deepgemm_masked_stacked" |
| 939 | + |
| 940 | + |
874 | 941 | @register_quantize_op
|
875 | 942 | class DeepGemmBlockwise(QuantizeOpBase):
|
876 | 943 | """
|
|
0 commit comments