From 40393efc2b0d4904259ccecaf69be2216f05b675 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 22 May 2025 01:15:34 +0000 Subject: [PATCH] Add reorder_batch to TPU V1 Signed-off-by: mgoin --- vllm/v1/worker/tpu_model_runner.py | 60 +++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index b4daf5a3467..7a5434d5b29 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -291,6 +291,57 @@ def _verify_num_xla_graphs(self, case_str): " num_xla_graphs = {} curr_cached_graph = {}".format( case_str, self.num_xla_graphs, curr_cached_graph)) + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are and + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + num_prefill_tokens = 0 + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if its not, + # we should update this to something like < 8 in the future but + # currently the decode run only supports num_tokens = 1 + if num_tokens == 1: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: + break + + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + + return modified_batch + def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: """Update the cached states and the persistent batch with the scheduler output. @@ -411,7 +462,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: if removed_req_indices: self.input_batch.condense(removed_req_indices) - return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 + batch_changed = len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 + + # We want to reorder the batch so that the decode requests are at + # the front and the prefill requests are at the back. + batch_reordered = self.reorder_batch(self.input_batch, + scheduler_output) + + return batch_changed or batch_reordered def get_model(self) -> nn.Module: assert self.model is not None