Skip to content

Commit 90772e8

Browse files
committed
review comments + fixes
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 1b2a672 commit 90772e8

File tree

5 files changed

+14
-11
lines changed

5 files changed

+14
-11
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
9595
act_dtype = dtype
9696
quant_dtype = None
9797

98-
#print(f"TYPES {dtype}, {act_dtype}, {quant_dtype}")
99-
10098
num_expert_tokens = torch.randint(low=0,
10199
high=max_tokens_per_expert,
102100
size=(num_experts, ),
@@ -226,8 +224,6 @@ def test_fused_moe_batched_experts(
226224
in_dtype=act_dtype,
227225
quant_dtype=quant_dtype)
228226

229-
torch.set_printoptions(profile="full")
230-
231227
with set_current_vllm_config(vllm_config):
232228
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
233229
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,

tests/kernels/moe/test_block_fp8.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
# Adapted from https://github.com/sgl-project/sglang/pull/2575
54
import itertools
65

76
import pytest

tests/kernels/moe/test_block_int8.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_block_int8.py
54
import itertools
65

76
import pytest

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,12 @@ def __init__(
231231
def activation_formats(
232232
self
233233
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
234-
return (mk.FusedMoEActivationFormat.Standard,
235-
mk.FusedMoEActivationFormat.Standard)
234+
if self.use_batched_format:
235+
return (mk.FusedMoEActivationFormat.BatchedExperts,
236+
mk.FusedMoEActivationFormat.BatchedExperts)
237+
else:
238+
return (mk.FusedMoEActivationFormat.Standard,
239+
mk.FusedMoEActivationFormat.Standard)
236240

237241
def supports_chunking(self) -> bool:
238242
return not self.use_batched_format

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,13 @@ def __init__(
607607
if params_dtype is None:
608608
params_dtype = torch.get_default_dtype()
609609
self.params_dtype = params_dtype
610-
all2all_manager = get_ep_group().device_communicator.all2all_manager
610+
611+
if ep_size is not None:
612+
all2all_manager = get_ep_group().device_communicator.all2all_manager
613+
world_size = (all2all_manager.world_size
614+
if all2all_manager is not None else 1)
615+
else:
616+
world_size = 1
611617

612618
vllm_config = get_current_vllm_config()
613619
self.moe_parallel_config: FusedMoEParallelConfig = (
@@ -616,8 +622,7 @@ def __init__(
616622
get_tensor_model_parallel_world_size()),
617623
dp_size_=(dp_size if dp_size is not None else
618624
get_dp_group().world_size),
619-
world_size_=(all2all_manager.world_size
620-
if all2all_manager is not None else 1),
625+
world_size_=world_size,
621626
vllm_parallel_config=vllm_config.parallel_config))
622627

623628
self.global_num_experts = num_experts

0 commit comments

Comments
 (0)