Skip to content

Commit 474f66c

Browse files
Tianyu Liangfacebook-github-bot
authored andcommitted
Kernel optimization for stacked group-gemm E2E (#4468)
Summary: Pull Request resolved: #4468 X-link: facebookresearch/FBGEMM#1526 Fused preprocessing kernels for NVFP4StackedGroupedGemm to reduce preprocessing overhead. Reviewed By: jiawenliu64 Differential Revision: D78062155 fbshipit-source-id: e43f6a189dc66a34924d87d8b95225fe3b27b4a7
1 parent 3571258 commit 474f66c

File tree

2 files changed

+170
-76
lines changed

2 files changed

+170
-76
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py

Lines changed: 167 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _kernel_quantize_mx4_unpack(
7474
MBITS_IMPLICIT: tl.constexpr = MBITS + 1 # type: ignore[Incompatible variable type]
7575
MAX_FP16_MANTISSA_BITS: tl.constexpr = 8 # type: ignore[Incompatible variable type]
7676
IMPLIED_1_BIT: tl.constexpr = 1 << 7 # type: ignore[Incompatible variable type]
77-
FP16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
77+
BF16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
7878
MANTISSA_OVERFLOW_THRESHOLD: tl.constexpr = (1 << MBITS_IMPLICIT) - 1 # type: ignore[Incompatible variable type]
7979
EXPONENT_OVERFLOW_THRESHOLD: tl.constexpr = (1 << EBITS) - 1 # type: ignore[Incompatible variable type]
8080
IMPLICIT_1_MASK = (1 << (MBITS_IMPLICIT - 1)) - 1
@@ -145,7 +145,7 @@ def _kernel_quantize_mx4_unpack(
145145
# Compute the shared exponent of each group.
146146
group_max = tl.max(tl.abs(a_groups), axis=1)
147147
# Prevent infinite values in log.
148-
group_max = tl.where(group_max == 0, FP16_MIN_NORMAL, group_max)
148+
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
149149
# Load relevant random values if doing stochastic rounding
150150
# or stochastic casting.
151151
group_rand_bits = None
@@ -513,7 +513,7 @@ def _kernel_silu_quantize_mx4_unpack(
513513
MBITS_IMPLICIT: tl.constexpr = MBITS + 1 # type: ignore[Incompatible variable type]
514514
MAX_FP16_MANTISSA_BITS: tl.constexpr = 8 # type: ignore[Incompatible variable type]
515515
IMPLIED_1_BIT: tl.constexpr = 1 << 7 # type: ignore[Incompatible variable type]
516-
FP16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
516+
BF16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
517517
MANTISSA_OVERFLOW_THRESHOLD: tl.constexpr = (1 << MBITS_IMPLICIT) - 1 # type: ignore[Incompatible variable type]
518518
EXPONENT_OVERFLOW_THRESHOLD: tl.constexpr = (1 << EBITS) - 1 # type: ignore[Incompatible variable type]
519519
IMPLICIT_1_MASK = (1 << (MBITS_IMPLICIT - 1)) - 1
@@ -597,7 +597,7 @@ def _kernel_silu_quantize_mx4_unpack(
597597
# Compute the shared exponent of each group.
598598
group_max = tl.max(tl.abs(a_groups), axis=1)
599599
# Prevent infinite values in log.
600-
group_max = tl.where(group_max == 0, FP16_MIN_NORMAL, group_max)
600+
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
601601
# Load relevant random values if doing stochastic rounding
602602
# or stochastic casting.
603603
group_rand_bits = None
@@ -928,7 +928,7 @@ def _kernel_rms_quantize_mx4_unpack(
928928
MBITS_IMPLICIT: tl.constexpr = MBITS + 1 # type: ignore[Incompatible variable type]
929929
MAX_FP16_MANTISSA_BITS: tl.constexpr = 8 # type: ignore[Incompatible variable type]
930930
IMPLIED_1_BIT: tl.constexpr = 1 << 7 # type: ignore[Incompatible variable type]
931-
FP16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
931+
BF16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
932932
MANTISSA_OVERFLOW_THRESHOLD: tl.constexpr = (1 << MBITS_IMPLICIT) - 1 # type: ignore[Incompatible variable type]
933933
EXPONENT_OVERFLOW_THRESHOLD: tl.constexpr = (1 << EBITS) - 1 # type: ignore[Incompatible variable type]
934934
IMPLICIT_1_MASK = (1 << (MBITS_IMPLICIT - 1)) - 1
@@ -1021,7 +1021,7 @@ def _kernel_rms_quantize_mx4_unpack(
10211021
# Compute the shared exponent of each group.
10221022
group_max = tl.max(tl.abs(a_groups), axis=1)
10231023
# Prevent infinite values in log.
1024-
group_max = tl.where(group_max == 0, FP16_MIN_NORMAL, group_max)
1024+
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
10251025
# Load relevant random values if doing stochastic rounding
10261026
# or stochastic casting.
10271027
group_rand_bits = None
@@ -1346,7 +1346,7 @@ def _kernel_nvfp4_quantize(
13461346
USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors.
13471347
"""
13481348
# Define Constant Expressions.
1349-
FP16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
1349+
BF16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
13501350

13511351
# Get the current thread number.
13521352
pid = tl.program_id(0)
@@ -1411,7 +1411,7 @@ def _kernel_nvfp4_quantize(
14111411
# Next we scale A in preparation for quantization.
14121412
scale_ = group_max / 6.0 * input_global_scale
14131413
# Prevent infinite values in log.
1414-
group_max = tl.where(group_max == 0, FP16_MIN_NORMAL, group_max)
1414+
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
14151415

14161416
# Apply scale_ to input. We do this by broadcasting scale.
14171417
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
@@ -1546,7 +1546,7 @@ def triton_scale_nvfp4_quant(
15461546
), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
15471547

15481548
# Two fp4 values will be packed into an uint8.
1549-
out = torch.zeros((M, K // 8), device=device, dtype=torch.uint32)
1549+
out = torch.empty((M, K // 8), device=device, dtype=torch.uint32)
15501550

15511551
# We use the rounded values to store the swizzled values. Due to the
15521552
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
@@ -1559,7 +1559,7 @@ def round_up(x: int, y: int) -> int:
15591559
rounded_M = round_up(M, 128)
15601560
scale_K = K // block_size
15611561
rounded_K = round_up(scale_K, 4)
1562-
scale = torch.zeros((rounded_M, rounded_K), device=device, dtype=torch.int8)
1562+
scale = torch.empty((rounded_M, rounded_K), device=device, dtype=torch.int8)
15631563

15641564
# In this kernel, we want each row to be divisible by group_size.
15651565
# If the rows are not, then we will pad them. Find the number of
@@ -1679,7 +1679,7 @@ def _kernel_nvfp4_quantize_silu(
16791679
USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors.
16801680
"""
16811681
# Define Constant Expressions.
1682-
FP16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
1682+
BF16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
16831683

16841684
# Get the current thread number.
16851685
pid = tl.program_id(0)
@@ -1758,7 +1758,7 @@ def _kernel_nvfp4_quantize_silu(
17581758
# Next we scale A in preparation for quantization.
17591759
scale_ = group_max / 6.0 * input_global_scale
17601760
# Prevent infinite values in log.
1761-
group_max = tl.where(group_max == 0, FP16_MIN_NORMAL, group_max)
1761+
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
17621762

17631763
# Apply scale_ to input. We do this by broadcasting scale.
17641764
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
@@ -1896,7 +1896,7 @@ def triton_scale_nvfp4_quant_silu(
18961896
), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
18971897

18981898
# Two fp4 values will be packed into an uint8.
1899-
out = torch.zeros((M, K // 8), device=device, dtype=torch.uint32)
1899+
out = torch.empty((M, K // 8), device=device, dtype=torch.uint32)
19001900

19011901
# We use the rounded values to store the swizzled values. Due to the
19021902
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
@@ -1909,7 +1909,7 @@ def round_up(x: int, y: int) -> int:
19091909
rounded_M = round_up(M, 128)
19101910
scale_K = K // block_size
19111911
rounded_K = round_up(scale_K, 4)
1912-
scale = torch.zeros((rounded_M, rounded_K), device=device, dtype=torch.int8)
1912+
scale = torch.empty((rounded_M, rounded_K), device=device, dtype=torch.int8)
19131913

19141914
# In this kernel, we want each row to be divisible by group_size.
19151915
# If the rows are not, then we will pad them. Find the number of
@@ -2029,7 +2029,7 @@ def _kernel_nvfp4_quantize_rms(
20292029
USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors.
20302030
"""
20312031
# Define Constant Expressions.
2032-
FP16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
2032+
BF16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
20332033

20342034
# Get the current thread number.
20352035
pid = tl.program_id(0)
@@ -2117,7 +2117,7 @@ def _kernel_nvfp4_quantize_rms(
21172117
# Next we scale A in preparation for quantization.
21182118
scale_ = group_max / 6.0 * input_global_scale
21192119
# Prevent infinite values in log.
2120-
group_max = tl.where(group_max == 0, FP16_MIN_NORMAL, group_max)
2120+
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
21212121

21222122
# Apply scale_ to input. We do this by broadcasting scale.
21232123
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
@@ -2258,7 +2258,7 @@ def triton_scale_nvfp4_quant_rms(
22582258
), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
22592259

22602260
# Two fp4 values will be packed into an uint8.
2261-
out = torch.zeros((M, K // 8), device=device, dtype=torch.uint32)
2261+
out = torch.empty((M, K // 8), device=device, dtype=torch.uint32)
22622262

22632263
# We use the rounded values to store the swizzled values. Due to the
22642264
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
@@ -2271,7 +2271,7 @@ def round_up(x: int, y: int) -> int:
22712271
rounded_M = round_up(M, 128)
22722272
scale_K = K // block_size
22732273
rounded_K = round_up(scale_K, 4)
2274-
scale = torch.zeros((rounded_M, rounded_K), device=device, dtype=torch.int8)
2274+
scale = torch.empty((rounded_M, rounded_K), device=device, dtype=torch.int8)
22752275

22762276
# In this kernel, we want each row to be divisible by group_size.
22772277
# If the rows are not, then we will pad them. Find the number of
@@ -2395,7 +2395,7 @@ def _kernel_nvfp4_quantize_stacked(
23952395
USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors.
23962396
"""
23972397
# Define Constant Expressions.
2398-
FP16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
2398+
BF16_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
23992399

24002400
# Get the current thread number.
24012401
pid = tl.program_id(0)
@@ -2479,7 +2479,7 @@ def _kernel_nvfp4_quantize_stacked(
24792479
# Next we scale A in preparation for quantization.
24802480
scale_ = group_max / 6.0 * input_global_scale
24812481
# Prevent infinite values in log.
2482-
group_max = tl.where(group_max == 0, FP16_MIN_NORMAL, group_max)
2482+
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
24832483

24842484
# Apply scale_ to input. We do this by broadcasting scale.
24852485
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
@@ -2642,7 +2642,7 @@ def triton_nvfp4_quant_stacked(
26422642
), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
26432643

26442644
# Two fp4 values will be packed into an uint8.
2645-
out = torch.zeros((M, K // 8), device=device, dtype=torch.uint32)
2645+
out = torch.empty((M, K // 8), device=device, dtype=torch.uint32)
26462646

26472647
# We use the rounded values to store the swizzled values. Due to the
26482648
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
@@ -2655,7 +2655,7 @@ def round_up(x: int, y: int) -> int:
26552655
rounded_M = round_up(M + (starting_row_after_padding.numel() - 1) * 128, 128)
26562656
scale_K = K // block_size
26572657
rounded_K = round_up(scale_K, 4)
2658-
scale = torch.zeros((rounded_M, rounded_K), device=device, dtype=torch.int8)
2658+
scale = torch.empty((rounded_M, rounded_K), device=device, dtype=torch.int8)
26592659

26602660
# In this kernel, we want each row to be divisible by group_size.
26612661
# If the rows are not, then we will pad them. Find the number of
@@ -2730,3 +2730,148 @@ def round_up(x: int, y: int) -> int:
27302730

27312731
scale = scale.flatten()
27322732
return out.view(list(orig_shape[:-1]) + [-1]).view(torch.uint8), scale
2733+
2734+
2735+
@triton.jit
2736+
def fused_padding_cumsum_and_segmented_arange_kernel(
2737+
m_sizes_ptr, # [num_segments] input sizes
2738+
starting_row_after_padding_ptr, # [num_segments + 1] output: padded cumsum
2739+
size_cumulative_ptr, # [num_segments + 1] input: regular cumsum
2740+
belong_indices_ptr, # [N] output: segment index
2741+
row_within_tensor_ptr, # [N] output: position within segment
2742+
num_segments: tl.constexpr,
2743+
N: tl.constexpr,
2744+
BLOCK_SIZE: tl.constexpr,
2745+
):
2746+
pid = tl.program_id(0)
2747+
2748+
# Part 1: Compute padded cumsum (only first block does this)
2749+
if pid == 0:
2750+
offs = tl.arange(0, BLOCK_SIZE)
2751+
mask = offs < num_segments
2752+
2753+
# Load m_sizes
2754+
m_sizes = tl.load(m_sizes_ptr + offs, mask=mask, other=0)
2755+
2756+
# Compute padded sizes
2757+
padded_sizes = ((m_sizes + 128 - 1) // 128) * 128
2758+
2759+
# Compute inclusive cumsum
2760+
cumsum = tl.cumsum(padded_sizes, axis=0)
2761+
2762+
# Store at indices 1 through num_segments
2763+
tl.store(starting_row_after_padding_ptr + offs + 1, cumsum, mask=mask)
2764+
2765+
# Set first element to zero
2766+
first_elem_mask = offs == 0
2767+
tl.store(
2768+
starting_row_after_padding_ptr + offs,
2769+
tl.zeros([BLOCK_SIZE], dtype=cumsum.dtype),
2770+
mask=first_elem_mask,
2771+
)
2772+
2773+
# Part 2: Segmented arange (all blocks do this)
2774+
offs = tl.arange(0, BLOCK_SIZE)
2775+
row_idx = pid * BLOCK_SIZE + offs
2776+
mask = row_idx < N
2777+
2778+
# Binary search using the regular cumsum
2779+
left = tl.zeros([BLOCK_SIZE], dtype=tl.int32)
2780+
right = tl.zeros([BLOCK_SIZE], dtype=tl.int32) + num_segments
2781+
2782+
for _ in range(32): # 32 iterations for binary search
2783+
mid = (left + right) // 2
2784+
mid_val = tl.load(size_cumulative_ptr + mid, mask=mask, other=0)
2785+
cond = mid_val <= row_idx
2786+
left = tl.where(cond, mid + 1, left)
2787+
right = tl.where(cond, right, mid)
2788+
2789+
belong_idx = left - 1
2790+
tl.store(belong_indices_ptr + row_idx, belong_idx, mask=mask)
2791+
2792+
# Compute row_within_tensor
2793+
segment_start = tl.load(size_cumulative_ptr + belong_idx, mask=mask, other=0)
2794+
row_within = row_idx - segment_start
2795+
tl.store(row_within_tensor_ptr + row_idx, row_within, mask=mask)
2796+
2797+
2798+
@triton.jit
2799+
def cumsum_kernel(
2800+
m_sizes_ptr,
2801+
size_cumulative_ptr,
2802+
N: tl.constexpr,
2803+
BLOCK_SIZE: tl.constexpr,
2804+
):
2805+
offs = tl.arange(0, BLOCK_SIZE)
2806+
mask = offs < N
2807+
2808+
# Load m_sizes
2809+
m_sizes = tl.load(m_sizes_ptr + offs, mask=mask, other=0)
2810+
2811+
# Compute inclusive cumsum
2812+
cumsum = tl.cumsum(m_sizes, axis=0)
2813+
2814+
# Store cumsum at indices 1 through N
2815+
tl.store(size_cumulative_ptr + offs + 1, cumsum, mask=mask)
2816+
2817+
# Set first element to zero
2818+
first_elem_mask = offs == 0
2819+
tl.store(
2820+
size_cumulative_ptr + offs,
2821+
tl.zeros([BLOCK_SIZE], dtype=cumsum.dtype),
2822+
mask=first_elem_mask,
2823+
)
2824+
2825+
2826+
def nvfp4_fused_padding_cumsum_and_segmented_arange(m_sizes, N):
2827+
device = m_sizes.device
2828+
dtype = m_sizes.dtype
2829+
num_segments = m_sizes.shape[0]
2830+
2831+
# First compute regular cumsum (needed for segmented arange)
2832+
size_cumulative = nvfp4_triton_cumsum(m_sizes)
2833+
2834+
# Allocate outputs
2835+
starting_row_after_padding = torch.empty(
2836+
num_segments + 1, dtype=dtype, device=device
2837+
)
2838+
belong_indices = torch.empty(N, dtype=dtype, device=device)
2839+
row_within_tensor = torch.empty(N, dtype=dtype, device=device)
2840+
2841+
BLOCK_SIZE = 256
2842+
# Need enough blocks to cover N, but at least 1 for the padding cumsum
2843+
grid = (max(1, triton.cdiv(N, BLOCK_SIZE)),)
2844+
2845+
fused_padding_cumsum_and_segmented_arange_kernel[grid](
2846+
m_sizes,
2847+
starting_row_after_padding,
2848+
size_cumulative,
2849+
belong_indices,
2850+
row_within_tensor,
2851+
num_segments=num_segments,
2852+
N=N,
2853+
BLOCK_SIZE=BLOCK_SIZE,
2854+
num_warps=4,
2855+
)
2856+
2857+
return starting_row_after_padding, belong_indices, row_within_tensor
2858+
2859+
2860+
def nvfp4_triton_cumsum(m_sizes):
2861+
device = m_sizes.device
2862+
dtype = m_sizes.dtype
2863+
N = m_sizes.shape[0]
2864+
2865+
size_cumulative = torch.empty(N + 1, dtype=dtype, device=device)
2866+
2867+
BLOCK_SIZE = triton.next_power_of_2(N)
2868+
grid = (1,) # single-block kernel
2869+
2870+
cumsum_kernel[grid](
2871+
m_sizes,
2872+
size_cumulative,
2873+
N=N,
2874+
BLOCK_SIZE=BLOCK_SIZE,
2875+
num_warps=4,
2876+
)
2877+
return size_cumulative

0 commit comments

Comments
 (0)