Skip to content

Commit ed1fb0f

Browse files
authored
[Misc] Add type alias ReqId and EngineId for better readability (vllm-project#19880)
Signed-off-by: Linkun Chen <github@lkchen.net>
1 parent cff610e commit ed1fb0f

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from vllm.v1.request import Request
3737

3838
Transfer = tuple[int, float] # (xfer_handle, start_time)
39+
EngineId = str
40+
ReqId = str
3941
GET_META_MSG = b"get_meta_msg"
4042

4143
logger = init_logger(__name__)
@@ -75,7 +77,7 @@ class ReqMeta:
7577
class NixlConnectorMetadata(KVConnectorMetadata):
7678

7779
def __init__(self):
78-
self.requests: dict[str, ReqMeta] = {}
80+
self.requests: dict[ReqId, ReqMeta] = {}
7981

8082
def add_new_req(
8183
self,
@@ -96,16 +98,17 @@ class NixlConnector(KVConnectorBase_V1):
9698

9799
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
98100
assert vllm_config.kv_transfer_config is not None
99-
self.engine_id = vllm_config.kv_transfer_config.engine_id
101+
assert vllm_config.kv_transfer_config.engine_id is not None
102+
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
100103

101104
if role == KVConnectorRole.SCHEDULER:
102105
self.connector_scheduler : Optional[NixlConnectorScheduler] = \
103-
NixlConnectorScheduler(vllm_config, str(self.engine_id))
106+
NixlConnectorScheduler(vllm_config, self.engine_id)
104107
self.connector_worker: Optional[NixlConnectorWorker] = None
105108
elif role == KVConnectorRole.WORKER:
106109
self.connector_scheduler = None
107110
self.connector_worker = NixlConnectorWorker(
108-
vllm_config, str(self.engine_id))
111+
vllm_config, self.engine_id)
109112

110113
############################################################
111114
# Scheduler Side Methods
@@ -179,7 +182,7 @@ class NixlConnectorScheduler:
179182
def __init__(self, vllm_config: VllmConfig, engine_id: str):
180183
self.vllm_config = vllm_config
181184
self.block_size = vllm_config.cache_config.block_size
182-
self.engine_id = engine_id
185+
self.engine_id: EngineId = engine_id
183186
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
184187
self.side_channel_port = (
185188
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
@@ -190,7 +193,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
190193
# Requests that need to start recv.
191194
# New requests are added by update_state_after_alloc in
192195
# the scheduler. Used to make metadata passed to Worker.
193-
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
196+
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
194197

195198
def get_num_new_matched_tokens(
196199
self, request: "Request",
@@ -332,19 +335,19 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
332335
# Agent.
333336
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
334337
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
335-
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)
338+
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
336339

337340
# NIXL handshake port.
338341
# NOTE(rob): Within a DP group, each DP rank gets its own
339342
# base port (which is sent in the KVTransferParams).
340343
# Each TP rank listens/queries on the base_port + tp_rank.
341-
self.side_channel_port = (
344+
self.side_channel_port: int = (
342345
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
343346
vllm_config.parallel_config.data_parallel_rank_local *
344347
vllm_config.parallel_config.tensor_parallel_size)
345348

346349
# Metadata.
347-
self.engine_id = engine_id
350+
self.engine_id: EngineId = engine_id
348351
self.tp_rank = get_tensor_model_parallel_rank()
349352
self.world_size = get_tensor_model_parallel_world_size()
350353
self.tp_group = get_tp_group()
@@ -354,7 +357,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
354357

355358
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
356359
# rank will still only pull from a single remote TP worker.
357-
self.kv_caches_base_addr: dict[str, list[int]] = {}
360+
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
358361

359362
# Number of NIXL regions. Currently one region per cache
360363
# (so 1 per layer for MLA, otherwise 2 per layer)
@@ -364,23 +367,23 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
364367
# nixl_prepped_dlist_handle.
365368
self.src_xfer_side_handle: int = 0
366369
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
367-
self.dst_xfer_side_handles: dict[str, int] = {}
370+
self.dst_xfer_side_handles: dict[EngineId, int] = {}
368371

369372
# Map of engine_id -> num_blocks. All ranks in the same deployment will
370373
# have the same number of blocks.
371-
self.dst_num_blocks: dict[str, int] = {}
374+
self.dst_num_blocks: dict[EngineId, int] = {}
372375
self._registered_descs: list[Any] = []
373376

374377
# In progress transfers.
375378
# [req_id -> list[handle]]
376-
self._recving_transfers = defaultdict[str, list[Transfer]](list)
379+
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
377380

378381
# Complete transfer tracker. Used by the rank 0 to track finished
379382
# transactions on ranks 1 to N-1.
380383
# [req_id -> count]
381-
self._done_recving_count: defaultdict[str,
384+
self._done_recving_count: defaultdict[ReqId,
382385
int] = defaultdict(lambda: 0)
383-
self._done_sending_count: defaultdict[str,
386+
self._done_sending_count: defaultdict[ReqId,
384387
int] = defaultdict(lambda: 0)
385388

386389
# Background thread for establishing new connections.
@@ -408,10 +411,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
408411
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
409412
logger.debug("Detected attention backend %s", self.backend_name)
410413

411-
self._tp_size: dict[str, int] = {self.engine_id: self.world_size}
414+
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
412415
# With heterogeneous TP, P must wait for all assigned D TP workers to
413416
# finish reading before safely freeing the blocks.
414-
self.consumer_notification_counts_by_req = defaultdict[str, int](int)
417+
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
415418

416419
@staticmethod
417420
def _nixl_handshake_listener(metadata: NixlAgentMetadata,

0 commit comments

Comments
 (0)