Skip to content

Commit c52a3a0

Browse files
committed
blocked working
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent a4f56fb commit c52a3a0

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def make_tensors(config: BatchedMMConfig):
7777
@pytest.mark.parametrize(
7878
"dtype",
7979
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
80-
@pytest.mark.parametrize("block_shape", [[128, 128]]) # [None])#, [128, 128]])
80+
@pytest.mark.parametrize("block_shape", [None, [128, 128]]) # [None])#, [128, 128]])
8181
@pytest.mark.parametrize("per_act_token_quant", [False, True])# [False])# ,True])
8282
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
8383
N: int, dtype: torch.dtype,
@@ -141,8 +141,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
141141

142142
assert A_q.dtype == B_q.dtype
143143

144-
#A_scale.fill_(1)
145-
B_scale.fill_(1)
144+
#B_scale.fill_(0.5)
146145

147146
invoke_moe_batched_triton_kernel(
148147
A_q,

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def moe_mmk(
3939
# Offsets and masks
4040
offs_m,
4141
offs_n,
42+
offs_bn,
4243
mask_m,
4344
# Block size for block-wise quantization
4445
group_n: tl.constexpr,
@@ -64,7 +65,7 @@ def moe_mmk(
6465
# block-wise
6566
if group_k > 0 and group_n > 0:
6667
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
67-
offs_bsn = offs_n // group_n
68+
offs_bsn = offs_bn // group_n
6869
b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bsn
6970

7071
# per act token
@@ -142,7 +143,7 @@ def moe_mmk(
142143
elif use_w8a8:
143144
if group_k > 0 and group_n > 0:
144145
accumulator = accumulator.to(compute_type)
145-
elif True or not per_act_token_quant:
146+
else: #if True or not per_act_token_quant:
146147
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
147148
else:
148149
accumulator = accumulator.to(compute_type)
@@ -178,6 +179,8 @@ def expert_triton_kernel(
178179
stride_bse,
179180
stride_bsk,
180181
stride_bsn,
182+
# offsets
183+
offs_bn,
181184
# Blockwise quantization data
182185
group_n,
183186
group_k,
@@ -222,6 +225,7 @@ def expert_triton_kernel(
222225
# Offsets and masks
223226
offs_m,
224227
offs_n,
228+
offs_bn,
225229
mask_m,
226230
# Block size for block-wise quantization
227231
group_n,
@@ -315,12 +319,15 @@ def batched_triton_kernel(
315319
c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
316320
cta_n_start * stride_cn)
317321

322+
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(tl.int64)) % N
323+
318324
if use_fp8_w8a8:
319325
a_scale_ptr = a_scale_ptr + expert_id * stride_ase
320326
b_scale_ptr = b_scale_ptr + expert_id * stride_bse
321327
# block-wise
322328
if group_k > 0 and group_n > 0:
323329
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
330+
#b_scale_ptr = b_scale_ptr + offs_bn * stride_bsn
324331
# b group advancement?
325332
elif False and per_act_token_quant:
326333
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
@@ -351,6 +358,8 @@ def batched_triton_kernel(
351358
stride_bse,
352359
stride_bsk,
353360
stride_bsn,
361+
# offsets
362+
offs_bn,
354363
# Blockwise quantization data
355364
group_n,
356365
group_k,
@@ -404,12 +413,13 @@ def invoke_moe_batched_triton_kernel(
404413
if B_scale is not None:
405414
if B_scale.ndim == 1:
406415
stride_bse = 1
407-
stride_bsn = 0
408416
stride_bsk = 0
417+
stride_bsn = 0
409418
else:
410419
stride_bse = B_scale.stride(0)
411-
stride_bsn = B_scale.stride(1)
412420
stride_bsk = B_scale.stride(2)
421+
stride_bsn = B_scale.stride(1)
422+
413423
else:
414424
stride_bse = 0
415425
stride_bsk = 0

0 commit comments

Comments
 (0)