@@ -2490,13 +2490,10 @@ class NVFP4StackedGroupedGemm(QuantizeOpBase):
2490
2490
"""
2491
2491
2492
2492
def preprocess (self , x , w ):
2493
-
2494
2493
m_values = [i .shape [0 ] for i in x ]
2495
2494
m_sizes = torch .tensor (m_values ).to (dtype = torch .int64 , device = x [0 ].device )
2496
2495
x = torch .concat (x , dim = 0 ).contiguous ()
2497
- return x , w , m_sizes
2498
2496
2499
- def quantize (self , x , w , m_sizes ):
2500
2497
def get_global_scale (x , w ):
2501
2498
G = len (w )
2502
2499
x_global_scale = []
@@ -2520,21 +2517,25 @@ def get_global_scale(x, w):
2520
2517
2521
2518
return x_global_scale , w_global_scale , global_scale
2522
2519
2523
- starting_row_after_padding , belong_indices , row_within_tensor = (
2524
- nvfp4_fused_padding_cumsum_and_segmented_arange (m_sizes , x .shape [0 ])
2525
- )
2526
-
2527
2520
# Compute global scale for each group
2528
2521
G = m_sizes .numel ()
2529
-
2530
2522
x_global_scale , w_global_scale , global_scale = get_global_scale (x , w )
2531
2523
2524
+ global_scale = torch .stack (global_scale , dim = 0 ).contiguous ()
2525
+
2532
2526
wq , w_scale = zip (
2533
2527
* [triton_scale_nvfp4_quant (w [i ], w_global_scale [i ]) for i in range (G )]
2534
2528
)
2535
2529
wq = torch .stack (wq , dim = 0 ).contiguous ()
2536
2530
w_scale = torch .stack (w_scale , dim = 0 ).contiguous ()
2537
2531
2532
+ return x , wq , w_scale , x_global_scale , global_scale , m_sizes
2533
+
2534
+ def quantize (self , x , wq , w_scale , x_global_scale , global_scale , m_sizes ):
2535
+ starting_row_after_padding , belong_indices , row_within_tensor = (
2536
+ nvfp4_fused_padding_cumsum_and_segmented_arange (m_sizes , x .shape [0 ])
2537
+ )
2538
+
2538
2539
xq , x_scale = triton_nvfp4_quant_stacked (
2539
2540
x ,
2540
2541
x_global_scale [0 ],
@@ -2543,7 +2544,7 @@ def get_global_scale(x, w):
2543
2544
row_within_tensor ,
2544
2545
)
2545
2546
x_scale = x_scale .reshape (- 1 , x .shape [1 ] // 16 )
2546
- global_scale = torch . stack ( global_scale , dim = 0 ). contiguous ()
2547
+
2547
2548
return (
2548
2549
xq ,
2549
2550
wq ,
@@ -2575,9 +2576,11 @@ def compute(
2575
2576
use_mx = False ,
2576
2577
)
2577
2578
2578
- def quantize_and_compute (self , x , w , m_sizes ):
2579
+ def quantize_and_compute (
2580
+ self , x , wq , w_scale , x_global_scale , global_scale , m_sizes
2581
+ ):
2579
2582
xq , wq , x_scale , w_scale , m_sizes , global_scale , starting_row_after_padding = (
2580
- self .quantize (x , w , m_sizes )
2583
+ self .quantize (x , wq , w_scale , x_global_scale , global_scale , m_sizes )
2581
2584
)
2582
2585
return self .compute (
2583
2586
xq , wq , x_scale , w_scale , m_sizes , global_scale , starting_row_after_padding
0 commit comments