File tree Expand file tree Collapse file tree 1 file changed +10
-9
lines changed Expand file tree Collapse file tree 1 file changed +10
-9
lines changed Original file line number Diff line number Diff line change @@ -1102,17 +1102,18 @@ def _process_reqs(
1102
1102
1103
1103
# Add graph_pad_size here
1104
1104
if self .torchair_graph_enabled and not with_prefill :
1105
- if self .dp_size > 1 :
1106
- padded_batch_size = self .select_torchair_padded_batch_size (
1107
- max_num_tokens )
1108
- num_tokens_across_dp .masked_fill_ (num_tokens_across_dp == - 1 ,
1109
- padded_batch_size )
1110
- else :
1111
- padded_batch_size = self .select_torchair_padded_batch_size (
1112
- total_num_scheduled_tokens )
1105
+ max_num_tokens = (max_num_tokens
1106
+ if self .dp_size > 1 else num_input_tokens )
1107
+ padded_batch_size = self .select_torchair_padded_batch_size (
1108
+ max_num_tokens )
1109
+ num_tokens_across_dp .masked_fill_ (num_tokens_across_dp == - 1 ,
1110
+ padded_batch_size )
1113
1111
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
1114
-
1115
1112
extra_builder_kwargs ['graph_pad_size' ] = graph_pad_size
1113
+ else :
1114
+ # If torchair graph is not enabled, or if with_prefill is True, the
1115
+ # dummy run batch size is set to 1.
1116
+ num_tokens_across_dp .masked_fill_ (num_tokens_across_dp == - 1 , 1 )
1116
1117
1117
1118
if self .vllm_config .model_config .use_mla :
1118
1119
attn_metadata = self .attn_metadata_builder .build ( # type: ignore
You can’t perform that action at this time.
0 commit comments