Skip to content

Commit 90ea3c7

Browse files
committed
per token + grouped broken
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent d7bb199 commit 90ea3c7

File tree

3 files changed

+15
-18
lines changed

3 files changed

+15
-18
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def make_tensors(config: BatchedMMConfig):
7777
@pytest.mark.parametrize(
7878
"dtype",
7979
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
80-
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
81-
@pytest.mark.parametrize("per_act_token_quant", [False, True])
80+
@pytest.mark.parametrize("block_shape", [None])#, [128, 128]])
81+
@pytest.mark.parametrize("per_act_token_quant", [False])#, True])
8282
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
8383
N: int, dtype: torch.dtype,
8484
block_shape: Optional[list[int]],

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,15 @@ def moe_mmk(
6363
if use_w8a8:
6464
# block-wise
6565
if group_k > 0 and group_n > 0:
66-
a_scale_ptrs = a_scale_ptr + (offs_m * stride_asm
67-
) #+ (expert_id * stride_ase)
66+
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm #+ (expert_id * stride_ase)
6867
offs_bsn = offs_n // group_n
69-
b_scale_ptrs = (b_scale_ptr +
70-
offs_bsn * stride_bsn) + expert_id * stride_bse
68+
b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse +
69+
offs_bsn * stride_bsn)
7170

7271
# channel-wise
7372
elif per_channel_quant:
7473
# TODO: probably not correct
75-
b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[
76-
None, :] * stride_bsn
74+
b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[None, :] * stride_bsn
7775
b_scale = tl.load(b_scale_ptrs)
7876
# Load per-token scale for activations
7977
# + (expert_id * stride_ase)??
@@ -300,16 +298,14 @@ def batched_triton_kernel(
300298
cta_n_start * stride_cn)
301299

302300
if use_fp8_w8a8:
301+
a_scale_ptr = a_scale_ptr + (expert_id * stride_ase)
303302
# block-wise
304-
if (group_k > 0 and group_n > 0) or per_channel_quant:
305-
a_scale_ptr = a_scale_ptr + (expert_id *
306-
stride_ase) + cta_m_start * stride_asm
307-
#b_scale_ptr = b_scale_ptr + (expert_id * stride_bse)
308-
# (?) b_scale_ptr = b_scale_ptr + cta_n_start * stride_bsn
309-
# channel-wise or tensor-wise
310-
else:
311-
a_scale_ptr = a_scale_ptr + (expert_id * stride_ase)
312-
#b_scale_ptr = b_scale_ptr + (expert_id * stride_bse)
303+
if group_k > 0 and group_n > 0:
304+
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
305+
b_scale_ptr = b_scale_ptr + (expert_id * stride_bse)
306+
elif per_channel_quant:
307+
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
308+
b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) + cta_n_start * stride_bsn
313309

314310
expert_triton_kernel(
315311
a_ptr,
@@ -532,6 +528,7 @@ def prepare(
532528
self.max_num_tokens,
533529
hidden_dim)
534530

531+
# empty?
535532
b_a1_scale = torch.zeros(scale_shape,
536533
dtype=torch.float32,
537534
device=a1.device)

vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ def prepare(
136136
else:
137137
assert a1q_scale.numel() == a1.shape[0] * cdiv(a1.shape[1], quant_config.block_shape[1])
138138
assert a1q_scale.shape == (a1.shape[0], cdiv(a1.shape[1], quant_config.block_shape[1]))
139-
#a1q_scale = group_broadcast(scale, a1q.shape)
140139

141140
if a1q_scale is not None:
142141
scalar_scales = a1q_scale.numel() == 1
@@ -208,6 +207,7 @@ def prepare(
208207

209208
#print(f"EXPERT_X_SCALE {expert_x_scale_shape}")
210209

210+
# empty?
211211
expert_x_scale = torch.zeros(
212212
expert_x_scale_shape,
213213
dtype=torch.float32,

0 commit comments

Comments
 (0)