@@ -434,19 +434,25 @@ def __init__(self, vllm_config: "VllmConfig",
434
434
self .local_server_id )
435
435
self .llm_datadist_engine .prepare_data_dist ()
436
436
if self .kv_role == llm_datadist .LLMRole .DECODER :
437
- while True :
438
- try :
439
- # Each decoding rank should correspond to each prefilling rank.
440
- clusters = self .llm_datadist_engine .make_clusters ()
441
- _ , ret = self .llm_datadist_engine .datadist_engine .link_clusters (
442
- clusters , 20000 )
443
- logger .info (f"{ local_rank } link, ret={ ret } " )
444
- break
445
- except LLMException as e :
446
- logger .error (
447
- f"Failed to link clusters, local_rank { local_rank } , error: { e } "
448
- )
449
- time .sleep (1 )
437
+ # Each decoding rank should correspond to each prefilling rank.
438
+ clusters = self .llm_datadist_engine .make_clusters ()
439
+ while len (clusters ) > 0 :
440
+ overall_ret , link_rets = \
441
+ self .llm_datadist_engine .datadist_engine .link_clusters (
442
+ clusters , timeout = 3000 )
443
+
444
+ if overall_ret != LLMStatusCode .LLM_SUCCESS :
445
+ logger .warning (f"Failed to link clusters, { overall_ret = } " )
446
+ continue
447
+
448
+ for idx , link_ret in enumerate (link_rets ):
449
+ if link_ret == LLMStatusCode .LLM_SUCCESS :
450
+ clusters .pop (idx )
451
+
452
+ if len (clusters ) == 0 :
453
+ logger .info (f"Successfully linked clusters" )
454
+ else :
455
+ logger .warning (f"Still { len (clusters )} clusters to link" )
450
456
451
457
def start_load_kv (self , forward_context : "ForwardContext" ,
452
458
** kwargs ) -> None :
0 commit comments