File tree Expand file tree Collapse file tree 1 file changed +4
-5
lines changed
vllm/model_executor/layers/fused_moe Expand file tree Collapse file tree 1 file changed +4
-5
lines changed Original file line number Diff line number Diff line change @@ -55,6 +55,7 @@ def _silu_mul_fp8_quant_deep_gemm(
55
55
56
56
# Meta ---------------------------------------------------------------
57
57
BLOCK : tl .constexpr ,
58
+ NUM_STAGES : tl .constexpr ,
58
59
):
59
60
G = H // GROUP_SIZE
60
61
@@ -73,8 +74,7 @@ def _silu_mul_fp8_quant_deep_gemm(
73
74
cols = cols .to (tl .int64 )
74
75
mask_h = cols < BLOCK
75
76
76
- t = tl .zeros ([], tl .int64 )
77
- while t < n_tokens :
77
+ for t in tl .range (0 , n_tokens , num_stages = NUM_STAGES ):
78
78
base_i_offset = (e * stride_i_e + t * stride_i_t +
79
79
g * GROUP_SIZE * stride_i_h )
80
80
base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
@@ -102,8 +102,6 @@ def _silu_mul_fp8_quant_deep_gemm(
102
102
tl .store (y_q_ptr + base_yq_offset + cols * stride_yq_h , y_q , mask = mask )
103
103
tl .store (y_s_ptr + base_ys_offset , y_s )
104
104
105
- t += 1
106
-
107
105
108
106
def silu_mul_fp8_quant_deep_gemm (
109
107
y : torch .Tensor , # (E, T, 2*H) float32
@@ -180,7 +178,8 @@ def silu_mul_fp8_quant_deep_gemm(
180
178
fp8_max ,
181
179
is_blackwell_deep_gemm_used (),
182
180
BLOCK = group_size ,
183
- num_warps = 4 ,
181
+ NUM_STAGES = 8 ,
182
+ num_warps = 1 ,
184
183
)
185
184
186
185
return y_q , y_s
You can’t perform that action at this time.
0 commit comments