@@ -619,10 +619,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
619
619
if batch_changed :
620
620
self .input_batch .refresh_sampling_metadata ()
621
621
622
- def _get_forward_metadata_across_dp (self , num_tokens : int ,
623
- with_prefill : bool ) -> tuple [int , bool ]:
622
+ def _get_forward_metadata_across_dp (
623
+ self , num_tokens : int ,
624
+ with_prefill : bool ) -> tuple [torch .Tensor , bool ]:
624
625
local_forward_metadata = torch .tensor ([num_tokens , with_prefill ],
625
- device = "npu" , dtype = torch .int32 )
626
+ device = "npu" ,
627
+ dtype = torch .int32 )
626
628
global_forward_metadata = get_dp_group ().all_gather (
627
629
local_forward_metadata )
628
630
num_tokens_across_dp = global_forward_metadata [:, 0 ].cpu ()
@@ -1861,10 +1863,11 @@ def _dummy_run(
1861
1863
for k , v in self .intermediate_tensors .items ()
1862
1864
})
1863
1865
1864
- with set_forward_context (None ,
1865
- self .vllm_config ,
1866
- num_tokens = num_tokens ,
1867
- num_tokens_across_dp = num_tokens_across_dp ):
1866
+ with set_forward_context (
1867
+ None ,
1868
+ self .vllm_config ,
1869
+ num_tokens = num_tokens ,
1870
+ num_tokens_across_dp = num_tokens_across_dp ):
1868
1871
if self .torchair_graph_enabled and not with_prefill :
1869
1872
attn_metadata = self .attn_metadata_builder .build_dummy (
1870
1873
num_reqs = num_tokens , num_actual_tokens = 1 )
0 commit comments