11
11
from tests .kernels .moe .utils import (batched_moe ,
12
12
make_quantized_test_activations ,
13
13
make_test_weights , triton_moe )
14
- from tests .kernels .quant_utils import native_w8a8_block_matmul
14
+ from tests .kernels .quant_utils import native_batched_masked_quant_matmul
15
15
from tests .kernels .utils import torch_experts
16
16
from vllm .config import VllmConfig , set_current_vllm_config
17
17
from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
@@ -68,43 +68,6 @@ def make_tensors(config: BatchedMMConfig):
68
68
return BatchedMMTensors (A , B , C , num_expert_tokens )
69
69
70
70
71
- def ref_impl (
72
- A : torch .Tensor ,
73
- B : torch .Tensor ,
74
- C : torch .Tensor ,
75
- num_expert_tokens : torch .Tensor ,
76
- A_scale : Optional [torch .Tensor ],
77
- B_scale : Optional [torch .Tensor ],
78
- block_shape : Optional [list [int ]],
79
- ) -> torch .Tensor :
80
- assert (A .dtype .itemsize > 1
81
- or (A_scale is not None and B_scale is not None ))
82
-
83
- num_expert_tokens_cpu = num_expert_tokens .clone ()
84
- num_expert_tokens_cpu = num_expert_tokens_cpu .to (device = "cpu" )
85
- num_experts = num_expert_tokens .size (0 )
86
-
87
- f32 = torch .float32
88
- bf16 = torch .bfloat16
89
-
90
- for e in range (num_experts ):
91
- num_tokens = num_expert_tokens_cpu [e ]
92
- if A .dtype .itemsize == 1 and block_shape is not None :
93
- tmp = native_w8a8_block_matmul (A [e ], B [e ], A_scale [e ], B_scale [e ],
94
- block_shape , C .dtype )
95
- C [e , :num_tokens , :] = tmp [:num_tokens , :]
96
- elif A .dtype .itemsize == 1 and block_shape is None :
97
- C [e , :num_tokens , :] = (
98
- (A [e , :num_tokens , :].to (f32 ) * A_scale [e ]).to (bf16 )
99
- @ (B [e ].transpose (0 , 1 ).to (f32 ) * B_scale [e ]).to (bf16 ))
100
- else :
101
- assert A_scale is None
102
- assert B_scale is None
103
- C [e , :num_tokens , :] = A [e , :num_tokens , :] @ B [e ].transpose (0 , 1 )
104
-
105
- return C
106
-
107
-
108
71
@pytest .mark .parametrize ("num_experts" , [8 , 16 , 32 ])
109
72
@pytest .mark .parametrize ("max_tokens_per_expert" ,
110
73
[32 , 64 , 128 , 192 , 224 , 256 , 512 ])
@@ -193,7 +156,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
193
156
block_shape = block_shape ,
194
157
)
195
158
196
- ref_output = ref_impl (
159
+ ref_output = native_batched_masked_quant_matmul (
197
160
A ,
198
161
B ,
199
162
ref_output ,
@@ -203,8 +166,10 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
203
166
None ,
204
167
)
205
168
206
- q_ref_output = ref_impl (A_q , B_q , q_ref_output , num_expert_tokens , A_scale ,
207
- B_scale , block_shape )
169
+ q_ref_output = native_batched_masked_quant_matmul (A_q , B_q , q_ref_output ,
170
+ num_expert_tokens ,
171
+ A_scale , B_scale ,
172
+ block_shape )
208
173
209
174
rtol , atol = {
210
175
torch .float16 : (6e-2 , 6e-2 ),
0 commit comments