9
9
from tests .kernels .quant_utils import (native_per_token_group_quant_fp8 ,
10
10
native_w8a8_block_matmul ,
11
11
per_block_cast_to_fp8 )
12
+ from tests .kernels .moe .utils import make_test_weights
12
13
from vllm .config import VllmConfig , set_current_vllm_config
13
14
from vllm .model_executor .layers .activation import SiluAndMul
14
- from vllm .model_executor .layers .fused_moe import fused_moe
15
+ from vllm .model_executor .layers .fused_moe import fused_experts
15
16
from vllm .model_executor .layers .fused_moe .deep_gemm_moe import (
16
17
_valid_deep_gemm_shape , deep_gemm_moe_fp8 )
17
18
from vllm .model_executor .layers .fused_moe .fused_moe import (
55
56
SEEDS = [0 ]
56
57
57
58
58
- def torch_w8a8_block_fp8_moe (a , w1 , w2 , w1_s , w2_s , score , topk , block_shape ):
59
+ def torch_w8a8_block_fp8_moe (a , w1 , w2 , w1_s , w2_s , topk_weight , topk_ids , block_shape ):
59
60
"""Fused moe with block-wise quantization using native torch."""
60
61
B , D = a .shape
62
+ topk = topk_ids .size (1 )
61
63
a = a .view (B , - 1 , D ).repeat (1 , topk , 1 ).reshape (- 1 , D )
62
64
out = torch .zeros (B * topk , w2 .shape [1 ], dtype = a .dtype , device = a .device )
63
- score = torch .softmax (score , dim = - 1 , dtype = torch .float32 )
64
- topk_weight , topk_ids = torch .topk (score , topk )
65
+
65
66
topk_weight = topk_weight .view (- 1 )
66
67
topk_ids = topk_ids .view (- 1 )
67
68
@@ -112,80 +113,59 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
112
113
113
114
monkeypatch .setenv ("VLLM_FUSED_MOE_CHUNK_SIZE" , "8192" )
114
115
115
- factor_for_scale = 1e-2
116
- fp8_info = torch .finfo (torch .float8_e4m3fn )
117
- fp8_max , fp8_min = fp8_info .max , fp8_info .min
118
-
119
116
a = torch .randn ((M , K ), dtype = dtype ) / 10
120
-
121
- w1_bf16 = (torch .rand (
122
- (E , 2 * N , K ), dtype = torch .bfloat16 ) - 0.5 ) * 2 * fp8_max
123
- w1 = w1_bf16 .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
124
- del w1_bf16
125
-
126
- w2_bf16 = (torch .rand ((E , K , N ), dtype = torch .bfloat16 ) - 0.5 ) * 2 * fp8_max
127
- w2 = w2_bf16 .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
128
- del w2_bf16
129
-
130
- block_n , block_k = block_size [0 ], block_size [1 ]
131
- n_tiles_w1 = (2 * N + block_n - 1 ) // block_n
132
- n_tiles_w2 = (K + block_n - 1 ) // block_n
133
- k_tiles_w1 = (K + block_k - 1 ) // block_k
134
- k_tiles_w2 = (N + block_k - 1 ) // block_k
135
-
136
- w1_s = torch .rand (
137
- (E , n_tiles_w1 , k_tiles_w1 ), dtype = torch .float32 ) * factor_for_scale
138
- w2_s = torch .rand (
139
- (E , n_tiles_w2 , k_tiles_w2 ), dtype = torch .float32 ) * factor_for_scale
140
-
141
117
score = torch .randn ((M , E ), dtype = dtype )
142
118
119
+ _ , w1 , w1_s , _ , w2 , w2_s = make_test_weights (E , N , K , dtype , torch .float8_e4m3fn ,
120
+ per_act_token_quant = False ,
121
+ block_shape = block_size )
122
+
143
123
m_fused_moe = modular_triton_fused_moe (use_fp8_w8a8 = True ,
144
124
use_int8_w8a8 = False ,
145
125
use_int8_w8a16 = False ,
146
126
use_int4_w4a16 = False ,
147
127
per_act_token_quant = False ,
148
128
block_shape = block_size )
149
129
130
+ topk_weights , topk_ids , _ = fused_topk (a , score .float (), topk , False )
131
+
150
132
# Set the context to avoid lots of warning spam.
151
133
with set_current_vllm_config (vllm_config ):
152
- out = fused_moe (
134
+ ref_out = torch_w8a8_block_fp8_moe (
153
135
a ,
154
136
w1 ,
155
137
w2 ,
156
- score ,
157
- topk ,
158
- renormalize = False ,
138
+ w1_s ,
139
+ w2_s ,
140
+ topk_weights ,
141
+ topk_ids ,
142
+ block_size ,
143
+ )
144
+
145
+ out = fused_experts (
146
+ a ,
147
+ w1 ,
148
+ w2 ,
149
+ topk_weights ,
150
+ topk_ids ,
159
151
use_fp8_w8a8 = True ,
160
152
w1_scale = w1_s ,
161
153
w2_scale = w2_s ,
162
154
block_shape = block_size ,
163
155
)
164
- ref_out = torch_w8a8_block_fp8_moe (a , w1 , w2 , w1_s , w2_s , score , topk ,
165
- block_size )
166
156
167
- topk_weights , topk_ids , _ = fused_topk (a , score , topk , False )
168
- m_out = m_fused_moe (a ,
169
- w1 ,
170
- w2 ,
171
- topk_weights ,
172
- topk_ids ,
173
- global_num_experts = E ,
174
- w1_scale = w1_s ,
175
- w2_scale = w2_s )
176
-
177
- #print(f"{out.sum()=}")
178
- #print(f"{ref_out.sum()=}")
179
-
180
- rel_diff = (torch .mean (
181
- torch .abs (out .to (torch .float32 ) - ref_out .to (torch .float32 ))) /
182
- torch .mean (torch .abs (ref_out .to (torch .float32 ))))
183
- assert rel_diff < 0.03
157
+ m_out = m_fused_moe (
158
+ a ,
159
+ w1 ,
160
+ w2 ,
161
+ topk_weights ,
162
+ topk_ids ,
163
+ w1_scale = w1_s ,
164
+ w2_scale = w2_s ,
165
+ )
184
166
185
- rel_diff = (torch .mean (
186
- torch .abs (m_out .to (torch .float32 ) - ref_out .to (torch .float32 ))) /
187
- torch .mean (torch .abs (ref_out .to (torch .float32 ))))
188
- assert rel_diff < 0.03
167
+ torch .testing .assert_close (out , ref_out , atol = 0.03 , rtol = 0.03 )
168
+ torch .testing .assert_close (m_out , ref_out , atol = 0.03 , rtol = 0.03 )
189
169
190
170
191
171
def fp8_perm (m , idx ):
@@ -221,15 +201,13 @@ def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
221
201
return (tmp_out * topk_weight .view (M , - 1 , 1 ).to (out .dtype )).sum (dim = 1 )
222
202
223
203
224
- def deep_gemm_w8a8_block_fp8_moe (M , K , a , w1 , w2 , w1_s , w2_s , score , topk ,
204
+ def deep_gemm_w8a8_block_fp8_moe (M , K , a , w1 , w2 , w1_s , w2_s , topk_weight , topk_ids ,
225
205
block_shape ):
226
206
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
227
207
num_groups = w1 .shape [0 ]
228
208
M , K = a .shape
229
209
N = w2 .shape [- 1 ]
230
-
231
- topk_weight , topk_ids , token_expert_indices = fused_topk (
232
- a , score .float (), topk , False )
210
+ topk = topk_ids .size (1 )
233
211
234
212
block_m = deep_gemm .get_m_alignment_for_contiguous_layout ()
235
213
@@ -282,40 +260,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
282
260
block_size = [block_m , block_m ]
283
261
dtype = torch .bfloat16
284
262
285
- fp8_info = torch .finfo (torch .float8_e4m3fn )
286
- fp8_max , fp8_min = fp8_info .max , fp8_info .min
287
-
288
263
a = torch .randn ((M , K ), dtype = dtype ) / 10
289
-
290
- w1_bf16 = ((torch .rand ((E , 2 * N , K ), dtype = torch .bfloat16 ) - 0.5 ) * 2 *
291
- fp8_max ).clamp (min = fp8_min , max = fp8_max )
292
-
293
- w2_bf16 = ((torch .rand ((E , K , N ), dtype = torch .bfloat16 ) - 0.5 ) * 2 *
294
- fp8_max ).clamp (min = fp8_min , max = fp8_max )
295
-
296
264
score = torch .randn ((M , E ), dtype = dtype )
297
265
298
- block_n , block_k = block_size [0 ], block_size [1 ]
299
- n_tiles_w1 = ((2 * N ) + block_n - 1 ) // block_n
300
- k_tiles_w1 = (K + block_k - 1 ) // block_k
301
- n_tiles_w2 = (K + block_n - 1 ) // block_n
302
- k_tiles_w2 = (N + block_k - 1 ) // block_k
303
-
304
- w1 = torch .empty_like (w1_bf16 , dtype = torch .float8_e4m3fn )
305
- w2 = torch .empty_like (w2_bf16 , dtype = torch .float8_e4m3fn )
306
-
307
- w1_s = torch .empty ((E , n_tiles_w1 , k_tiles_w1 ), dtype = torch .float32 )
308
- w2_s = torch .empty ((E , n_tiles_w2 , k_tiles_w2 ), dtype = torch .float32 )
309
-
310
- w1_s = deep_gemm .get_col_major_tma_aligned_tensor (w1_s ).contiguous ()
311
- w2_s = deep_gemm .get_col_major_tma_aligned_tensor (w2_s ).contiguous ()
312
-
313
- assert w1_s .shape == (E , (2 * N + 127 ) // 128 , (K + 127 ) // 128 )
314
- assert (w2 .shape [- 2 ] + block_n - 1 ) // block_n == w2_s .shape [- 2 ]
315
-
316
- for i in range (E ):
317
- w1 [i ], w1_s [i ] = per_block_cast_to_fp8 (w1_bf16 [i ])
318
- w2 [i ], w2_s [i ] = per_block_cast_to_fp8 (w2_bf16 [i ])
266
+ _ , w1 , w1_s , _ , w2 , w2_s = make_test_weights (E , N , K , dtype , torch .float8_e4m3fn ,
267
+ per_act_token_quant = False ,
268
+ block_shape = block_size )
319
269
320
270
# Note: for now use_compile will error out if the problem size is
321
271
# large enough to trigger chunking. I'm leaving the flag and
@@ -325,17 +275,16 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
325
275
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
326
276
and current_platform .is_cuda_alike ())
327
277
278
+ topk_weights , topk_ids , _ = fused_topk (a , score .float (), topk , False )
279
+
328
280
# Set the context to avoid lots of warning spam.
329
281
with set_current_vllm_config (vllm_config ):
330
- if M >= 128 :
282
+ if False and M >= 128 :
331
283
ref_out = deep_gemm_w8a8_block_fp8_moe (M , K , a , w1 , w2 , w1_s , w2_s ,
332
- score , topk , block_size )
284
+ topk_weights , topk_ids , block_size )
333
285
else :
334
- ref_out = torch_w8a8_block_fp8_moe (a , w1 , w2 , w1_s , w2_s , score ,
335
- topk , block_size )
336
-
337
- topk_weights , topk_ids , token_expert_indices = fused_topk (
338
- a , score .float (), topk , False )
286
+ ref_out = torch_w8a8_block_fp8_moe (a , w1 , w2 , w1_s , w2_s , topk_weights ,
287
+ topk_ids , block_size )
339
288
340
289
if use_compile :
341
290
deep_gemm_moe_fp8_fn = torch .compile (deep_gemm_moe_fp8 ,
@@ -361,11 +310,4 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
361
310
graph .replay ()
362
311
torch .cuda .synchronize ()
363
312
364
- #print(f"{out.sum()=}")
365
- #print(f"{ref_out.sum()=}")
366
-
367
- rel_diff = (torch .mean (
368
- torch .abs (out .to (torch .float32 ) - ref_out .to (torch .float32 ))) /
369
- torch .mean (torch .abs (ref_out .to (torch .float32 ))))
370
-
371
- assert rel_diff < 0.03
313
+ torch .testing .assert_close (out , ref_out , atol = 0.03 , rtol = 0.03 )
0 commit comments