|
26 | 26 | from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
27 | 27 | from vllm.pooling_params import PoolingParams
|
28 | 28 | from vllm.sampling_params import SamplingParams, SamplingType
|
| 29 | +from vllm.utils import swap_dict_values |
29 | 30 | from vllm.v1.outputs import LogprobsTensors
|
30 | 31 | from vllm.v1.sample.metadata import SamplingMetadata
|
31 | 32 | from vllm.v1.utils import copy_slice
|
@@ -423,6 +424,64 @@ def remove_request(self, req_id: str) -> Optional[int]:
|
423 | 424 | self.pooling_params.pop(req_id, None)
|
424 | 425 | return req_index
|
425 | 426 |
|
| 427 | + def swap_states(self, i1: int, i2: int) -> None: |
| 428 | + old_id_i1 = self._req_ids[i1] |
| 429 | + old_id_i2 = self._req_ids[i2] |
| 430 | + self._req_ids[i1], self._req_ids[i2] =\ |
| 431 | + self._req_ids[i2], self._req_ids[i1] # noqa |
| 432 | + self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\ |
| 433 | + self.req_output_token_ids[i2], self.req_output_token_ids[i1] |
| 434 | + assert old_id_i1 is not None and old_id_i2 is not None |
| 435 | + self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\ |
| 436 | + self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] |
| 437 | + self.num_tokens[i1], self.num_tokens[i2] =\ |
| 438 | + self.num_tokens[i2], self.num_tokens[i1] |
| 439 | + self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ |
| 440 | + self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] |
| 441 | + self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ |
| 442 | + self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] |
| 443 | + self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ |
| 444 | + self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] |
| 445 | + self.temperature_cpu[i1], self.temperature_cpu[i2] =\ |
| 446 | + self.temperature_cpu[i2], self.temperature_cpu[i1] |
| 447 | + self.top_p_cpu[i1], self.top_p_cpu[i2] =\ |
| 448 | + self.top_p_cpu[i2], self.top_p_cpu[i1] |
| 449 | + self.top_k_cpu[i1], self.top_k_cpu[i2] =\ |
| 450 | + self.top_k_cpu[i2], self.top_k_cpu[i1] |
| 451 | + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ |
| 452 | + self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] |
| 453 | + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ |
| 454 | + self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] |
| 455 | + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ |
| 456 | + self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] |
| 457 | + self.min_p_cpu[i1], self.min_p_cpu[i2] =\ |
| 458 | + self.min_p_cpu[i2], self.min_p_cpu[i1] |
| 459 | + |
| 460 | + # NOTE: the following is unsafe |
| 461 | + # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ |
| 462 | + # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] |
| 463 | + # instead, we need to temporiarily copy the data for one of the indices |
| 464 | + # TODO(lucas): optimize this by only copying valid indices |
| 465 | + tmp = self.token_ids_cpu[i1, ...].copy() |
| 466 | + self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] |
| 467 | + self.token_ids_cpu[i2, ...] = tmp |
| 468 | + |
| 469 | + swap_dict_values(self.generators, i1, i2) |
| 470 | + swap_dict_values(self.min_tokens, i1, i2) |
| 471 | + swap_dict_values(self.bad_words_token_ids, i1, i2) |
| 472 | + |
| 473 | + self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ |
| 474 | + self.request_lora_mapping[i2], self.request_lora_mapping[i1] |
| 475 | + self.logit_bias[i1], self.logit_bias[i2] =\ |
| 476 | + self.logit_bias[i2], self.logit_bias[i1] |
| 477 | + |
| 478 | + if self.allowed_token_ids_mask_cpu_tensor is not None: |
| 479 | + self.allowed_token_ids_mask_cpu_tensor[i1], \ |
| 480 | + self.allowed_token_ids_mask_cpu_tensor[i2] =\ |
| 481 | + self.allowed_token_ids_mask_cpu_tensor[i2], \ |
| 482 | + self.allowed_token_ids_mask_cpu_tensor[i1] |
| 483 | + self.block_table.swap_row(i1, i2) |
| 484 | + |
426 | 485 | def condense(self, empty_req_indices: list[int]) -> None:
|
427 | 486 | """Move non-empty requests down into lower, empty indices.
|
428 | 487 |
|
|
0 commit comments