Skip to content

Commit 574ad60

Browse files
njhillsdavidbd
andauthored
[KVConnector] Always call connector clear_metadata() at end of step (#20756)
Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: David Ben-David <sdavidbd@gmail.com>
1 parent fdadb6f commit 574ad60

File tree

4 files changed

+25
-26
lines changed

4 files changed

+25
-26
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class KVConnectorRole(enum.Enum):
5757
WORKER = 1
5858

5959

60-
class KVConnectorMetadata:
60+
class KVConnectorMetadata(ABC): # noqa: B024
6161
"""
6262
Abstract Metadata used to communicate between the
6363
Scheduler KVConnector and Worker KVConnector.
@@ -71,7 +71,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
7171
logger.warning(
7272
"Initializing KVConnectorBase_V1. This API is experimental and "
7373
"subject to change in the future as we iterate the design.")
74-
self._connector_metadata = KVConnectorMetadata()
74+
self._connector_metadata: Optional[KVConnectorMetadata] = None
7575
self._vllm_config = vllm_config
7676
self._role = role
7777

@@ -102,7 +102,7 @@ def clear_connector_metadata(self) -> None:
102102
This function should be called by the model runner every time
103103
after the model execution.
104104
"""
105-
self._connector_metadata = KVConnectorMetadata()
105+
self._connector_metadata = None
106106

107107
def _get_connector_metadata(self) -> KVConnectorMetadata:
108108
"""Get the connector metadata.
@@ -112,6 +112,9 @@ def _get_connector_metadata(self) -> KVConnectorMetadata:
112112
Returns:
113113
ConnectorMetadata: the connector metadata.
114114
"""
115+
116+
# Should only be called while set to valid metadata.
117+
assert self._connector_metadata is not None
115118
return self._connector_metadata
116119

117120
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):

vllm/v1/executor/multiproc_executor.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -250,28 +250,24 @@ def _aggregate_workers_output(
250250
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
251251
# aggregate finished_sending, finished_recving from all workers
252252

253-
finished_sending = set[str]()
254-
finished_recving = set[str]()
255-
for output in outputs:
256-
# update finished_sending
257-
for req_id in output.finished_sending or []:
258-
new_count = self._send_remaining_count[req_id] - 1
253+
def update_finished_set(req_ids: Optional[set[str]],
254+
remaining_count_dict: dict[str, int],
255+
finished_set: set[str]) -> None:
256+
for req_id in req_ids or ():
257+
new_count = remaining_count_dict[req_id] - 1
259258
if new_count == 0:
260-
# got response from all workers, report back to scheduler
261-
finished_sending.add(req_id)
262-
del self._send_remaining_count[req_id]
259+
finished_set.add(req_id)
260+
del remaining_count_dict[req_id]
263261
else:
264-
self._send_remaining_count[req_id] = new_count
262+
remaining_count_dict[req_id] = new_count
265263

266-
# update finished_recving
267-
for req_id in output.finished_recving or []:
268-
new_count = self._recv_remaining_count[req_id] - 1
269-
if new_count == 0:
270-
# got response from all workers, report back to scheduler
271-
finished_recving.add(req_id)
272-
del self._recv_remaining_count[req_id]
273-
else:
274-
self._recv_remaining_count[req_id] = new_count
264+
finished_sending = set[str]()
265+
finished_recving = set[str]()
266+
for output in outputs:
267+
update_finished_set(output.finished_sending,
268+
self._send_remaining_count, finished_sending)
269+
update_finished_set(output.finished_recving,
270+
self._recv_remaining_count, finished_recving)
275271

276272
# select output of the worker specified by output_rank
277273
output = outputs[self.output_rank]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,10 +1539,6 @@ def execute_model(
15391539
attn_metadata,
15401540
)
15411541

1542-
# Clear KVConnector state after all KVs are generated.
1543-
if has_kv_transfer_group():
1544-
get_kv_transfer_group().clear_connector_metadata()
1545-
15461542
self.eplb_step()
15471543

15481544
return ModelRunnerOutput(

vllm/v1/worker/gpu_worker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,10 @@ def execute_model(
338338
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
339339
output.finished_sending = finished_sending
340340
output.finished_recving = finished_recving
341+
342+
# Clear KVConnector state for this step.
343+
get_kv_transfer_group().clear_connector_metadata()
344+
341345
# with a connector, the scheduler expects output from all workers
342346
return output
343347

0 commit comments

Comments
 (0)