Skip to content

Commit 6e2032e

Browse files
committed
feat: simplify 1p1d startup
Eliminates the need to launch the meta server in the 1p1d environment. Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent 531bf1e commit 6e2032e

File tree

2 files changed

+62
-20
lines changed

2 files changed

+62
-20
lines changed

examples/disaggregated-prefill-v1/offling_inference.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ def get_kv_transfer_config(role: Literal["kv_producer", "kv_consumer"],
1818
"kv_buffer_device": "npu",
1919
"kv_role": "{role}",
2020
"kv_rank": {kv_rank},
21-
"kv_parallel_size": 2,
22-
"kv_connector_extra_config": {{
23-
"local_server_id": "{local_server_id}"
24-
}}
21+
"kv_parallel_size": 2
2522
}}"""
2623

2724

vllm_ascend/distributed/llmdatadist_connector_v1.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,12 @@ def __init__(self, vllm_config: "VllmConfig") -> None:
160160
str, Any] = kv_transfer_config.get_from_extra_config("decode", {})
161161

162162
self._servers: List[ServerInfo] = get_servers_from_ranktable(
163-
GLOBAL_RANKTABLE, self.prefill_tp, self.decode_tp)
163+
GLOBAL_RANKTABLE, self.prefill_tp_size, self.decode_tp_size)
164+
165+
self._num_prefill_instances = len(
166+
self.get_servers_by_role(ServerRole.Prefill))
167+
self._num_decode_instances = len(
168+
self.get_servers_by_role(ServerRole.Decode))
164169

165170
def get_device(self, server_id: str, dp_rank: int,
166171
tp_rank: int) -> Union[DeviceInfo, None]:
@@ -182,6 +187,10 @@ def get_cluster_id(self, server_id: str, dp_rank: int,
182187
def get_servers_by_role(self, role: ServerRole) -> List[ServerInfo]:
183188
return [server for server in self._servers if server.role == role]
184189

190+
def is_1p1d(self) -> bool:
191+
return (self._num_prefill_instances == 1
192+
and self._num_decode_instances == 1)
193+
185194
@property
186195
def router_endpoint(self):
187196
for server in self._servers:
@@ -190,28 +199,28 @@ def router_endpoint(self):
190199
raise ValueError("Router endpoint not found")
191200

192201
@property
193-
def prefill_dp(self):
202+
def prefill_dp_size(self):
194203
candidate_keys = ["data_parallel_size", "dp_size", "dp"]
195204
return int(
196205
self._get_first_matching_value(self._prefill_parallel_config,
197206
candidate_keys, self._dp_size))
198207

199208
@property
200-
def prefill_tp(self):
209+
def prefill_tp_size(self):
201210
candidate_keys = ["tensor_parallel_size", "tp_size", "tp"]
202211
return int(
203212
self._get_first_matching_value(self._prefill_parallel_config,
204213
candidate_keys, self._tp_size))
205214

206215
@property
207-
def decode_dp(self):
216+
def decode_dp_size(self):
208217
candidate_keys = ["data_parallel_size", "dp_size", "dp"]
209218
return int(
210219
self._get_first_matching_value(self._decode_parallel_config,
211220
candidate_keys, self._dp_size))
212221

213222
@property
214-
def decode_tp(self):
223+
def decode_tp_size(self):
215224
candidate_keys = ["tensor_parallel_size", "tp_size", "tp"]
216225
return int(
217226
self._get_first_matching_value(self._decode_parallel_config,
@@ -311,7 +320,8 @@ def make_clusters(self):
311320
ServerRole.Prefill):
312321
for device in server.devices:
313322
target_tp_rank = self.tp_rank % min(
314-
self.cluster_info.prefill_tp, self.cluster_info.decode_tp)
323+
self.cluster_info.prefill_tp_size,
324+
self.cluster_info.decode_tp_size)
315325
if target_tp_rank == device.tp_rank:
316326
cluster = self.make_cluster(device.device_ip,
317327
device.cluster_id)
@@ -399,9 +409,21 @@ def __init__(self, vllm_config: "VllmConfig",
399409

400410
self.local_server_id = kv_transfer_config.get_from_extra_config(
401411
"local_server_id", None)
402-
assert (
403-
self.local_server_id is not None
404-
), "Cannot find `local_server_id` from `kv_transfer_config.kv_connector_extra_config`."
412+
if self.local_server_id is None:
413+
if not self.cluster_info.is_1p1d(
414+
) or self.cluster_info.prefill_dp_size != 1:
415+
raise ValueError(
416+
"Cannot find `local_server_id` from"
417+
" `kv_transfer_config.kv_connector_extra_config`.")
418+
# In a 1p1d configuration (1 prefill node and 1 decode node), the
419+
# server ID can be directly determined from the rank table based on
420+
# the KV role.
421+
servers = self.cluster_info.get_servers_by_role(
422+
ServerRole.Prefill if self.kv_role ==
423+
llm_datadist.LLMRole.PROMPT else ServerRole.Decode)
424+
assert len(servers) == 1, \
425+
f"Expected only one server for {self.kv_role}, but got {len(servers)}"
426+
self.local_server_id = servers[0].server_id
405427

406428
self.dp_rank = self._vllm_config.parallel_config.data_parallel_rank
407429
self.tp_size = self._vllm_config.parallel_config.tensor_parallel_size
@@ -467,8 +489,25 @@ def start_load_kv(self, forward_context: "ForwardContext",
467489
self._get_unique_req_id(req.request_id)
468490
for req in metadata.requests if not req.is_store
469491
]
470-
prefill_infos = fetch_prefill_info(self.cluster_info.router_endpoint,
471-
request_ids)
492+
if self.cluster_info.is_1p1d(
493+
) and self.cluster_info.prefill_dp_size == 1:
494+
# In a 1p1d configuration (1 prefill node and 1 decode node), the
495+
# server ID can be directly determined from the rank table based on
496+
# the KV role.
497+
servers = self.cluster_info.get_servers_by_role(ServerRole.Prefill)
498+
assert len(servers) == 1, \
499+
f"Expected only one server for {self.kv_role}, but got {len(servers)}"
500+
prefill_infos = {
501+
request_id: {
502+
"dp_rank": 0,
503+
"server_id": servers[0].server_id,
504+
}
505+
for request_id in request_ids
506+
}
507+
else:
508+
prefill_infos = fetch_prefill_info(
509+
self.cluster_info.router_endpoint, request_ids)
510+
472511
# If prefill_infos is None, it indicates that get_prefill_info failed.
473512
# Therefore, we need to recalculate the kv cache during the decoding
474513
# phase. If there is a performance issue, we should consider whether
@@ -518,8 +557,8 @@ def start_load_kv(self, forward_context: "ForwardContext",
518557
self.num_layers, kv_cache_shape, kv_hidden_dtype)
519558

520559
target_tp_rank = self.tp_rank % min(
521-
self.cluster_info.prefill_tp,
522-
self.cluster_info.decode_tp,
560+
self.cluster_info.prefill_tp_size,
561+
self.cluster_info.decode_tp_size,
523562
)
524563
remote_cluster_id = self.cluster_info.get_cluster_id(
525564
server_id, dp_rank, target_tp_rank)
@@ -662,9 +701,15 @@ def wait_for_save(self):
662701
# Release reference count
663702
self.llm_datadist_engine.kv_transfer.deallocate_cache(kv_buffer)
664703

665-
# Report prefill info to meta server
666-
report_prefill_info(self.cluster_info.router_endpoint,
667-
prefill_info_input)
704+
# If the cluster is configured as 1p1d (1 prefill node and 1 decode
705+
# node), and the data parallel size on the prefill node is 1, we don't
706+
# need to report the prefill information to the router. This is because
707+
# there is only one candidate server for the decode node to request the
708+
# KV cache from.
709+
if not self.cluster_info.is_1p1d(
710+
) or self.cluster_info.prefill_dp_size != 1:
711+
report_prefill_info(self.cluster_info.router_endpoint,
712+
prefill_info_input)
668713
logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank())
669714

670715
def _inject_kv_into_layer(

0 commit comments

Comments
 (0)