Skip to content

Commit 19a53b2

Browse files
authored
[V1] Decouple GPU and TPU InputBatch (#19778)
Signed-off-by: Andrew Feldman <afeldman@redhat.com>
1 parent eccdc83 commit 19a53b2

File tree

5 files changed

+597
-4
lines changed

5 files changed

+597
-4
lines changed

vllm/v1/sample/tpu/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from vllm.v1.worker.gpu_input_batch import InputBatch
8+
from vllm.v1.worker.tpu_input_batch import InputBatch
99

1010
DEFAULT_SAMPLING_PARAMS = dict(
1111
temperature=-1.0,

vllm/v1/worker/gpu_input_batch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
# Datastructures defining an input batch
3+
# Datastructures defining a GPU input batch
44

55
from dataclasses import dataclass
66
from typing import Optional, cast
@@ -453,6 +453,11 @@ def swap_states(self, i1: int, i2: int) -> None:
453453
self.block_table.swap_row(i1, i2)
454454

455455
def condense(self, empty_req_indices: list[int]) -> None:
456+
"""Move non-empty requests down into lower, empty indices.
457+
458+
Args:
459+
empty_req_indices: empty batch indices, sorted descending.
460+
"""
456461
num_reqs = self.num_reqs
457462
if num_reqs == 0:
458463
# The batched states are empty.

vllm/v1/worker/lora_model_runner_mixin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
from contextlib import contextmanager
8+
from typing import Union
89

910
import numpy as np
1011
import torch.nn as nn
@@ -15,7 +16,10 @@
1516
from vllm.lora.request import LoRARequest
1617
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
1718
from vllm.model_executor.models import supports_lora, supports_multimodal
18-
from vllm.v1.worker.gpu_input_batch import InputBatch
19+
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
20+
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
21+
22+
InputBatch = Union[TPUInputBatch, GPUInputBatch]
1923

2024
logger = init_logger(__name__)
2125

0 commit comments

Comments
 (0)