Skip to content

Commit 7a95679

Browse files
committed
disable buggy fp8 tests
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 1bae03b commit 7a95679

File tree

2 files changed

+7
-54
lines changed

2 files changed

+7
-54
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@ def make_tensors(config: BatchedMMConfig):
6767
device="cuda",
6868
dtype=torch.int32)
6969

70-
71-
7270
return BatchedMMTensors(A, B, C, num_expert_tokens)
7371

7472

@@ -111,9 +109,7 @@ def ref_impl(
111109
[32, 64, 128, 192, 224, 256, 512])
112110
@pytest.mark.parametrize("K", [128, 256, 1024])
113111
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
114-
@pytest.mark.parametrize(
115-
"dtype",
116-
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
112+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
117113
@pytest.mark.parametrize("block_shape", [None])
118114
@pytest.mark.parametrize("per_act_token_quant", [False])
119115
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
@@ -223,7 +219,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
223219
@pytest.mark.parametrize("k", [128, 512, 1024, 2048])
224220
@pytest.mark.parametrize("e", NUM_EXPERTS)
225221
@pytest.mark.parametrize("topk", TOP_KS)
226-
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
222+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
227223
@pytest.mark.parametrize("per_act_token_quant", [False])
228224
@pytest.mark.parametrize("block_shape", [None])
229225
def test_fused_moe_batched_experts(

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ def invoke_moe_batched_triton_kernel(
318318
expert_num_tokens: torch.Tensor, # [E]
319319
compute_type: tl.dtype,
320320
# Quantization data
321-
A_scale: torch.Tensor, # Optional
322-
B_scale: torch.Tensor, # Optional
321+
A_scale: Optional[torch.Tensor],
322+
B_scale: Optional[torch.Tensor],
323323
B_zp: torch.Tensor,
324324
# Quantization schemes
325325
use_fp8_w8a8: bool,
@@ -453,61 +453,18 @@ def prepare(
453453
dtype=b_type,
454454
device=a1.device)
455455

456-
if quant_config.quant_dtype is not None:
457-
if quant_config.block_shape is not None:
458-
_, block_k = quant_config.block_shape
459-
k_tiles = (hidden_dim + block_k - 1) // block_k
460-
scale_shape = (num_local_experts, self.max_num_tokens, k_tiles)
461-
else:
462-
if quant_config.per_act_token_quant:
463-
num = self.max_num_tokens
464-
else:
465-
num = 1
466-
scale_shape = (num_local_experts, num, 1)
456+
b_a1_scale = None
467457

468-
#print(f"SCALE_SHAPE {block_shape} {b_a1.shape} {scale_shape}")
469-
470-
b_a1_scale = torch.zeros(scale_shape,
471-
dtype=torch.float32,
472-
device=a1.device)
473-
else:
474-
assert a1_scale is None
475-
b_a1_scale = None
458+
assert quant_config.quant_dtype is None, "quantization NYI"
476459

477460
first_expert = num_local_experts * self.rank
478461
last_expert = first_expert + num_local_experts
479462

480463
for expert_id in range(first_expert, last_expert):
481464
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
482465
rows = torch.count_nonzero(topks.flatten())
483-
rhs = a1[:topks.numel()][topks]
484466
idx = expert_id - first_expert
485-
if quant_config.quant_dtype is not None:
486-
if a1_scale is not None:
487-
assert False, "NYI"
488-
rhs_a1_scale = a1_scale[:topks.numel()][topks]
489-
else:
490-
rhs_a1_scale = None
491-
b_a1[idx, :rows, :], b_s = moe_kernel_quantize_input(
492-
rhs,
493-
rhs_a1_scale,
494-
quant_config.quant_dtype,
495-
quant_config.per_act_token_quant,
496-
quant_config.block_shape,
497-
)
498-
assert b_s is not None
499-
if (quant_config.block_shape is None
500-
and not quant_config.per_act_token_quant):
501-
print(f"SCALE {idx}, {b_a1_scale[idx, :].shape} {b_s.shape}")
502-
b_a1_scale[idx, :] = b_s
503-
else:
504-
#print(f"XXXXX rhs={rhs.shape} b_s={b_s.shape}")
505-
assert rows == b_s.shape[0] and b_a1_scale.shape[
506-
-1] == b_s.shape[-1]
507-
b_a1_scale[idx, :rows] = b_s
508-
else:
509-
b_a1[idx, :rows, :] = rhs
510-
467+
b_a1[idx, :rows, :] = a1[:topks.numel()][topks]
511468
tokens_per_expert[idx] = rows
512469

513470
assert b_a1_scale is None or b_a1_scale.ndim == 3

0 commit comments

Comments
 (0)