@@ -74,7 +74,7 @@ def _kernel_quantize_mx4_unpack(
74
74
MBITS_IMPLICIT : tl .constexpr = MBITS + 1 # type: ignore[Incompatible variable type]
75
75
MAX_FP16_MANTISSA_BITS : tl .constexpr = 8 # type: ignore[Incompatible variable type]
76
76
IMPLIED_1_BIT : tl .constexpr = 1 << 7 # type: ignore[Incompatible variable type]
77
- FP16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
77
+ BF16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
78
78
MANTISSA_OVERFLOW_THRESHOLD : tl .constexpr = (1 << MBITS_IMPLICIT ) - 1 # type: ignore[Incompatible variable type]
79
79
EXPONENT_OVERFLOW_THRESHOLD : tl .constexpr = (1 << EBITS ) - 1 # type: ignore[Incompatible variable type]
80
80
IMPLICIT_1_MASK = (1 << (MBITS_IMPLICIT - 1 )) - 1
@@ -145,7 +145,7 @@ def _kernel_quantize_mx4_unpack(
145
145
# Compute the shared exponent of each group.
146
146
group_max = tl .max (tl .abs (a_groups ), axis = 1 )
147
147
# Prevent infinite values in log.
148
- group_max = tl .where (group_max == 0 , FP16_MIN_NORMAL , group_max )
148
+ group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
149
149
# Load relevant random values if doing stochastic rounding
150
150
# or stochastic casting.
151
151
group_rand_bits = None
@@ -513,7 +513,7 @@ def _kernel_silu_quantize_mx4_unpack(
513
513
MBITS_IMPLICIT : tl .constexpr = MBITS + 1 # type: ignore[Incompatible variable type]
514
514
MAX_FP16_MANTISSA_BITS : tl .constexpr = 8 # type: ignore[Incompatible variable type]
515
515
IMPLIED_1_BIT : tl .constexpr = 1 << 7 # type: ignore[Incompatible variable type]
516
- FP16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
516
+ BF16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
517
517
MANTISSA_OVERFLOW_THRESHOLD : tl .constexpr = (1 << MBITS_IMPLICIT ) - 1 # type: ignore[Incompatible variable type]
518
518
EXPONENT_OVERFLOW_THRESHOLD : tl .constexpr = (1 << EBITS ) - 1 # type: ignore[Incompatible variable type]
519
519
IMPLICIT_1_MASK = (1 << (MBITS_IMPLICIT - 1 )) - 1
@@ -597,7 +597,7 @@ def _kernel_silu_quantize_mx4_unpack(
597
597
# Compute the shared exponent of each group.
598
598
group_max = tl .max (tl .abs (a_groups ), axis = 1 )
599
599
# Prevent infinite values in log.
600
- group_max = tl .where (group_max == 0 , FP16_MIN_NORMAL , group_max )
600
+ group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
601
601
# Load relevant random values if doing stochastic rounding
602
602
# or stochastic casting.
603
603
group_rand_bits = None
@@ -928,7 +928,7 @@ def _kernel_rms_quantize_mx4_unpack(
928
928
MBITS_IMPLICIT : tl .constexpr = MBITS + 1 # type: ignore[Incompatible variable type]
929
929
MAX_FP16_MANTISSA_BITS : tl .constexpr = 8 # type: ignore[Incompatible variable type]
930
930
IMPLIED_1_BIT : tl .constexpr = 1 << 7 # type: ignore[Incompatible variable type]
931
- FP16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
931
+ BF16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
932
932
MANTISSA_OVERFLOW_THRESHOLD : tl .constexpr = (1 << MBITS_IMPLICIT ) - 1 # type: ignore[Incompatible variable type]
933
933
EXPONENT_OVERFLOW_THRESHOLD : tl .constexpr = (1 << EBITS ) - 1 # type: ignore[Incompatible variable type]
934
934
IMPLICIT_1_MASK = (1 << (MBITS_IMPLICIT - 1 )) - 1
@@ -1021,7 +1021,7 @@ def _kernel_rms_quantize_mx4_unpack(
1021
1021
# Compute the shared exponent of each group.
1022
1022
group_max = tl .max (tl .abs (a_groups ), axis = 1 )
1023
1023
# Prevent infinite values in log.
1024
- group_max = tl .where (group_max == 0 , FP16_MIN_NORMAL , group_max )
1024
+ group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
1025
1025
# Load relevant random values if doing stochastic rounding
1026
1026
# or stochastic casting.
1027
1027
group_rand_bits = None
@@ -1346,7 +1346,7 @@ def _kernel_nvfp4_quantize(
1346
1346
USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors.
1347
1347
"""
1348
1348
# Define Constant Expressions.
1349
- FP16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
1349
+ BF16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
1350
1350
1351
1351
# Get the current thread number.
1352
1352
pid = tl .program_id (0 )
@@ -1411,7 +1411,7 @@ def _kernel_nvfp4_quantize(
1411
1411
# Next we scale A in preparation for quantization.
1412
1412
scale_ = group_max / 6.0 * input_global_scale
1413
1413
# Prevent infinite values in log.
1414
- group_max = tl .where (group_max == 0 , FP16_MIN_NORMAL , group_max )
1414
+ group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
1415
1415
1416
1416
# Apply scale_ to input. We do this by broadcasting scale.
1417
1417
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
@@ -1546,7 +1546,7 @@ def triton_scale_nvfp4_quant(
1546
1546
), f"input.dtype needs to be fp16 or bf16 but got { input .dtype } ."
1547
1547
1548
1548
# Two fp4 values will be packed into an uint8.
1549
- out = torch .zeros ((M , K // 8 ), device = device , dtype = torch .uint32 )
1549
+ out = torch .empty ((M , K // 8 ), device = device , dtype = torch .uint32 )
1550
1550
1551
1551
# We use the rounded values to store the swizzled values. Due to the
1552
1552
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
@@ -1559,7 +1559,7 @@ def round_up(x: int, y: int) -> int:
1559
1559
rounded_M = round_up (M , 128 )
1560
1560
scale_K = K // block_size
1561
1561
rounded_K = round_up (scale_K , 4 )
1562
- scale = torch .zeros ((rounded_M , rounded_K ), device = device , dtype = torch .int8 )
1562
+ scale = torch .empty ((rounded_M , rounded_K ), device = device , dtype = torch .int8 )
1563
1563
1564
1564
# In this kernel, we want each row to be divisible by group_size.
1565
1565
# If the rows are not, then we will pad them. Find the number of
@@ -1679,7 +1679,7 @@ def _kernel_nvfp4_quantize_silu(
1679
1679
USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors.
1680
1680
"""
1681
1681
# Define Constant Expressions.
1682
- FP16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
1682
+ BF16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
1683
1683
1684
1684
# Get the current thread number.
1685
1685
pid = tl .program_id (0 )
@@ -1758,7 +1758,7 @@ def _kernel_nvfp4_quantize_silu(
1758
1758
# Next we scale A in preparation for quantization.
1759
1759
scale_ = group_max / 6.0 * input_global_scale
1760
1760
# Prevent infinite values in log.
1761
- group_max = tl .where (group_max == 0 , FP16_MIN_NORMAL , group_max )
1761
+ group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
1762
1762
1763
1763
# Apply scale_ to input. We do this by broadcasting scale.
1764
1764
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
@@ -1896,7 +1896,7 @@ def triton_scale_nvfp4_quant_silu(
1896
1896
), f"input.dtype needs to be fp16 or bf16 but got { input .dtype } ."
1897
1897
1898
1898
# Two fp4 values will be packed into an uint8.
1899
- out = torch .zeros ((M , K // 8 ), device = device , dtype = torch .uint32 )
1899
+ out = torch .empty ((M , K // 8 ), device = device , dtype = torch .uint32 )
1900
1900
1901
1901
# We use the rounded values to store the swizzled values. Due to the
1902
1902
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
@@ -1909,7 +1909,7 @@ def round_up(x: int, y: int) -> int:
1909
1909
rounded_M = round_up (M , 128 )
1910
1910
scale_K = K // block_size
1911
1911
rounded_K = round_up (scale_K , 4 )
1912
- scale = torch .zeros ((rounded_M , rounded_K ), device = device , dtype = torch .int8 )
1912
+ scale = torch .empty ((rounded_M , rounded_K ), device = device , dtype = torch .int8 )
1913
1913
1914
1914
# In this kernel, we want each row to be divisible by group_size.
1915
1915
# If the rows are not, then we will pad them. Find the number of
@@ -2029,7 +2029,7 @@ def _kernel_nvfp4_quantize_rms(
2029
2029
USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors.
2030
2030
"""
2031
2031
# Define Constant Expressions.
2032
- FP16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
2032
+ BF16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
2033
2033
2034
2034
# Get the current thread number.
2035
2035
pid = tl .program_id (0 )
@@ -2117,7 +2117,7 @@ def _kernel_nvfp4_quantize_rms(
2117
2117
# Next we scale A in preparation for quantization.
2118
2118
scale_ = group_max / 6.0 * input_global_scale
2119
2119
# Prevent infinite values in log.
2120
- group_max = tl .where (group_max == 0 , FP16_MIN_NORMAL , group_max )
2120
+ group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
2121
2121
2122
2122
# Apply scale_ to input. We do this by broadcasting scale.
2123
2123
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
@@ -2258,7 +2258,7 @@ def triton_scale_nvfp4_quant_rms(
2258
2258
), f"input.dtype needs to be fp16 or bf16 but got { input .dtype } ."
2259
2259
2260
2260
# Two fp4 values will be packed into an uint8.
2261
- out = torch .zeros ((M , K // 8 ), device = device , dtype = torch .uint32 )
2261
+ out = torch .empty ((M , K // 8 ), device = device , dtype = torch .uint32 )
2262
2262
2263
2263
# We use the rounded values to store the swizzled values. Due to the
2264
2264
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
@@ -2271,7 +2271,7 @@ def round_up(x: int, y: int) -> int:
2271
2271
rounded_M = round_up (M , 128 )
2272
2272
scale_K = K // block_size
2273
2273
rounded_K = round_up (scale_K , 4 )
2274
- scale = torch .zeros ((rounded_M , rounded_K ), device = device , dtype = torch .int8 )
2274
+ scale = torch .empty ((rounded_M , rounded_K ), device = device , dtype = torch .int8 )
2275
2275
2276
2276
# In this kernel, we want each row to be divisible by group_size.
2277
2277
# If the rows are not, then we will pad them. Find the number of
@@ -2395,7 +2395,7 @@ def _kernel_nvfp4_quantize_stacked(
2395
2395
USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors.
2396
2396
"""
2397
2397
# Define Constant Expressions.
2398
- FP16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
2398
+ BF16_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
2399
2399
2400
2400
# Get the current thread number.
2401
2401
pid = tl .program_id (0 )
@@ -2479,7 +2479,7 @@ def _kernel_nvfp4_quantize_stacked(
2479
2479
# Next we scale A in preparation for quantization.
2480
2480
scale_ = group_max / 6.0 * input_global_scale
2481
2481
# Prevent infinite values in log.
2482
- group_max = tl .where (group_max == 0 , FP16_MIN_NORMAL , group_max )
2482
+ group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
2483
2483
2484
2484
# Apply scale_ to input. We do this by broadcasting scale.
2485
2485
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
@@ -2642,7 +2642,7 @@ def triton_nvfp4_quant_stacked(
2642
2642
), f"input.dtype needs to be fp16 or bf16 but got { input .dtype } ."
2643
2643
2644
2644
# Two fp4 values will be packed into an uint8.
2645
- out = torch .zeros ((M , K // 8 ), device = device , dtype = torch .uint32 )
2645
+ out = torch .empty ((M , K // 8 ), device = device , dtype = torch .uint32 )
2646
2646
2647
2647
# We use the rounded values to store the swizzled values. Due to the
2648
2648
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
@@ -2655,7 +2655,7 @@ def round_up(x: int, y: int) -> int:
2655
2655
rounded_M = round_up (M + (starting_row_after_padding .numel () - 1 ) * 128 , 128 )
2656
2656
scale_K = K // block_size
2657
2657
rounded_K = round_up (scale_K , 4 )
2658
- scale = torch .zeros ((rounded_M , rounded_K ), device = device , dtype = torch .int8 )
2658
+ scale = torch .empty ((rounded_M , rounded_K ), device = device , dtype = torch .int8 )
2659
2659
2660
2660
# In this kernel, we want each row to be divisible by group_size.
2661
2661
# If the rows are not, then we will pad them. Find the number of
@@ -2730,3 +2730,148 @@ def round_up(x: int, y: int) -> int:
2730
2730
2731
2731
scale = scale .flatten ()
2732
2732
return out .view (list (orig_shape [:- 1 ]) + [- 1 ]).view (torch .uint8 ), scale
2733
+
2734
+
2735
+ @triton .jit
2736
+ def fused_padding_cumsum_and_segmented_arange_kernel (
2737
+ m_sizes_ptr , # [num_segments] input sizes
2738
+ starting_row_after_padding_ptr , # [num_segments + 1] output: padded cumsum
2739
+ size_cumulative_ptr , # [num_segments + 1] input: regular cumsum
2740
+ belong_indices_ptr , # [N] output: segment index
2741
+ row_within_tensor_ptr , # [N] output: position within segment
2742
+ num_segments : tl .constexpr ,
2743
+ N : tl .constexpr ,
2744
+ BLOCK_SIZE : tl .constexpr ,
2745
+ ):
2746
+ pid = tl .program_id (0 )
2747
+
2748
+ # Part 1: Compute padded cumsum (only first block does this)
2749
+ if pid == 0 :
2750
+ offs = tl .arange (0 , BLOCK_SIZE )
2751
+ mask = offs < num_segments
2752
+
2753
+ # Load m_sizes
2754
+ m_sizes = tl .load (m_sizes_ptr + offs , mask = mask , other = 0 )
2755
+
2756
+ # Compute padded sizes
2757
+ padded_sizes = ((m_sizes + 128 - 1 ) // 128 ) * 128
2758
+
2759
+ # Compute inclusive cumsum
2760
+ cumsum = tl .cumsum (padded_sizes , axis = 0 )
2761
+
2762
+ # Store at indices 1 through num_segments
2763
+ tl .store (starting_row_after_padding_ptr + offs + 1 , cumsum , mask = mask )
2764
+
2765
+ # Set first element to zero
2766
+ first_elem_mask = offs == 0
2767
+ tl .store (
2768
+ starting_row_after_padding_ptr + offs ,
2769
+ tl .zeros ([BLOCK_SIZE ], dtype = cumsum .dtype ),
2770
+ mask = first_elem_mask ,
2771
+ )
2772
+
2773
+ # Part 2: Segmented arange (all blocks do this)
2774
+ offs = tl .arange (0 , BLOCK_SIZE )
2775
+ row_idx = pid * BLOCK_SIZE + offs
2776
+ mask = row_idx < N
2777
+
2778
+ # Binary search using the regular cumsum
2779
+ left = tl .zeros ([BLOCK_SIZE ], dtype = tl .int32 )
2780
+ right = tl .zeros ([BLOCK_SIZE ], dtype = tl .int32 ) + num_segments
2781
+
2782
+ for _ in range (32 ): # 32 iterations for binary search
2783
+ mid = (left + right ) // 2
2784
+ mid_val = tl .load (size_cumulative_ptr + mid , mask = mask , other = 0 )
2785
+ cond = mid_val <= row_idx
2786
+ left = tl .where (cond , mid + 1 , left )
2787
+ right = tl .where (cond , right , mid )
2788
+
2789
+ belong_idx = left - 1
2790
+ tl .store (belong_indices_ptr + row_idx , belong_idx , mask = mask )
2791
+
2792
+ # Compute row_within_tensor
2793
+ segment_start = tl .load (size_cumulative_ptr + belong_idx , mask = mask , other = 0 )
2794
+ row_within = row_idx - segment_start
2795
+ tl .store (row_within_tensor_ptr + row_idx , row_within , mask = mask )
2796
+
2797
+
2798
+ @triton .jit
2799
+ def cumsum_kernel (
2800
+ m_sizes_ptr ,
2801
+ size_cumulative_ptr ,
2802
+ N : tl .constexpr ,
2803
+ BLOCK_SIZE : tl .constexpr ,
2804
+ ):
2805
+ offs = tl .arange (0 , BLOCK_SIZE )
2806
+ mask = offs < N
2807
+
2808
+ # Load m_sizes
2809
+ m_sizes = tl .load (m_sizes_ptr + offs , mask = mask , other = 0 )
2810
+
2811
+ # Compute inclusive cumsum
2812
+ cumsum = tl .cumsum (m_sizes , axis = 0 )
2813
+
2814
+ # Store cumsum at indices 1 through N
2815
+ tl .store (size_cumulative_ptr + offs + 1 , cumsum , mask = mask )
2816
+
2817
+ # Set first element to zero
2818
+ first_elem_mask = offs == 0
2819
+ tl .store (
2820
+ size_cumulative_ptr + offs ,
2821
+ tl .zeros ([BLOCK_SIZE ], dtype = cumsum .dtype ),
2822
+ mask = first_elem_mask ,
2823
+ )
2824
+
2825
+
2826
+ def nvfp4_fused_padding_cumsum_and_segmented_arange (m_sizes , N ):
2827
+ device = m_sizes .device
2828
+ dtype = m_sizes .dtype
2829
+ num_segments = m_sizes .shape [0 ]
2830
+
2831
+ # First compute regular cumsum (needed for segmented arange)
2832
+ size_cumulative = nvfp4_triton_cumsum (m_sizes )
2833
+
2834
+ # Allocate outputs
2835
+ starting_row_after_padding = torch .empty (
2836
+ num_segments + 1 , dtype = dtype , device = device
2837
+ )
2838
+ belong_indices = torch .empty (N , dtype = dtype , device = device )
2839
+ row_within_tensor = torch .empty (N , dtype = dtype , device = device )
2840
+
2841
+ BLOCK_SIZE = 256
2842
+ # Need enough blocks to cover N, but at least 1 for the padding cumsum
2843
+ grid = (max (1 , triton .cdiv (N , BLOCK_SIZE )),)
2844
+
2845
+ fused_padding_cumsum_and_segmented_arange_kernel [grid ](
2846
+ m_sizes ,
2847
+ starting_row_after_padding ,
2848
+ size_cumulative ,
2849
+ belong_indices ,
2850
+ row_within_tensor ,
2851
+ num_segments = num_segments ,
2852
+ N = N ,
2853
+ BLOCK_SIZE = BLOCK_SIZE ,
2854
+ num_warps = 4 ,
2855
+ )
2856
+
2857
+ return starting_row_after_padding , belong_indices , row_within_tensor
2858
+
2859
+
2860
+ def nvfp4_triton_cumsum (m_sizes ):
2861
+ device = m_sizes .device
2862
+ dtype = m_sizes .dtype
2863
+ N = m_sizes .shape [0 ]
2864
+
2865
+ size_cumulative = torch .empty (N + 1 , dtype = dtype , device = device )
2866
+
2867
+ BLOCK_SIZE = triton .next_power_of_2 (N )
2868
+ grid = (1 ,) # single-block kernel
2869
+
2870
+ cumsum_kernel [grid ](
2871
+ m_sizes ,
2872
+ size_cumulative ,
2873
+ N = N ,
2874
+ BLOCK_SIZE = BLOCK_SIZE ,
2875
+ num_warps = 4 ,
2876
+ )
2877
+ return size_cumulative
0 commit comments