|
16 | 16 | _valid_deep_gemm_shape, deep_gemm_moe_fp8)
|
17 | 17 | from vllm.model_executor.layers.fused_moe.fused_moe import (
|
18 | 18 | fused_topk, modular_triton_fused_moe)
|
19 |
| -from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( |
20 |
| - moe_align_block_size) |
21 |
| -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( |
22 |
| - per_token_group_quant_fp8) |
23 | 19 | from vllm.platforms import current_platform
|
24 | 20 |
|
25 | 21 | dg_available = False
|
|
39 | 35 |
|
40 | 36 | # Test configurations
|
41 | 37 | DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
42 |
| -NUM_TOKENS = [7, 2050] |
43 |
| -D = [512, 4096, 5120, 13824] |
44 |
| -GROUP_SIZE = [64, 128, 512] |
45 | 38 | # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
46 | 39 | # and its hidden size is 7168.
|
47 |
| -M = [1, 2, 83, 128, 2048, 40000] |
| 40 | +M = [1, 83, 128, 2048, 8192] |
48 | 41 | M_dg = [128, 192, 1335, 2048]
|
49 |
| -N = [128, 256, 1024, 4608] # [13824] |
50 |
| -K = [256, 512, 7168] # [13824] |
| 42 | +N = [128, 256, 1024, 4608] |
| 43 | +K = [256, 512, 7168] |
51 | 44 | BLOCK_SIZE = [[128, 128]]
|
52 |
| -E = [2, 8, 16, 24] # [128, 256] |
| 45 | +E = [2, 8, 16] # [128, 256] |
53 | 46 | TOP_KS = [1, 2, 6]
|
54 |
| -OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] |
55 | 47 | SEEDS = [0]
|
56 | 48 |
|
57 | 49 |
|
@@ -111,7 +103,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
111 | 103 |
|
112 | 104 | torch.manual_seed(seed)
|
113 | 105 |
|
114 |
| - monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") |
| 106 | + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048") |
115 | 107 |
|
116 | 108 | a = torch.randn((M, K), dtype=dtype) / 10
|
117 | 109 | score = torch.randn((M, E), dtype=dtype)
|
@@ -174,76 +166,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
174 | 166 | torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
|
175 | 167 |
|
176 | 168 |
|
177 |
| -def fp8_perm(m, idx): |
178 |
| - if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: |
179 |
| - return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) |
180 |
| - else: |
181 |
| - return m[idx, ...] |
182 |
| - |
183 |
| - |
184 |
| -def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): |
185 |
| - M, K = a.shape |
186 |
| - |
187 |
| - sorted_token_ids, m_indices, num_pad = moe_align_block_size( |
188 |
| - topk_ids, block_m, num_groups, None, pad_sorted_ids=True) |
189 |
| - |
190 |
| - num_tokens = topk * M |
191 |
| - |
192 |
| - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) |
193 |
| - m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) |
194 |
| - inv_perm = torch.argsort(sorted_token_ids)[:M * topk] |
195 |
| - |
196 |
| - a = fp8_perm(a, sorted_token_ids // topk) |
197 |
| - if a_s is not None: |
198 |
| - a_s = a_s[sorted_token_ids // topk] |
199 |
| - |
200 |
| - return a, a_s, m_indices, inv_perm |
201 |
| - |
202 |
| - |
203 |
| -def _moe_unpermute(out, inv_perm, topk, K, topk_weight): |
204 |
| - M = topk_weight.shape[0] |
205 |
| - out = out[inv_perm, ...] |
206 |
| - tmp_out = out.view(-1, topk, K) |
207 |
| - return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) |
208 |
| - |
209 |
| - |
210 |
| -def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, topk_weight, |
211 |
| - topk_ids, block_shape): |
212 |
| - """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" |
213 |
| - num_groups = w1.shape[0] |
214 |
| - M, K = a.shape |
215 |
| - N = w2.shape[-1] |
216 |
| - topk = topk_ids.size(1) |
217 |
| - |
218 |
| - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() |
219 |
| - |
220 |
| - _, block_k = block_shape[0], block_shape[1] |
221 |
| - |
222 |
| - a_q, a_s = per_token_group_quant_fp8(a, block_m) |
223 |
| - |
224 |
| - a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids, |
225 |
| - num_groups, topk, block_m) |
226 |
| - |
227 |
| - inter_out = torch.zeros((a_q.shape[0], N * 2), |
228 |
| - dtype=torch.bfloat16, |
229 |
| - device=a.device) |
230 |
| - |
231 |
| - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), |
232 |
| - inter_out, m_indices) |
233 |
| - |
234 |
| - act_out = SiluAndMul().forward_native(inter_out) |
235 |
| - act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) |
236 |
| - |
237 |
| - out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) |
238 |
| - |
239 |
| - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( |
240 |
| - (act_out_q, act_out_s), (w2, w2_s), out, m_indices) |
241 |
| - |
242 |
| - final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight) |
243 |
| - |
244 |
| - return final_out |
245 |
| - |
246 |
| - |
247 | 169 | @pytest.mark.parametrize("M,N,K,E,topk,seed",
|
248 | 170 | itertools.product(M_dg, N, K, E, TOP_KS, SEEDS))
|
249 | 171 | @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
@@ -289,14 +211,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
289 | 211 |
|
290 | 212 | # Set the context to avoid lots of warning spam.
|
291 | 213 | with set_current_vllm_config(vllm_config):
|
292 |
| - if M >= 128: |
293 |
| - ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, |
294 |
| - topk_weights, topk_ids, |
295 |
| - block_size) |
296 |
| - else: |
297 |
| - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, |
298 |
| - topk_weights, topk_ids, |
299 |
| - block_size) |
| 214 | + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights, |
| 215 | + topk_ids, block_size) |
300 | 216 |
|
301 | 217 | if use_compile:
|
302 | 218 | deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
|
|
0 commit comments