Skip to content

Commit b3052b7

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Correct preprocess in NVFP4 stacked grouped gemm (#4475)
Summary: Pull Request resolved: #4475 X-link: facebookresearch/FBGEMM#1532 as title Reviewed By: Tianyu-Liang Differential Revision: D78168569 fbshipit-source-id: 6dbf3d2c0d09d8f39c1df389e95f0475346483e1
1 parent 4551ff7 commit b3052b7

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2490,13 +2490,10 @@ class NVFP4StackedGroupedGemm(QuantizeOpBase):
24902490
"""
24912491

24922492
def preprocess(self, x, w):
2493-
24942493
m_values = [i.shape[0] for i in x]
24952494
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
24962495
x = torch.concat(x, dim=0).contiguous()
2497-
return x, w, m_sizes
24982496

2499-
def quantize(self, x, w, m_sizes):
25002497
def get_global_scale(x, w):
25012498
G = len(w)
25022499
x_global_scale = []
@@ -2520,21 +2517,25 @@ def get_global_scale(x, w):
25202517

25212518
return x_global_scale, w_global_scale, global_scale
25222519

2523-
starting_row_after_padding, belong_indices, row_within_tensor = (
2524-
nvfp4_fused_padding_cumsum_and_segmented_arange(m_sizes, x.shape[0])
2525-
)
2526-
25272520
# Compute global scale for each group
25282521
G = m_sizes.numel()
2529-
25302522
x_global_scale, w_global_scale, global_scale = get_global_scale(x, w)
25312523

2524+
global_scale = torch.stack(global_scale, dim=0).contiguous()
2525+
25322526
wq, w_scale = zip(
25332527
*[triton_scale_nvfp4_quant(w[i], w_global_scale[i]) for i in range(G)]
25342528
)
25352529
wq = torch.stack(wq, dim=0).contiguous()
25362530
w_scale = torch.stack(w_scale, dim=0).contiguous()
25372531

2532+
return x, wq, w_scale, x_global_scale, global_scale, m_sizes
2533+
2534+
def quantize(self, x, wq, w_scale, x_global_scale, global_scale, m_sizes):
2535+
starting_row_after_padding, belong_indices, row_within_tensor = (
2536+
nvfp4_fused_padding_cumsum_and_segmented_arange(m_sizes, x.shape[0])
2537+
)
2538+
25382539
xq, x_scale = triton_nvfp4_quant_stacked(
25392540
x,
25402541
x_global_scale[0],
@@ -2543,7 +2544,7 @@ def get_global_scale(x, w):
25432544
row_within_tensor,
25442545
)
25452546
x_scale = x_scale.reshape(-1, x.shape[1] // 16)
2546-
global_scale = torch.stack(global_scale, dim=0).contiguous()
2547+
25472548
return (
25482549
xq,
25492550
wq,
@@ -2575,9 +2576,11 @@ def compute(
25752576
use_mx=False,
25762577
)
25772578

2578-
def quantize_and_compute(self, x, w, m_sizes):
2579+
def quantize_and_compute(
2580+
self, x, wq, w_scale, x_global_scale, global_scale, m_sizes
2581+
):
25792582
xq, wq, x_scale, w_scale, m_sizes, global_scale, starting_row_after_padding = (
2580-
self.quantize(x, w, m_sizes)
2583+
self.quantize(x, wq, w_scale, x_global_scale, global_scale, m_sizes)
25812584
)
25822585
return self.compute(
25832586
xq, wq, x_scale, w_scale, m_sizes, global_scale, starting_row_after_padding

0 commit comments

Comments
 (0)