16
16
from vllm import envs
17
17
from vllm .config import VllmConfig
18
18
from vllm .distributed .kv_transfer .kv_connector .v1 .base import (
19
- KVConnectorBase_V1 , KVConnectorMetadata , KVConnectorRole , KVTransferParams )
19
+ KVConnectorBase_V1 , KVConnectorMetadata , KVConnectorRole )
20
20
from vllm .distributed .parallel_state import (
21
21
get_tensor_model_parallel_rank , get_tensor_model_parallel_world_size ,
22
22
get_tp_group )
44
44
NixlWrapper = None
45
45
46
46
47
- @dataclass
48
- class NixlKVTransferParams (KVTransferParams ):
49
-
50
- def __init__ (
51
- self ,
52
- do_remote_prefill : bool ,
53
- do_remote_decode : bool ,
54
- remote_block_ids : Optional [list [int ]] = None ,
55
- remote_host : Optional [str ] = None ,
56
- remote_port : Optional [int ] = None ,
57
- remote_engine_id : Optional [str ] = None ,
58
- ):
59
- self .do_remote_prefill = do_remote_prefill
60
- self .do_remote_decode = do_remote_decode
61
- self .remote_block_ids = remote_block_ids
62
- self .remote_host = remote_host
63
- self .remote_port = remote_port
64
- self .remote_engine_id = remote_engine_id
65
-
66
- @staticmethod
67
- def from_raw_dict (
68
- raw_dict : Optional [dict [str ,
69
- Any ]]) -> Optional ["NixlKVTransferParams" ]:
70
-
71
- # If no raw transfer params passed, return None.
72
- if raw_dict is None :
73
- return None
74
-
75
- # Validate the request is formatted properly.
76
- if (("do_remote_prefill" not in raw_dict )
77
- or ("do_remote_decode" not in raw_dict )
78
- or ("remote_block_ids" not in raw_dict )
79
- or ("remote_host" not in raw_dict )
80
- or ("remote_port" not in raw_dict )
81
- or ("remote_engine_id" not in raw_dict )):
82
- logger .warning (
83
- "Got invalid KVTransferParams: %s. This "
84
- "request will not utilize KVTransfer" , raw_dict )
85
- return None
86
-
87
- return NixlKVTransferParams (
88
- do_remote_prefill = raw_dict ["do_remote_prefill" ],
89
- do_remote_decode = raw_dict ["do_remote_decode" ],
90
- remote_block_ids = raw_dict ["remote_block_ids" ],
91
- remote_host = raw_dict ["remote_host" ],
92
- remote_port = raw_dict ["remote_port" ],
93
- remote_engine_id = raw_dict ["remote_engine_id" ],
94
- )
95
-
96
-
97
47
class NixlAgentMetadata (
98
48
msgspec .Struct ,
99
49
omit_defaults = True , # type: ignore[call-arg]
@@ -123,25 +73,18 @@ def add_new_req(
123
73
self ,
124
74
request_id : str ,
125
75
local_block_ids : list [int ],
126
- kv_transfer_params : NixlKVTransferParams ,
76
+ kv_transfer_params : dict [ str , Any ] ,
127
77
):
128
- assert request_id not in self .requests
129
- assert kv_transfer_params .remote_block_ids is not None
130
- assert kv_transfer_params .remote_engine_id is not None
131
- assert kv_transfer_params .remote_host is not None
132
- assert kv_transfer_params .remote_port is not None
133
-
134
78
self .requests [request_id ] = ReqMeta (
135
79
local_block_ids = local_block_ids ,
136
- remote_block_ids = kv_transfer_params . remote_block_ids ,
137
- remote_engine_id = kv_transfer_params . remote_engine_id ,
138
- remote_host = kv_transfer_params . remote_host ,
139
- remote_port = kv_transfer_params . remote_port ,
80
+ remote_block_ids = kv_transfer_params [ " remote_block_ids" ] ,
81
+ remote_engine_id = kv_transfer_params [ " remote_engine_id" ] ,
82
+ remote_host = kv_transfer_params [ " remote_host" ] ,
83
+ remote_port = kv_transfer_params [ " remote_port" ] ,
140
84
)
141
85
142
86
143
87
class NixlConnector (KVConnectorBase_V1 ):
144
- _KVTransferParams : type [NixlKVTransferParams ] = NixlKVTransferParams
145
88
146
89
def __init__ (self , vllm_config : VllmConfig , role : KVConnectorRole ):
147
90
assert vllm_config .kv_transfer_config is not None
@@ -253,52 +196,52 @@ def get_num_new_matched_tokens(
253
196
asynchronously (between scheduler steps).
254
197
"""
255
198
199
+ params = request .kv_transfer_params
256
200
logger .debug (
257
201
"NIXLConnector get_num_new_matched_tokens: "
258
202
"num_computed_tokens=%s, kv_transfer_params=%s" ,
259
- num_computed_tokens , request .kv_transfer_params )
260
-
261
- # No KVTransfer for this request.
262
- if request .kv_transfer_params is None :
263
- return 0 , False
264
- assert isinstance (request .kv_transfer_params , NixlKVTransferParams )
203
+ num_computed_tokens , params )
265
204
266
- # Remote prefill: get all prompt blocks from remote.
267
- if request . kv_transfer_params . do_remote_prefill :
205
+ if params is not None and params . get ( "do_remote_prefill" ):
206
+ # Remote prefill: get all prompt blocks from remote.
268
207
assert num_computed_tokens % self .block_size == 0
269
208
rounded_num_prompt_tokens = round_down (
270
209
len (request .prompt_token_ids ), self .block_size )
271
210
count = max (rounded_num_prompt_tokens - num_computed_tokens , 0 )
272
211
return count , count > 0
273
212
213
+ # No remote prefill for this request.
274
214
return 0 , False
275
215
276
216
def update_state_after_alloc (self , request : "Request" ,
277
217
blocks : "KVCacheBlocks" ,
278
218
num_external_tokens : int ):
279
219
220
+ params = request .kv_transfer_params
280
221
logger .debug (
281
222
"NIXLConnector update_state_after_alloc: "
282
223
"num_external_tokens=%s, kv_transfer_params=%s" ,
283
- num_external_tokens , request . kv_transfer_params )
224
+ num_external_tokens , params )
284
225
285
- if request .kv_transfer_params is None :
286
- return
287
-
288
- assert isinstance (request .kv_transfer_params , NixlKVTransferParams )
289
- if request .kv_transfer_params .do_remote_prefill :
226
+ if params is not None and params .get ("do_remote_prefill" ):
290
227
# NOTE(rob): if prompt < block_size, no remote blocks
291
228
# since the remote only sends fully computed blocks, so
292
229
# skip recving for this request. num_external_tokens
293
230
# should be 0 if there are no remote blocks.
294
- if request .kv_transfer_params .remote_block_ids :
295
- # Get unhashed blocks to pull from remote.
296
- self ._reqs_need_recv [request .request_id ] = (
297
- request , blocks .get_unhashed_block_ids ())
231
+ if params .get ("remote_block_ids" ):
232
+ if all (p in params for p in ("remote_engine_id" , "remote_host" ,
233
+ "remote_port" )):
234
+ # Get unhashed blocks to pull from remote.
235
+ self ._reqs_need_recv [request .request_id ] = (
236
+ request , blocks .get_unhashed_block_ids ())
237
+ else :
238
+ logger .warning (
239
+ "Got invalid KVTransferParams: %s. This "
240
+ "request will not utilize KVTransfer" , params )
298
241
else :
299
242
assert num_external_tokens == 0
300
243
# Only trigger 1 KV transfer per request.
301
- request . kv_transfer_params . do_remote_prefill = False
244
+ params [ " do_remote_prefill" ] = False
302
245
303
246
def build_connector_meta (
304
247
self ,
@@ -308,7 +251,7 @@ def build_connector_meta(
308
251
309
252
# Loop through scheduled reqs and convert to ReqMeta.
310
253
for req_id , (req , block_ids ) in self ._reqs_need_recv .items ():
311
- assert isinstance ( req .kv_transfer_params , NixlKVTransferParams )
254
+ assert req .kv_transfer_params is not None
312
255
meta .add_new_req (
313
256
request_id = req_id ,
314
257
local_block_ids = block_ids ,
@@ -330,34 +273,30 @@ def request_finished(
330
273
should be freed now or will be sent asynchronously and freed later.
331
274
"""
332
275
276
+ params = request .kv_transfer_params
333
277
logger .debug (
334
- "NIXLConnector request_finished, "
335
- "request_status=%s, kv_transfer_params=%s" , request .status ,
336
- request .kv_transfer_params )
337
-
338
- if request .kv_transfer_params is None :
339
- return False , None
340
- assert isinstance (request .kv_transfer_params , NixlKVTransferParams )
278
+ "NIXLConnector request_finished, request_status=%s, "
279
+ "kv_transfer_params=%s" , request .status , params )
341
280
342
- if (( not request . kv_transfer_params . do_remote_decode )
343
- or ( request .status != RequestStatus .FINISHED_LENGTH_CAPPED ) ):
281
+ if (params is None or not params . get ( " do_remote_decode" )
282
+ or request .status != RequestStatus .FINISHED_LENGTH_CAPPED ):
344
283
return False , None
345
284
346
285
# Get computed blocks.
347
286
all_full = request .num_computed_tokens % self .block_size == 0
348
- computed_block_ids = ( block_ids if all_full else block_ids [:- 1 ])
287
+ computed_block_ids = block_ids if all_full else block_ids [:- 1 ]
349
288
350
289
# If prompt < block_size, no xfer so free blocks immediately.
351
290
delay_free_blocks = len (computed_block_ids ) > 0
352
291
353
- return delay_free_blocks , NixlKVTransferParams (
292
+ return delay_free_blocks , dict (
354
293
do_remote_prefill = True ,
355
294
do_remote_decode = False ,
356
295
remote_block_ids = computed_block_ids ,
357
296
remote_engine_id = self .engine_id ,
358
297
remote_host = envs .VLLM_NIXL_SIDE_CHANNEL_HOST ,
359
298
remote_port = envs .VLLM_NIXL_SIDE_CHANNEL_PORT ,
360
- ). __dict__
299
+ )
361
300
362
301
363
302
class NixlConnectorWorker :
0 commit comments