Skip to content

Commit 03b41b6

Browse files
committed
fix merge
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent cad6447 commit 03b41b6

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
reason="Requires PPLX kernels",
6464
)
6565

66-
6766
@dataclasses.dataclass
6867
class ProcessGroupInfo:
6968
world_size: int
@@ -74,6 +73,11 @@ class ProcessGroupInfo:
7473
device: torch.device
7574

7675

76+
@pytest.fixture(scope="function", autouse=True)
77+
def use_pplx_backend(monkeypatch):
78+
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "pplx")
79+
80+
7781
def _worker_parallel_launch(
7882
local_rank: int,
7983
world_size: int,

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,8 +429,6 @@ def prepare(
429429
"apply_router_weight_on_input is only implemented for topk=1"
430430
a1.mul_(topk_weights.to(a1.dtype))
431431

432-
_, block_k = self.block_shape
433-
434432
num_tokens, hidden_dim = a1.size()
435433
topk = topk_ids.size(1)
436434

@@ -453,6 +451,7 @@ def prepare(
453451
device=a1.device)
454452

455453
if self.qtype is not None:
454+
_, block_k = self.block_shape
456455
k_tiles = (hidden_dim + block_k - 1) // block_k
457456
b_a1_scale = torch.zeros(
458457
(num_local_experts, self.max_num_tokens, k_tiles),

0 commit comments

Comments
 (0)