@@ -626,10 +626,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
626
626
if batch_changed :
627
627
self .input_batch .refresh_sampling_metadata ()
628
628
629
- def _get_forward_metadata_across_dp (self , num_tokens : int ,
630
- with_prefill : bool ) -> tuple [int , bool ]:
629
+ def _get_forward_metadata_across_dp (
630
+ self , num_tokens : int ,
631
+ with_prefill : bool ) -> tuple [torch .Tensor , bool ]:
631
632
local_forward_metadata = torch .tensor ([num_tokens , with_prefill ],
632
- device = "npu" , dtype = torch .int32 )
633
+ device = "npu" ,
634
+ dtype = torch .int32 )
633
635
global_forward_metadata = get_dp_group ().all_gather (
634
636
local_forward_metadata )
635
637
num_tokens_across_dp = global_forward_metadata [:, 0 ].cpu ()
@@ -1874,10 +1876,11 @@ def _dummy_run(
1874
1876
for k , v in self .intermediate_tensors .items ()
1875
1877
})
1876
1878
1877
- with set_forward_context (None ,
1878
- self .vllm_config ,
1879
- num_tokens = num_tokens ,
1880
- num_tokens_across_dp = num_tokens_across_dp ):
1879
+ with set_forward_context (
1880
+ None ,
1881
+ self .vllm_config ,
1882
+ num_tokens = num_tokens ,
1883
+ num_tokens_across_dp = num_tokens_across_dp ):
1881
1884
if self .torchair_graph_enabled and not with_prefill :
1882
1885
attn_metadata = self .attn_metadata_builder .build_dummy (
1883
1886
num_reqs = num_tokens , num_actual_tokens = 1 )
0 commit comments