Skip to content

Commit 203dece

Browse files
committed
wip test
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 680ecc5 commit 203dece

File tree

2 files changed

+77
-14
lines changed

2 files changed

+77
-14
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
133133
act_dtype = dtype
134134
quant_dtype = None
135135

136+
#print(f"TYPES {dtype}, {act_dtype}, {quant_dtype}")
137+
136138
num_expert_tokens = torch.randint(low=0,
137139
high=max_tokens_per_expert,
138140
size=(num_experts, ),
@@ -153,7 +155,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
153155
num_experts,
154156
N // 2,
155157
K,
156-
quant_dtype=dtype,
158+
in_dtype=act_dtype,
159+
quant_dtype=quant_dtype,
157160
block_shape=block_shape,
158161
)
159162

@@ -168,6 +171,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
168171
torch.float32: tl.float32
169172
}[test_output.dtype]
170173

174+
assert A_q.dtype == B_q.dtype
175+
171176
invoke_moe_batched_triton_kernel(
172177
A_q,
173178
B_q,
@@ -185,7 +190,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
185190
config={
186191
"BLOCK_SIZE_M": 16,
187192
"BLOCK_SIZE_N": 16,
188-
"BLOCK_SIZE_K": 16
193+
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
189194
},
190195
block_shape=block_shape,
191196
)
@@ -209,7 +214,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
209214
torch.float32: (1e-2, 1e-2),
210215
}[test_output.dtype]
211216

212-
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
217+
torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
213218
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
214219

215220

@@ -234,7 +239,6 @@ def test_fused_moe_batched_experts(
234239
current_platform.seed_everything(7)
235240

236241
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
237-
quant_type = torch.float8_e4m3fn if use_fp8_w8a8 else None
238242

239243
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
240244
pytest.skip("Skip quantization test for non-quantized type")
@@ -244,20 +248,30 @@ def test_fused_moe_batched_experts(
244248

245249
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
246250
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
247-
_, w1, w1_s, _, w2, w2_s = make_test_weights(e, n, k, block_shape=block_shape, quant_dtype=dtype)
251+
252+
if dtype.itemsize == 1:
253+
act_dtype = torch.bfloat16
254+
quant_dtype = dtype
255+
else:
256+
act_dtype = dtype
257+
quant_dtype = None
258+
259+
_, w1, w1_s, _, w2, w2_s = make_test_weights(e, n, k, block_shape=block_shape,
260+
in_dtype=act_dtype,
261+
quant_dtype=quant_dtype)
248262

249263
torch.set_printoptions(profile="full")
250264

251265
with set_current_vllm_config(vllm_config):
252266
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
253267
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
254-
w2_s, quant_type, per_act_token_quant,
268+
w2_s, quant_dtype, per_act_token_quant,
255269
block_shape)
256270
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
257-
w2_s, quant_type, per_act_token_quant,
271+
w2_s, quant_dtype, per_act_token_quant,
258272
block_shape)
259273
triton_output = triton_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
260-
w2_s, quant_type, per_act_token_quant,
274+
w2_s, quant_dtype, per_act_token_quant,
261275
block_shape)
262276

263277
torch.testing.assert_close(triton_output,

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -446,26 +446,73 @@ def prepare(
446446

447447
num_local_experts = num_experts // self.world_size
448448

449-
assert quant_config.quant_dtype is None, "NYI"
450-
451449
b_type = a1.dtype if quant_config.quant_dtype is None else quant_config.quant_dtype
452450

453451
b_a1 = torch.zeros(
454452
(num_local_experts, self.max_num_tokens, hidden_dim),
455453
dtype=b_type,
456454
device=a1.device)
457455

456+
if quant_config.quant_dtype is not None:
457+
if quant_config.block_shape is not None:
458+
_, block_k = quant_config.block_shape
459+
k_tiles = (hidden_dim + block_k - 1) // block_k
460+
scale_shape = (num_local_experts, self.max_num_tokens, k_tiles)
461+
else:
462+
if quant_config.per_act_token_quant:
463+
num = self.max_num_tokens
464+
else:
465+
num = 1
466+
scale_shape = (num_local_experts, num, 1)
467+
468+
#print(f"SCALE_SHAPE {block_shape} {b_a1.shape} {scale_shape}")
469+
470+
b_a1_scale = torch.zeros(scale_shape,
471+
dtype=torch.float32,
472+
device=a1.device)
473+
else:
474+
assert a1_scale is None
475+
b_a1_scale = None
476+
458477
first_expert = num_local_experts * self.rank
459478
last_expert = first_expert + num_local_experts
460479

461480
for expert_id in range(first_expert, last_expert):
462481
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
463482
rows = torch.count_nonzero(topks.flatten())
464-
b_a1[expert_id -
465-
first_expert, :rows, :] = a1[:topks.numel()][topks]
466-
tokens_per_expert[expert_id - first_expert] = rows
483+
rhs = a1[:topks.numel()][topks]
484+
idx = expert_id - first_expert
485+
if quant_config.quant_dtype is not None:
486+
if a1_scale is not None:
487+
assert False, "NYI"
488+
rhs_a1_scale = a1_scale[:topks.numel()][topks]
489+
else:
490+
rhs_a1_scale = None
491+
b_a1[idx, :rows, :], b_s = moe_kernel_quantize_input(
492+
rhs,
493+
rhs_a1_scale,
494+
quant_config.quant_dtype,
495+
quant_config.per_act_token_quant,
496+
quant_config.block_shape,
497+
)
498+
assert b_s is not None
499+
if (quant_config.block_shape is None
500+
and not quant_config.per_act_token_quant):
501+
print(f"SCALE {idx}, {b_a1_scale[idx, :].shape} {b_s.shape}")
502+
b_a1_scale[idx, :] = b_s
503+
else:
504+
#print(f"XXXXX rhs={rhs.shape} b_s={b_s.shape}")
505+
assert rows == b_s.shape[0] and b_a1_scale.shape[
506+
-1] == b_s.shape[-1]
507+
b_a1_scale[idx, :rows] = b_s
508+
else:
509+
b_a1[idx, :rows, :] = rhs
467510

468-
return b_a1, a1_scale, tokens_per_expert, None, None
511+
tokens_per_expert[idx] = rows
512+
513+
assert b_a1_scale is None or b_a1_scale.ndim == 3
514+
515+
return b_a1, b_a1_scale, tokens_per_expert, None, None
469516

470517
def finalize(
471518
self,
@@ -770,6 +817,8 @@ def apply(
770817
config=config,
771818
block_shape=self.block_shape)
772819

820+
intermediate_cache2.fill_(0)
821+
773822
# TODO: would be nice to use expert_num_tokens here to reduce
774823
# garbage compute
775824
self.activation(activation, intermediate_cache2.view(-1, N // 2),

0 commit comments

Comments
 (0)