Skip to content

Commit 5e5baa9

Browse files
authored
[Kernels] Use empty for modular MoE workspaces (#19667)
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 836d4ce commit 5e5baa9

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,9 @@ def apply(
716716
intermediate_cache2 = _resize_cache(workspace2,
717717
(E, max_num_tokens, N // 2))
718718

719+
if self.use_fp8_w8a8:
720+
intermediate_cache1.fill_(0)
721+
719722
# MM1
720723
invoke_moe_batched_triton_kernel(A=hidden_states,
721724
B=w1,

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,10 @@ def forward(
426426

427427
# We can reuse the memory between cache1 and cache3 because by the
428428
# time we need cache3, we're done with cache1.
429-
workspace13 = torch.zeros(prod(workspace13_shape),
429+
workspace13 = torch.empty(prod(workspace13_shape),
430430
device=a1.device,
431431
dtype=workspace_dtype)
432-
workspace2 = torch.zeros(prod(workspace2_shape),
432+
workspace2 = torch.empty(prod(workspace2_shape),
433433
device=a1.device,
434434
dtype=workspace_dtype)
435435

0 commit comments

Comments
 (0)