Skip to content

Commit 87b969c

Browse files
committed
fix test_block_fp8.py test
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 52f935c commit 87b969c

File tree

1 file changed

+7
-91
lines changed

1 file changed

+7
-91
lines changed

tests/kernels/moe/test_block_fp8.py

Lines changed: 7 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@
1616
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
1717
from vllm.model_executor.layers.fused_moe.fused_moe import (
1818
fused_topk, modular_triton_fused_moe)
19-
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
20-
moe_align_block_size)
21-
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
22-
per_token_group_quant_fp8)
2319
from vllm.platforms import current_platform
2420

2521
dg_available = False
@@ -39,19 +35,15 @@
3935

4036
# Test configurations
4137
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
42-
NUM_TOKENS = [7, 2050]
43-
D = [512, 4096, 5120, 13824]
44-
GROUP_SIZE = [64, 128, 512]
4538
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
4639
# and its hidden size is 7168.
47-
M = [1, 2, 83, 128, 2048, 40000]
40+
M = [1, 83, 128, 2048, 8192]
4841
M_dg = [128, 192, 1335, 2048]
49-
N = [128, 256, 1024, 4608] # [13824]
50-
K = [256, 512, 7168] # [13824]
42+
N = [128, 256, 1024, 4608]
43+
K = [256, 512, 7168]
5144
BLOCK_SIZE = [[128, 128]]
52-
E = [2, 8, 16, 24] # [128, 256]
45+
E = [2, 8, 16] # [128, 256]
5346
TOP_KS = [1, 2, 6]
54-
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
5547
SEEDS = [0]
5648

5749

@@ -111,7 +103,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
111103

112104
torch.manual_seed(seed)
113105

114-
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
106+
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
115107

116108
a = torch.randn((M, K), dtype=dtype) / 10
117109
score = torch.randn((M, E), dtype=dtype)
@@ -174,76 +166,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
174166
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
175167

176168

177-
def fp8_perm(m, idx):
178-
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
179-
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
180-
else:
181-
return m[idx, ...]
182-
183-
184-
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
185-
M, K = a.shape
186-
187-
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
188-
topk_ids, block_m, num_groups, None, pad_sorted_ids=True)
189-
190-
num_tokens = topk * M
191-
192-
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
193-
m_indices = torch.repeat_interleave(m_indices, block_m, dim=0)
194-
inv_perm = torch.argsort(sorted_token_ids)[:M * topk]
195-
196-
a = fp8_perm(a, sorted_token_ids // topk)
197-
if a_s is not None:
198-
a_s = a_s[sorted_token_ids // topk]
199-
200-
return a, a_s, m_indices, inv_perm
201-
202-
203-
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
204-
M = topk_weight.shape[0]
205-
out = out[inv_perm, ...]
206-
tmp_out = out.view(-1, topk, K)
207-
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
208-
209-
210-
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, topk_weight,
211-
topk_ids, block_shape):
212-
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
213-
num_groups = w1.shape[0]
214-
M, K = a.shape
215-
N = w2.shape[-1]
216-
topk = topk_ids.size(1)
217-
218-
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
219-
220-
_, block_k = block_shape[0], block_shape[1]
221-
222-
a_q, a_s = per_token_group_quant_fp8(a, block_m)
223-
224-
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
225-
num_groups, topk, block_m)
226-
227-
inter_out = torch.zeros((a_q.shape[0], N * 2),
228-
dtype=torch.bfloat16,
229-
device=a.device)
230-
231-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
232-
inter_out, m_indices)
233-
234-
act_out = SiluAndMul().forward_native(inter_out)
235-
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
236-
237-
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
238-
239-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
240-
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
241-
242-
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
243-
244-
return final_out
245-
246-
247169
@pytest.mark.parametrize("M,N,K,E,topk,seed",
248170
itertools.product(M_dg, N, K, E, TOP_KS, SEEDS))
249171
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@@ -289,14 +211,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
289211

290212
# Set the context to avoid lots of warning spam.
291213
with set_current_vllm_config(vllm_config):
292-
if M >= 128:
293-
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
294-
topk_weights, topk_ids,
295-
block_size)
296-
else:
297-
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s,
298-
topk_weights, topk_ids,
299-
block_size)
214+
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
215+
topk_ids, block_size)
300216

301217
if use_compile:
302218
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,

0 commit comments

Comments
 (0)