Skip to content

Commit 9cfebf5

Browse files
committed
basic working test
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 77f95b9 commit 9cfebf5

File tree

2 files changed

+54
-51
lines changed

2 files changed

+54
-51
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,11 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
198198
#B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
199199
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
200200
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
201+
quant_block_shape = [1, 1]
201202
else:
202203
A_scale = None
203204
B_scale = None
205+
quant_block_shape = None
204206

205207
invoke_moe_batched_triton_kernel(
206208
tensors.A,
@@ -220,7 +222,9 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
220222
"BLOCK_SIZE_M": block_shape[0],
221223
"BLOCK_SIZE_N": block_shape[1],
222224
"BLOCK_SIZE_K": block_shape[2],
223-
})
225+
},
226+
block_shape=quant_block_shape,
227+
)
224228

225229
ref_output = ref_output.to(dtype=out_dtype)
226230
ref_output = ref_impl(tensors.A.to(dtype=out_dtype),
@@ -246,5 +250,4 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
246250
}[test_output.dtype]
247251

248252
torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol)
249-
if not use_fp8_w8a8:
250-
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)
253+
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,38 +15,38 @@
1515

1616
@triton.jit
1717
def moe_mmk(
18-
a_ptrs,
19-
b_ptrs,
20-
K,
21-
expert_id,
22-
a_scale_ptr,
23-
b_scale_ptr,
24-
# The stride variables represent how much to increase the ptr by when
25-
# moving by 1 element in a particular dimension. E.g. `stride_am` is
26-
# how much to increase `a_ptr` by to get the element one row down
27-
# (A has M rows).
28-
stride_ak,
29-
stride_bk,
30-
stride_asm,
31-
stride_ask,
32-
stride_bse,
33-
stride_bsk,
34-
stride_bsn,
35-
# Offsets and masks
36-
offs_m,
37-
offs_n,
38-
mask_m,
39-
# Block size for block-wise quantization
40-
group_n: tl.constexpr,
41-
group_k: tl.constexpr,
42-
# Meta-parameters
43-
BLOCK_M: tl.constexpr,
44-
BLOCK_N: tl.constexpr,
45-
BLOCK_K: tl.constexpr,
46-
compute_type: tl.constexpr,
47-
use_w8a8: tl.constexpr,
48-
use_w8a16: tl.constexpr):
49-
18+
a_ptrs,
19+
b_ptrs,
20+
K,
21+
expert_id,
22+
a_scale_ptr,
23+
b_scale_ptr,
24+
# The stride variables represent how much to increase the ptr by when
25+
# moving by 1 element in a particular dimension. E.g. `stride_am` is
26+
# how much to increase `a_ptr` by to get the element one row down
27+
# (A has M rows).
28+
stride_ak,
29+
stride_bk,
30+
stride_asm,
31+
stride_ask,
32+
stride_bse,
33+
stride_bsk,
34+
stride_bsn,
35+
# Offsets and masks
36+
offs_m,
37+
offs_n,
38+
mask_m,
39+
# Block size for block-wise quantization
40+
group_n: tl.constexpr,
41+
group_k: tl.constexpr,
42+
# Meta-parameters
43+
BLOCK_M: tl.constexpr,
44+
BLOCK_N: tl.constexpr,
45+
BLOCK_K: tl.constexpr,
46+
compute_type: tl.constexpr,
47+
use_w8a8: tl.constexpr,
48+
use_w8a16: tl.constexpr
49+
):
5050
offs_k = tl.arange(0, BLOCK_K)
5151

5252
if use_w8a16:
@@ -310,22 +310,22 @@ def batched_triton_kernel(
310310

311311

312312
def invoke_moe_batched_triton_kernel(
313-
A: torch.Tensor, # [E, max_tokens, K]
314-
B: torch.Tensor, # [E, K, N]
315-
C: torch.Tensor, # [E, max_tokens, N]
316-
expert_num_tokens: torch.Tensor, # [E]
317-
compute_type: tl.dtype,
318-
# Quantization data
319-
A_scale: Optional[torch.Tensor],
320-
B_scale: Optional[torch.Tensor],
321-
B_zp: torch.Tensor,
322-
# Quantization schemes
323-
use_fp8_w8a8: bool,
324-
use_int8_w8a16: bool,
325-
use_int4_w4a16: bool,
326-
config: dict[str, int],
327-
block_shape: Optional[list[int]] = None):
328-
313+
A: torch.Tensor, # [E, max_tokens, K]
314+
B: torch.Tensor, # [E, K, N]
315+
C: torch.Tensor, # [E, max_tokens, N]
316+
expert_num_tokens: torch.Tensor, # [E]
317+
compute_type: tl.dtype,
318+
# Quantization data
319+
A_scale: Optional[torch.Tensor],
320+
B_scale: Optional[torch.Tensor],
321+
B_zp: torch.Tensor,
322+
# Quantization schemes
323+
use_fp8_w8a8: bool,
324+
use_int8_w8a16: bool,
325+
use_int4_w4a16: bool,
326+
config: dict[str, int],
327+
block_shape: Optional[list[int]] = None
328+
):
329329
assert not use_int4_w4a16
330330
max_num_tokens = A.size(1)
331331
K = A.size(2)

0 commit comments

Comments
 (0)