@@ -612,10 +612,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
612
612
if batch_changed :
613
613
self .input_batch .refresh_sampling_metadata ()
614
614
615
- def _get_forward_metadata_across_dp (self , num_tokens : int ,
616
- with_prefill : bool ) -> tuple [int , bool ]:
615
+ def _get_forward_metadata_across_dp (
616
+ self , num_tokens : int ,
617
+ with_prefill : bool ) -> tuple [torch .Tensor , bool ]:
617
618
local_forward_metadata = torch .tensor ([num_tokens , with_prefill ],
618
- device = "npu" , dtype = torch .int32 )
619
+ device = "npu" ,
620
+ dtype = torch .int32 )
619
621
global_forward_metadata = get_dp_group ().all_gather (
620
622
local_forward_metadata )
621
623
num_tokens_across_dp = global_forward_metadata [:, 0 ].cpu ()
@@ -1830,10 +1832,11 @@ def _dummy_run(
1830
1832
for k , v in self .intermediate_tensors .items ()
1831
1833
})
1832
1834
1833
- with set_forward_context (None ,
1834
- self .vllm_config ,
1835
- num_tokens = num_tokens ,
1836
- num_tokens_across_dp = num_tokens_across_dp ):
1835
+ with set_forward_context (
1836
+ None ,
1837
+ self .vllm_config ,
1838
+ num_tokens = num_tokens ,
1839
+ num_tokens_across_dp = num_tokens_across_dp ):
1837
1840
if self .torchair_graph_enabled and not with_prefill :
1838
1841
attn_metadata = self .attn_metadata_builder .build_dummy (
1839
1842
num_reqs = num_tokens , num_actual_tokens = 1 )
0 commit comments