Skip to content

Commit e59258a

Browse files
committed
disable buggy fp8 tests
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 58a5c18 commit e59258a

File tree

4 files changed

+11
-82
lines changed

4 files changed

+11
-82
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@ def make_tensors(config: BatchedMMConfig):
6767
device="cuda",
6868
dtype=torch.int32)
6969

70-
71-
7270
return BatchedMMTensors(A, B, C, num_expert_tokens)
7371

7472

@@ -111,9 +109,7 @@ def ref_impl(
111109
[32, 64, 128, 192, 224, 256, 512])
112110
@pytest.mark.parametrize("K", [128, 256, 1024])
113111
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
114-
@pytest.mark.parametrize(
115-
"dtype",
116-
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
112+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
117113
@pytest.mark.parametrize("block_shape", [None])
118114
@pytest.mark.parametrize("per_act_token_quant", [False])
119115
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
@@ -223,7 +219,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
223219
@pytest.mark.parametrize("k", [128, 512, 1024, 2048])
224220
@pytest.mark.parametrize("e", NUM_EXPERTS)
225221
@pytest.mark.parametrize("topk", TOP_KS)
226-
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
222+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
227223
@pytest.mark.parametrize("per_act_token_quant", [False])
228224
@pytest.mark.parametrize("block_shape", [None])
229225
def test_fused_moe_batched_experts(

tests/kernels/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,8 @@ def torch_experts(a: torch.Tensor,
10631063
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
10641064
assert (global_num_experts == -1
10651065
or (global_num_experts == w1.shape[0] and expert_map is None)
1066-
or global_num_experts == expert_map.shape[0])
1066+
or (expert_map is not None
1067+
and global_num_experts == expert_map.shape[0]))
10671068
topk = topk_ids.shape[1]
10681069
B, D = a.shape
10691070
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ def invoke_moe_batched_triton_kernel(
318318
expert_num_tokens: torch.Tensor, # [E]
319319
compute_type: tl.dtype,
320320
# Quantization data
321-
A_scale: torch.Tensor, # Optional
322-
B_scale: torch.Tensor, # Optional
321+
A_scale: Optional[torch.Tensor],
322+
B_scale: Optional[torch.Tensor],
323323
B_zp: torch.Tensor,
324324
# Quantization schemes
325325
use_fp8_w8a8: bool,
@@ -453,61 +453,18 @@ def prepare(
453453
dtype=b_type,
454454
device=a1.device)
455455

456-
if quant_config.quant_dtype is not None:
457-
if quant_config.block_shape is not None:
458-
_, block_k = quant_config.block_shape
459-
k_tiles = (hidden_dim + block_k - 1) // block_k
460-
scale_shape = (num_local_experts, self.max_num_tokens, k_tiles)
461-
else:
462-
if quant_config.per_act_token_quant:
463-
num = self.max_num_tokens
464-
else:
465-
num = 1
466-
scale_shape = (num_local_experts, num, 1)
456+
b_a1_scale = None
467457

468-
#print(f"SCALE_SHAPE {block_shape} {b_a1.shape} {scale_shape}")
469-
470-
b_a1_scale = torch.zeros(scale_shape,
471-
dtype=torch.float32,
472-
device=a1.device)
473-
else:
474-
assert a1_scale is None
475-
b_a1_scale = None
458+
assert quant_config.quant_dtype is None, "quantization NYI"
476459

477460
first_expert = num_local_experts * self.rank
478461
last_expert = first_expert + num_local_experts
479462

480463
for expert_id in range(first_expert, last_expert):
481464
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
482465
rows = torch.count_nonzero(topks.flatten())
483-
rhs = a1[:topks.numel()][topks]
484466
idx = expert_id - first_expert
485-
if quant_config.quant_dtype is not None:
486-
if a1_scale is not None:
487-
assert False, "NYI"
488-
rhs_a1_scale = a1_scale[:topks.numel()][topks]
489-
else:
490-
rhs_a1_scale = None
491-
b_a1[idx, :rows, :], b_s = moe_kernel_quantize_input(
492-
rhs,
493-
rhs_a1_scale,
494-
quant_config.quant_dtype,
495-
quant_config.per_act_token_quant,
496-
quant_config.block_shape,
497-
)
498-
assert b_s is not None
499-
if (quant_config.block_shape is None
500-
and not quant_config.per_act_token_quant):
501-
print(f"SCALE {idx}, {b_a1_scale[idx, :].shape} {b_s.shape}")
502-
b_a1_scale[idx, :] = b_s
503-
else:
504-
#print(f"XXXXX rhs={rhs.shape} b_s={b_s.shape}")
505-
assert rows == b_s.shape[0] and b_a1_scale.shape[
506-
-1] == b_s.shape[-1]
507-
b_a1_scale[idx, :rows] = b_s
508-
else:
509-
b_a1[idx, :rows, :] = rhs
510-
467+
b_a1[idx, :rows, :] = a1[:topks.numel()][topks]
511468
tokens_per_expert[idx] = rows
512469

513470
assert b_a1_scale is None or b_a1_scale.ndim == 3

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -843,14 +843,14 @@ def try_get_optimal_moe_config_list(
843843
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
844844
is_marlin, block_shape)
845845

846-
return [
846+
return (
847847
config['BLOCK_SIZE_M'],
848848
config['BLOCK_SIZE_N'],
849849
config['BLOCK_SIZE_K'],
850850
config['GROUP_SIZE_M'],
851851
config.get('num_warps', 4),
852852
config.get('num_stages', 3 if not current_platform.is_rocm() else 2),
853-
]
853+
)
854854

855855

856856
direct_register_custom_op(
@@ -1213,31 +1213,6 @@ def fused_experts(hidden_states: torch.Tensor,
12131213
a2_scale=a2_scale,
12141214
apply_router_weight_on_input=apply_router_weight_on_input,
12151215
)
1216-
elif True:
1217-
fn = modular_triton_fused_moe(use_fp8_w8a8=use_fp8_w8a8,
1218-
use_int8_w8a8=use_int8_w8a8,
1219-
use_int8_w8a16=use_int8_w8a16,
1220-
use_int4_w4a16=use_int4_w4a16,
1221-
per_channel_quant=per_channel_quant,
1222-
block_shape=block_shape)
1223-
1224-
return fn(
1225-
hidden_states=hidden_states,
1226-
w1=w1,
1227-
w2=w2,
1228-
topk_weights=topk_weights,
1229-
topk_ids=topk_ids,
1230-
activation=activation,
1231-
apply_router_weight_on_input=apply_router_weight_on_input,
1232-
global_num_experts=global_num_experts,
1233-
expert_map=expert_map,
1234-
w1_scale=w1_scale,
1235-
w2_scale=w2_scale,
1236-
w1_zp=w1_zp,
1237-
w2_zp=w2_zp,
1238-
a1_scale=a1_scale,
1239-
a2_scale=a2_scale,
1240-
)
12411216
else:
12421217
return dispatch_fused_experts_func(inplace)(
12431218
hidden_states=hidden_states,

0 commit comments

Comments
 (0)