Skip to content

Commit dcf59cf

Browse files
committed
per_act_token working
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent c52a3a0 commit dcf59cf

File tree

3 files changed

+15
-35
lines changed

3 files changed

+15
-35
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def make_tensors(config: BatchedMMConfig):
7777
@pytest.mark.parametrize(
7878
"dtype",
7979
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
80-
@pytest.mark.parametrize("block_shape", [None, [128, 128]]) # [None])#, [128, 128]])
81-
@pytest.mark.parametrize("per_act_token_quant", [False, True])# [False])# ,True])
80+
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
81+
@pytest.mark.parametrize("per_act_token_quant", [False, True])
8282
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
8383
N: int, dtype: torch.dtype,
8484
block_shape: Optional[list[int]],
@@ -141,8 +141,6 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
141141

142142
assert A_q.dtype == B_q.dtype
143143

144-
#B_scale.fill_(0.5)
145-
146144
invoke_moe_batched_triton_kernel(
147145
A_q,
148146
B_q,
@@ -190,7 +188,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
190188
print(f"REF_OUTPUT {q_ref_output.shape}\n{q_ref_output}")
191189
print(f"TRITON {test_output.shape}\n{test_output}")
192190

193-
#torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
191+
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
194192
#torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
195193
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
196194

@@ -246,9 +244,6 @@ def test_fused_moe_batched_experts(
246244
per_act_token_quant=per_act_token_quant,
247245
)
248246

249-
# TODO remove
250-
torch.set_printoptions(profile="full")
251-
252247
with set_current_vllm_config(vllm_config):
253248
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
254249

@@ -274,9 +269,9 @@ def test_fused_moe_batched_experts(
274269
else:
275270
baseline_output = torch_experts(a, w1_16, w2_16, topk_weight, topk_ids)
276271

277-
#triton_output = triton_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
278-
# w2_s, quant_dtype, per_act_token_quant,
279-
# block_shape)
272+
triton_output = triton_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
273+
w2_s, quant_dtype, per_act_token_quant,
274+
block_shape)
280275

281276
#print(f"TORCH {baseline_output.shape}\n{baseline_output}")
282277
#print(f"TRITON {triton_output.shape}\n{triton_output}")
@@ -292,7 +287,7 @@ def test_fused_moe_batched_experts(
292287
# atol=2e-2,
293288
# rtol=2e-2)
294289

295-
# torch.testing.assert_close(triton_output,
296-
# batched_output,
297-
# atol=2e-2,
298-
# rtol=2e-2)
290+
torch.testing.assert_close(triton_output,
291+
batched_output,
292+
atol=2e-2,
293+
rtol=2e-2)

tests/kernels/moe/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,8 @@ def make_quantized_test_activations(
162162
a_q[e], a_scale[e] = per_token_group_quant_fp8(
163163
a[e], block_shape[1])
164164
else:
165-
a_tmp, a_scale[e] = per_token_group_quant_fp8(
166-
a[e].view(1, -1), a[e].numel())
167-
a_q[e] = a_tmp.view(*a[e].shape)
165+
a_q[e], a_scale[e] = ops.scaled_fp8_quant(
166+
a[e], None, use_per_token_if_dynamic=per_act_token_quant)
168167
a_scale = torch.stack(a_scale)
169168

170169
return a, a_q, a_scale

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,9 @@ def moe_mmk(
7474
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
7575
a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:,None]
7676

77-
b_scale_ptrs = b_scale_ptr + offs_n[None, :] * stride_bsn
77+
b_scale_ptrs = b_scale_ptr + offs_bn[None, :] * stride_bsn
7878
b_scale = tl.load(b_scale_ptrs)
7979

80-
81-
# Load per-token scale for activations
82-
# + (expert_id * stride_ase)??
83-
#a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
84-
#a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None]
85-
86-
# TODO: probably not correct
87-
#b_scale_ptrs = b_scale_ptr + expert_id * stride_bse #+ offs_n[None, :] * stride_bsn
88-
#b_scale = tl.load(b_scale_ptrs)
89-
9080
# tensor-wise
9181
else:
9282
a_scale = tl.load(a_scale_ptr)
@@ -134,10 +124,6 @@ def moe_mmk(
134124
a_ptrs += BLOCK_K * stride_ak
135125
b_ptrs += BLOCK_K * stride_bk
136126

137-
if False and per_act_token_quant:
138-
a_scale_ptrs += BLOCK_K * stride_ask
139-
b_scale_ptrs += BLOCK_K * stride_bsk
140-
141127
if use_w8a16:
142128
accumulator = (accumulator * b_scale).to(compute_type)
143129
elif use_w8a8:
@@ -329,9 +315,9 @@ def batched_triton_kernel(
329315
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
330316
#b_scale_ptr = b_scale_ptr + offs_bn * stride_bsn
331317
# b group advancement?
332-
elif False and per_act_token_quant:
318+
elif per_act_token_quant:
333319
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
334-
b_scale_ptr = b_scale_ptr + cta_n_start * stride_bsn
320+
# b_scale_ptr = b_scale_ptr + cta_n_start * stride_bsn
335321

336322
expert_triton_kernel(
337323
a_ptr,

0 commit comments

Comments
 (0)