@@ -160,7 +160,12 @@ def __init__(self, vllm_config: "VllmConfig") -> None:
160
160
str , Any ] = kv_transfer_config .get_from_extra_config ("decode" , {})
161
161
162
162
self ._servers : List [ServerInfo ] = get_servers_from_ranktable (
163
- GLOBAL_RANKTABLE , self .prefill_tp , self .decode_tp )
163
+ GLOBAL_RANKTABLE , self .prefill_tp_size , self .decode_tp_size )
164
+
165
+ self ._num_prefill_instances = len (
166
+ self .get_servers_by_role (ServerRole .Prefill ))
167
+ self ._num_decode_instances = len (
168
+ self .get_servers_by_role (ServerRole .Decode ))
164
169
165
170
def get_device (self , server_id : str , dp_rank : int ,
166
171
tp_rank : int ) -> Union [DeviceInfo , None ]:
@@ -182,6 +187,10 @@ def get_cluster_id(self, server_id: str, dp_rank: int,
182
187
def get_servers_by_role (self , role : ServerRole ) -> List [ServerInfo ]:
183
188
return [server for server in self ._servers if server .role == role ]
184
189
190
+ def is_1p1d (self ) -> bool :
191
+ return (self ._num_prefill_instances == 1
192
+ and self ._num_decode_instances == 1 )
193
+
185
194
@property
186
195
def router_endpoint (self ):
187
196
for server in self ._servers :
@@ -190,28 +199,28 @@ def router_endpoint(self):
190
199
raise ValueError ("Router endpoint not found" )
191
200
192
201
@property
193
- def prefill_dp (self ):
202
+ def prefill_dp_size (self ):
194
203
candidate_keys = ["data_parallel_size" , "dp_size" , "dp" ]
195
204
return int (
196
205
self ._get_first_matching_value (self ._prefill_parallel_config ,
197
206
candidate_keys , self ._dp_size ))
198
207
199
208
@property
200
- def prefill_tp (self ):
209
+ def prefill_tp_size (self ):
201
210
candidate_keys = ["tensor_parallel_size" , "tp_size" , "tp" ]
202
211
return int (
203
212
self ._get_first_matching_value (self ._prefill_parallel_config ,
204
213
candidate_keys , self ._tp_size ))
205
214
206
215
@property
207
- def decode_dp (self ):
216
+ def decode_dp_size (self ):
208
217
candidate_keys = ["data_parallel_size" , "dp_size" , "dp" ]
209
218
return int (
210
219
self ._get_first_matching_value (self ._decode_parallel_config ,
211
220
candidate_keys , self ._dp_size ))
212
221
213
222
@property
214
- def decode_tp (self ):
223
+ def decode_tp_size (self ):
215
224
candidate_keys = ["tensor_parallel_size" , "tp_size" , "tp" ]
216
225
return int (
217
226
self ._get_first_matching_value (self ._decode_parallel_config ,
@@ -311,7 +320,8 @@ def make_clusters(self):
311
320
ServerRole .Prefill ):
312
321
for device in server .devices :
313
322
target_tp_rank = self .tp_rank % min (
314
- self .cluster_info .prefill_tp , self .cluster_info .decode_tp )
323
+ self .cluster_info .prefill_tp_size ,
324
+ self .cluster_info .decode_tp_size )
315
325
if target_tp_rank == device .tp_rank :
316
326
cluster = self .make_cluster (device .device_ip ,
317
327
device .cluster_id )
@@ -399,9 +409,21 @@ def __init__(self, vllm_config: "VllmConfig",
399
409
400
410
self .local_server_id = kv_transfer_config .get_from_extra_config (
401
411
"local_server_id" , None )
402
- assert (
403
- self .local_server_id is not None
404
- ), "Cannot find `local_server_id` from `kv_transfer_config.kv_connector_extra_config`."
412
+ if self .local_server_id is None :
413
+ if not self .cluster_info .is_1p1d (
414
+ ) or self .cluster_info .prefill_dp_size != 1 :
415
+ raise ValueError (
416
+ "Cannot find `local_server_id` from"
417
+ " `kv_transfer_config.kv_connector_extra_config`." )
418
+ # In a 1p1d configuration (1 prefill node and 1 decode node), the
419
+ # server ID can be directly determined from the rank table based on
420
+ # the KV role.
421
+ servers = self .cluster_info .get_servers_by_role (
422
+ ServerRole .Prefill if self .kv_role ==
423
+ llm_datadist .LLMRole .PROMPT else ServerRole .Decode )
424
+ assert len (servers ) == 1 , \
425
+ f"Expected only one server for { self .kv_role } , but got { len (servers )} "
426
+ self .local_server_id = servers [0 ].server_id
405
427
406
428
self .dp_rank = self ._vllm_config .parallel_config .data_parallel_rank
407
429
self .tp_size = self ._vllm_config .parallel_config .tensor_parallel_size
@@ -467,8 +489,25 @@ def start_load_kv(self, forward_context: "ForwardContext",
467
489
self ._get_unique_req_id (req .request_id )
468
490
for req in metadata .requests if not req .is_store
469
491
]
470
- prefill_infos = fetch_prefill_info (self .cluster_info .router_endpoint ,
471
- request_ids )
492
+ if self .cluster_info .is_1p1d (
493
+ ) and self .cluster_info .prefill_dp_size == 1 :
494
+ # In a 1p1d configuration (1 prefill node and 1 decode node), the
495
+ # server ID can be directly determined from the rank table based on
496
+ # the KV role.
497
+ servers = self .cluster_info .get_servers_by_role (ServerRole .Prefill )
498
+ assert len (servers ) == 1 , \
499
+ f"Expected only one server for { self .kv_role } , but got { len (servers )} "
500
+ prefill_infos = {
501
+ request_id : {
502
+ "dp_rank" : 0 ,
503
+ "server_id" : servers [0 ].server_id ,
504
+ }
505
+ for request_id in request_ids
506
+ }
507
+ else :
508
+ prefill_infos = fetch_prefill_info (
509
+ self .cluster_info .router_endpoint , request_ids )
510
+
472
511
# If prefill_infos is None, it indicates that get_prefill_info failed.
473
512
# Therefore, we need to recalculate the kv cache during the decoding
474
513
# phase. If there is a performance issue, we should consider whether
@@ -518,8 +557,8 @@ def start_load_kv(self, forward_context: "ForwardContext",
518
557
self .num_layers , kv_cache_shape , kv_hidden_dtype )
519
558
520
559
target_tp_rank = self .tp_rank % min (
521
- self .cluster_info .prefill_tp ,
522
- self .cluster_info .decode_tp ,
560
+ self .cluster_info .prefill_tp_size ,
561
+ self .cluster_info .decode_tp_size ,
523
562
)
524
563
remote_cluster_id = self .cluster_info .get_cluster_id (
525
564
server_id , dp_rank , target_tp_rank )
@@ -662,9 +701,15 @@ def wait_for_save(self):
662
701
# Release reference count
663
702
self .llm_datadist_engine .kv_transfer .deallocate_cache (kv_buffer )
664
703
665
- # Report prefill info to meta server
666
- report_prefill_info (self .cluster_info .router_endpoint ,
667
- prefill_info_input )
704
+ # If the cluster is configured as 1p1d (1 prefill node and 1 decode
705
+ # node), and the data parallel size on the prefill node is 1, we don't
706
+ # need to report the prefill information to the router. This is because
707
+ # there is only one candidate server for the decode node to request the
708
+ # KV cache from.
709
+ if not self .cluster_info .is_1p1d (
710
+ ) or self .cluster_info .prefill_dp_size != 1 :
711
+ report_prefill_info (self .cluster_info .router_endpoint ,
712
+ prefill_info_input )
668
713
logger .info ("[rank%d][P]: KV send DONE." , torch .distributed .get_rank ())
669
714
670
715
def _inject_kv_into_layer (
0 commit comments