File tree Expand file tree Collapse file tree 1 file changed +7
-2
lines changed
vllm/model_executor/layers/fused_moe Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Original file line number Diff line number Diff line change @@ -1320,8 +1320,13 @@ def maybe_all_reduce_tensor_model_parallel(
1320
1320
1321
1321
def forward (self , hidden_states : torch .Tensor ,
1322
1322
router_logits : torch .Tensor ):
1323
- return torch .ops .vllm .moe_forward (hidden_states , router_logits ,
1324
- self .layer_name )
1323
+ # TODO: Once the OOM issue for the TPU backend is resolved, we will
1324
+ # switch to using the moe_forward custom op.
1325
+ if current_platform .is_tpu ():
1326
+ return self .forward_impl (hidden_states , router_logits )
1327
+ else :
1328
+ return torch .ops .vllm .moe_forward (hidden_states , router_logits ,
1329
+ self .layer_name )
1325
1330
1326
1331
def forward_impl_chunked (self , full_hidden_states : torch .Tensor ,
1327
1332
full_router_logits : torch .Tensor ):
You can’t perform that action at this time.
0 commit comments