Skip to content

Commit 4548c03

Browse files
authored
[TPU][Bugfix] fix the MoE OOM issue (vllm-project#20339)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent 40b86aa commit 4548c03

File tree

1 file changed

+7
-2
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+7
-2
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,8 +1320,13 @@ def maybe_all_reduce_tensor_model_parallel(
13201320

13211321
def forward(self, hidden_states: torch.Tensor,
13221322
router_logits: torch.Tensor):
1323-
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
1324-
self.layer_name)
1323+
# TODO: Once the OOM issue for the TPU backend is resolved, we will
1324+
# switch to using the moe_forward custom op.
1325+
if current_platform.is_tpu():
1326+
return self.forward_impl(hidden_states, router_logits)
1327+
else:
1328+
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
1329+
self.layer_name)
13251330

13261331
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
13271332
full_router_logits: torch.Tensor):

0 commit comments

Comments
 (0)