36
36
from vllm .v1 .request import Request
37
37
38
38
Transfer = tuple [int , float ] # (xfer_handle, start_time)
39
+ EngineId = str
40
+ ReqId = str
39
41
GET_META_MSG = b"get_meta_msg"
40
42
41
43
logger = init_logger (__name__ )
@@ -75,7 +77,7 @@ class ReqMeta:
75
77
class NixlConnectorMetadata (KVConnectorMetadata ):
76
78
77
79
def __init__ (self ):
78
- self .requests : dict [str , ReqMeta ] = {}
80
+ self .requests : dict [ReqId , ReqMeta ] = {}
79
81
80
82
def add_new_req (
81
83
self ,
@@ -96,16 +98,17 @@ class NixlConnector(KVConnectorBase_V1):
96
98
97
99
def __init__ (self , vllm_config : VllmConfig , role : KVConnectorRole ):
98
100
assert vllm_config .kv_transfer_config is not None
99
- self .engine_id = vllm_config .kv_transfer_config .engine_id
101
+ assert vllm_config .kv_transfer_config .engine_id is not None
102
+ self .engine_id : EngineId = vllm_config .kv_transfer_config .engine_id
100
103
101
104
if role == KVConnectorRole .SCHEDULER :
102
105
self .connector_scheduler : Optional [NixlConnectorScheduler ] = \
103
- NixlConnectorScheduler (vllm_config , str ( self .engine_id ) )
106
+ NixlConnectorScheduler (vllm_config , self .engine_id )
104
107
self .connector_worker : Optional [NixlConnectorWorker ] = None
105
108
elif role == KVConnectorRole .WORKER :
106
109
self .connector_scheduler = None
107
110
self .connector_worker = NixlConnectorWorker (
108
- vllm_config , str ( self .engine_id ) )
111
+ vllm_config , self .engine_id )
109
112
110
113
############################################################
111
114
# Scheduler Side Methods
@@ -179,7 +182,7 @@ class NixlConnectorScheduler:
179
182
def __init__ (self , vllm_config : VllmConfig , engine_id : str ):
180
183
self .vllm_config = vllm_config
181
184
self .block_size = vllm_config .cache_config .block_size
182
- self .engine_id = engine_id
185
+ self .engine_id : EngineId = engine_id
183
186
self .side_channel_host = envs .VLLM_NIXL_SIDE_CHANNEL_HOST
184
187
self .side_channel_port = (
185
188
envs .VLLM_NIXL_SIDE_CHANNEL_PORT +
@@ -190,7 +193,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
190
193
# Requests that need to start recv.
191
194
# New requests are added by update_state_after_alloc in
192
195
# the scheduler. Used to make metadata passed to Worker.
193
- self ._reqs_need_recv : dict [str , tuple [Request , list [int ]]] = {}
196
+ self ._reqs_need_recv : dict [ReqId , tuple [Request , list [int ]]] = {}
194
197
195
198
def get_num_new_matched_tokens (
196
199
self , request : "Request" ,
@@ -332,19 +335,19 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
332
335
# Agent.
333
336
self .nixl_wrapper = NixlWrapper (str (uuid .uuid4 ()), None )
334
337
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
335
- self ._remote_agents : dict [str , dict [int , str ]] = defaultdict (dict )
338
+ self ._remote_agents : dict [EngineId , dict [int , str ]] = defaultdict (dict )
336
339
337
340
# NIXL handshake port.
338
341
# NOTE(rob): Within a DP group, each DP rank gets its own
339
342
# base port (which is sent in the KVTransferParams).
340
343
# Each TP rank listens/queries on the base_port + tp_rank.
341
- self .side_channel_port = (
344
+ self .side_channel_port : int = (
342
345
envs .VLLM_NIXL_SIDE_CHANNEL_PORT +
343
346
vllm_config .parallel_config .data_parallel_rank_local *
344
347
vllm_config .parallel_config .tensor_parallel_size )
345
348
346
349
# Metadata.
347
- self .engine_id = engine_id
350
+ self .engine_id : EngineId = engine_id
348
351
self .tp_rank = get_tensor_model_parallel_rank ()
349
352
self .world_size = get_tensor_model_parallel_world_size ()
350
353
self .tp_group = get_tp_group ()
@@ -354,7 +357,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
354
357
355
358
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
356
359
# rank will still only pull from a single remote TP worker.
357
- self .kv_caches_base_addr : dict [str , list [int ]] = {}
360
+ self .kv_caches_base_addr : dict [EngineId , list [int ]] = {}
358
361
359
362
# Number of NIXL regions. Currently one region per cache
360
363
# (so 1 per layer for MLA, otherwise 2 per layer)
@@ -364,23 +367,23 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
364
367
# nixl_prepped_dlist_handle.
365
368
self .src_xfer_side_handle : int = 0
366
369
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
367
- self .dst_xfer_side_handles : dict [str , int ] = {}
370
+ self .dst_xfer_side_handles : dict [EngineId , int ] = {}
368
371
369
372
# Map of engine_id -> num_blocks. All ranks in the same deployment will
370
373
# have the same number of blocks.
371
- self .dst_num_blocks : dict [str , int ] = {}
374
+ self .dst_num_blocks : dict [EngineId , int ] = {}
372
375
self ._registered_descs : list [Any ] = []
373
376
374
377
# In progress transfers.
375
378
# [req_id -> list[handle]]
376
- self ._recving_transfers = defaultdict [str , list [Transfer ]](list )
379
+ self ._recving_transfers = defaultdict [ReqId , list [Transfer ]](list )
377
380
378
381
# Complete transfer tracker. Used by the rank 0 to track finished
379
382
# transactions on ranks 1 to N-1.
380
383
# [req_id -> count]
381
- self ._done_recving_count : defaultdict [str ,
384
+ self ._done_recving_count : defaultdict [ReqId ,
382
385
int ] = defaultdict (lambda : 0 )
383
- self ._done_sending_count : defaultdict [str ,
386
+ self ._done_sending_count : defaultdict [ReqId ,
384
387
int ] = defaultdict (lambda : 0 )
385
388
386
389
# Background thread for establishing new connections.
@@ -408,10 +411,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
408
411
self ._use_flashinfer = attn_backend == _Backend .FLASHINFER_VLLM_V1
409
412
logger .debug ("Detected attention backend %s" , self .backend_name )
410
413
411
- self ._tp_size : dict [str , int ] = {self .engine_id : self .world_size }
414
+ self ._tp_size : dict [EngineId , int ] = {self .engine_id : self .world_size }
412
415
# With heterogeneous TP, P must wait for all assigned D TP workers to
413
416
# finish reading before safely freeing the blocks.
414
- self .consumer_notification_counts_by_req = defaultdict [str , int ](int )
417
+ self .consumer_notification_counts_by_req = defaultdict [ReqId , int ](int )
415
418
416
419
@staticmethod
417
420
def _nixl_handshake_listener (metadata : NixlAgentMetadata ,
0 commit comments