8
8
import torch
9
9
import triton .language as tl
10
10
11
- import vllm ._custom_ops as ops
11
+ from tests .kernels .moe .utils import (batched_moe , make_test_weights ,
12
+ torch_moe2 , triton_moe )
13
+ from tests .kernels .quant_utils import native_w8a8_block_matmul
12
14
from vllm .config import VllmConfig , set_current_vllm_config
13
15
from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
14
- BatchedPrepareAndFinalize , BatchedTritonExperts ,
15
16
invoke_moe_batched_triton_kernel )
16
- from vllm .model_executor .layers .fused_moe .utils import (
17
- moe_kernel_quantize_input )
18
17
from vllm .model_executor .layers .fused_moe .fused_moe import fused_topk
19
- from vllm .model_executor .layers .fused_moe .modular_kernel import (
20
- FusedMoEModularKernel )
21
18
from vllm .model_executor .layers .quantization .utils .fp8_utils import (
22
- w8a8_block_fp8_matmul ,
23
19
per_token_group_quant_fp8 )
24
20
from vllm .platforms import current_platform
25
- from tests .kernels .quant_utils import native_w8a8_block_matmul
26
- from tests .kernels .moe .utils import (
27
- torch_moe2 ,
28
- triton_moe ,
29
- batched_moe ,
30
- make_test_weights ,
31
- )
32
21
33
22
NUM_EXPERTS = [8 , 64 ]
34
23
TOP_KS = [1 , 2 , 6 ]
@@ -104,18 +93,13 @@ def ref_impl(
104
93
for e in range (num_experts ):
105
94
num_tokens = num_expert_tokens_cpu [e ]
106
95
if A .dtype .itemsize == 1 and block_shape is not None :
107
- tmp = native_w8a8_block_matmul (A [e ],
108
- B [e ],
109
- A_scale [e ],
110
- B_scale [e ],
111
- block_shape ,
112
- C .dtype )
96
+ tmp = native_w8a8_block_matmul (A [e ], B [e ], A_scale [e ], B_scale [e ],
97
+ block_shape , C .dtype )
113
98
C [e , :num_tokens , :] = tmp [:num_tokens , :]
114
99
elif A .dtype .itemsize == 1 and block_shape is None :
115
100
C [e , :num_tokens , :] = (
116
- (A [e , :num_tokens , :].to (f32 ) * A_scale [e ]).to (bf16 ) @
117
- (B [e ].transpose (0 , 1 ).to (f32 ) * B_scale [e ]).to (bf16 )
118
- )
101
+ (A [e , :num_tokens , :].to (f32 ) * A_scale [e ]).to (bf16 )
102
+ @ (B [e ].transpose (0 , 1 ).to (f32 ) * B_scale [e ]).to (bf16 ))
119
103
else :
120
104
assert A_scale is None
121
105
assert B_scale is None
@@ -124,7 +108,8 @@ def ref_impl(
124
108
return C
125
109
126
110
127
- def make_quantized_test_activations (E , m , k , dtype , block_shape , per_act_token ):
111
+ def make_quantized_test_activations (E , m , k , dtype , block_shape ,
112
+ per_act_token ):
128
113
assert not per_act_token , "NYI"
129
114
130
115
a_type = torch .bfloat16 if dtype == torch .float8_e4m3fn else dtype
@@ -138,9 +123,11 @@ def make_quantized_test_activations(E, m, k, dtype, block_shape, per_act_token):
138
123
a_scale = [None ] * E
139
124
for e in range (E ):
140
125
if block_shape is not None :
141
- a_q [e ], a_scale [e ] = per_token_group_quant_fp8 (a [e ], block_shape [1 ])
126
+ a_q [e ], a_scale [e ] = per_token_group_quant_fp8 (
127
+ a [e ], block_shape [1 ])
142
128
else :
143
- a_tmp , a_scale [e ] = per_token_group_quant_fp8 (a [e ].view (1 , - 1 ), a [e ].numel ())
129
+ a_tmp , a_scale [e ] = per_token_group_quant_fp8 (
130
+ a [e ].view (1 , - 1 ), a [e ].numel ())
144
131
a_q [e ] = a_tmp .view (* a [e ].shape )
145
132
a_scale = torch .stack (a_scale )
146
133
@@ -173,14 +160,10 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
173
160
device = "cuda" ,
174
161
dtype = torch .int32 )
175
162
176
- A , A_q , A_scale = make_quantized_test_activations (
177
- num_experts ,
178
- max_tokens_per_expert ,
179
- K ,
180
- dtype ,
181
- block_shape ,
182
- per_act_token_quant
183
- )
163
+ A , A_q , A_scale = make_quantized_test_activations (num_experts ,
164
+ max_tokens_per_expert , K ,
165
+ dtype , block_shape ,
166
+ per_act_token_quant )
184
167
185
168
B_q , _ , B_scale , _ , B , _ = make_test_weights (
186
169
num_experts ,
@@ -206,7 +189,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
206
189
207
190
#print(f"A {use_fp8_w8a8} {A_q.dtype} {B_q.dtype} {A_scale.shape} {B_scale.shape}")
208
191
if False :
209
- from vllm .model_executor .layers .fused_moe .batched_moe2 import fused_moe_kernel2
192
+ from vllm .model_executor .layers .fused_moe .batched_moe2 import (
193
+ fused_moe_kernel2 )
210
194
fused_moe_kernel2 (
211
195
A_q ,
212
196
B_q ,
@@ -238,7 +222,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
238
222
config_block_shape [1 ],
239
223
config_block_shape [2 ],
240
224
1 ,
241
- 1 , # topk hack
225
+ 1 , # topk hack
242
226
compute_tl_dtype ,
243
227
use_fp8_w8a8 ,
244
228
False ,
@@ -279,15 +263,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
279
263
None ,
280
264
)
281
265
282
- q_ref_output = ref_impl (
283
- A_q ,
284
- B_q ,
285
- q_ref_output ,
286
- num_expert_tokens ,
287
- A_scale ,
288
- B_scale ,
289
- block_shape
290
- )
266
+ q_ref_output = ref_impl (A_q , B_q , q_ref_output , num_expert_tokens , A_scale ,
267
+ B_scale , block_shape )
291
268
292
269
rtol , atol = {
293
270
torch .float16 : (6e-2 , 6e-2 ),
@@ -393,11 +370,14 @@ def test_fused_moe_batched_experts(
393
370
with set_current_vllm_config (vllm_config ):
394
371
topk_weight , topk_ids , _ = fused_topk (a , score , topk , False )
395
372
batched_output = batched_moe (a , w1 , w2 , topk_weight , topk_ids , w1_s ,
396
- w2_s , quant_type , per_act_token_quant , block_shape )
373
+ w2_s , quant_type , per_act_token_quant ,
374
+ block_shape )
397
375
baseline_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids , w1_s ,
398
- w2_s , quant_type , per_act_token_quant , block_shape )
376
+ w2_s , quant_type , per_act_token_quant ,
377
+ block_shape )
399
378
triton_output = triton_moe (a , w1 , w2 , topk_weight , topk_ids , w1_s ,
400
- w2_s , quant_type , per_act_token_quant , block_shape )
379
+ w2_s , quant_type , per_act_token_quant ,
380
+ block_shape )
401
381
402
382
torch .testing .assert_close (triton_output ,
403
383
baseline_output ,
@@ -411,4 +391,4 @@ def test_fused_moe_batched_experts(
411
391
torch .testing .assert_close (triton_output ,
412
392
batched_output ,
413
393
atol = 2e-2 ,
414
- rtol = 2e-2 ) # 0
394
+ rtol = 2e-2 ) # 0
0 commit comments