Skip to content

Commit ffcc7da

Browse files
levendleefacebook-github-bot
authored andcommitted
Add DEEPGEMM Masked API. (#3949)
Summary: Pull Request resolved: #3949 X-link: facebookresearch/FBGEMM#1033 The DeepGEMM GroupedGEMM contiguous API requires M to be 128 bytes aligned, and will produce buggy results if the requirement isn't meet. [DeepGEMM API Requirement](deepseek-ai/DeepGEMM#15) Reviewed By: Alkaid-Benetnash Differential Revision: D72688238 fbshipit-source-id: 886daee2957700574f1dbebedc8887569c07c1cb
1 parent 8fb5e75 commit ffcc7da

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
gemm_fp8_fp8_bf16_nt,
4646
get_col_major_tma_aligned_tensor,
4747
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
48+
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
4849
)
4950

5051
DEEPGEMM_ENABLED = True
@@ -871,6 +872,72 @@ def cuda(self) -> bool:
871872
return DEEPGEMM_ENABLED
872873

873874

875+
@register_quantize_op
876+
class DeepGemmMaskedStacked(DeepGemmStacked):
877+
def preprocess(self, x, w):
878+
# Quantize weights.
879+
wq, w_scale = zip(*[quantize_fp8_block(i, block_k=128, block_m=128) for i in w])
880+
# Group weights as single tensor.
881+
wq = torch.stack(wq, dim=0).contiguous()
882+
w_scale = torch.stack(w_scale, dim=0).contiguous()
883+
884+
# Also view input as flattened.
885+
m_values = [i.shape[0] for i in x]
886+
expected_m = max(m_values)
887+
padded_m_max = ((max(m_values) + 127) // 128) * 128
888+
masked_m = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
889+
890+
num_groups = len(m_values)
891+
k = x[0].shape[1]
892+
x_padded = torch.zeros(
893+
[num_groups, padded_m_max, k], device=x[0].device, dtype=x[0].dtype
894+
)
895+
for g in range(num_groups):
896+
x_padded[g, : m_values[g], :] = x[g]
897+
898+
# Return processed tensors.
899+
return x_padded, wq, w_scale, masked_m, expected_m, m_values
900+
901+
def quantize(self, x, wq, w_scale, masked_m, expected_m, m_values):
902+
g, m_max, k = x.shape
903+
xq, x_scale = quantize_fp8_block(x.view(-1, k), block_m=1, block_k=128)
904+
# Pretranspose scales to deepgemm format.
905+
x_scale = get_col_major_tma_aligned_tensor(x_scale)
906+
return (
907+
xq.view(g, m_max, -1),
908+
wq,
909+
x_scale.view(g, m_max, -1),
910+
w_scale,
911+
masked_m,
912+
expected_m,
913+
m_values,
914+
)
915+
916+
def compute(self, xq, wq, x_scale, w_scale, masked_m, expected_m, m_values):
917+
# Preallocate output.
918+
out = torch.empty(
919+
[xq.shape[0], xq.shape[1], wq.shape[1]],
920+
device=xq.device,
921+
dtype=torch.bfloat16,
922+
)
923+
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
924+
(xq, x_scale), (wq, w_scale), out, masked_m, expected_m
925+
)
926+
num_groups = xq.shape[0]
927+
out_list = [out[g, : m_values[g], :] for g in range(num_groups)]
928+
return out_list
929+
930+
def quantize_and_compute(self, x, wq, w_scale, masked_m, expected_m, m_values):
931+
xq, wq, x_scale, w_scale, masked_m, expected_m = self.quantize(
932+
x, wq, w_scale, masked_m, expected_m, m_values
933+
)
934+
return self.compute(xq, wq, x_scale, w_scale, masked_m, expected_m, m_values)
935+
936+
@property
937+
def name(self) -> str:
938+
return "deepgemm_masked_stacked"
939+
940+
874941
@register_quantize_op
875942
class DeepGemmBlockwise(QuantizeOpBase):
876943
"""

0 commit comments

Comments
 (0)