@@ -620,10 +620,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
620
620
if batch_changed :
621
621
self .input_batch .refresh_sampling_metadata ()
622
622
623
- def _get_forward_metadata_across_dp (self , num_tokens : int ,
624
- with_prefill : bool ) -> tuple [int , bool ]:
623
+ def _get_forward_metadata_across_dp (
624
+ self , num_tokens : int ,
625
+ with_prefill : bool ) -> tuple [torch .Tensor , bool ]:
625
626
local_forward_metadata = torch .tensor ([num_tokens , with_prefill ],
626
- device = "npu" , dtype = torch .int32 )
627
+ device = "npu" ,
628
+ dtype = torch .int32 )
627
629
global_forward_metadata = get_dp_group ().all_gather (
628
630
local_forward_metadata )
629
631
num_tokens_across_dp = global_forward_metadata [:, 0 ].cpu ()
@@ -1868,10 +1870,11 @@ def _dummy_run(
1868
1870
for k , v in self .intermediate_tensors .items ()
1869
1871
})
1870
1872
1871
- with set_forward_context (None ,
1872
- self .vllm_config ,
1873
- num_tokens = num_tokens ,
1874
- num_tokens_across_dp = num_tokens_across_dp ):
1873
+ with set_forward_context (
1874
+ None ,
1875
+ self .vllm_config ,
1876
+ num_tokens = num_tokens ,
1877
+ num_tokens_across_dp = num_tokens_across_dp ):
1875
1878
if self .torchair_graph_enabled and not with_prefill :
1876
1879
attn_metadata = self .attn_metadata_builder .build_dummy (
1877
1880
num_reqs = num_tokens , num_actual_tokens = 1 )
0 commit comments