diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 7d7f488f47..8a680b3ed6 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -24,6 +24,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.layer import (wait_for_kv_layer_from_connector, + maybe_save_kv_layer_to_connector) from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils import direct_register_custom_op from vllm.v1.core.sched.output import SchedulerOutput @@ -444,8 +446,11 @@ def unified_ascend_attention_with_output( output: torch.Tensor, layer_name: str, ) -> None: + wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, @@ -456,7 +461,7 @@ def unified_ascend_attention_with_output( attn_metadata, output, trace_flag=False) - return + maybe_save_kv_layer_to_connector(layer_name, kv_cache) def unified_attention_with_output_fake( diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index b9e51a3e61..f035c67a3e 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -10,6 +10,8 @@ from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.attention.layer import (wait_for_kv_layer_from_connector, + maybe_save_kv_layer_to_connector) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.utils import cdiv, round_down @@ -1078,6 +1080,8 @@ def forward( prefill_k_pe = k_pe[num_decode_tokens:] else: decode_hs_or_q_c = hidden_states_or_q_c + if has_prefill: + wait_for_kv_layer_from_connector(layer.layer_name) if has_decode: decode_k_nope = None assert attn_metadata.decode is not None @@ -1208,5 +1212,7 @@ def forward( current_ms_metadata.after_comm_event.record() else: output[:num_decode_tokens] = output_decode + if has_prefill: + maybe_save_kv_layer_to_connector(layer.layer_name, kv_cache) return output_padded diff --git a/vllm_ascend/distributed/kv_transfer/cpu_offloading_connector.py b/vllm_ascend/distributed/kv_transfer/cpu_offloading_connector.py new file mode 100644 index 0000000000..837f3935fc --- /dev/null +++ b/vllm_ascend/distributed/kv_transfer/cpu_offloading_connector.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import defaultdict +from dataclasses import dataclass +from itertools import chain +from typing import TYPE_CHECKING, Any, Sequence, Optional +import queue +import threading +import torch +from vllm.attention import AttentionType +from vllm.attention.layer import Attention +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import (get_pp_group, get_tp_group) +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.v1.engine.metadata import MLAConfig, MetadataServer +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec) +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request +logger = init_logger(__name__) +@dataclass +class ReqMeta: + gpu_block_ids: list[int] + cpu_block_ids: list[int] + num_scheduled_tokens: int + num_computed_tokens: int + num_gpu_computed_tokens: int + num_cpu_computed_tokens: int + def update(self, other: "ReqMeta"): + self.gpu_block_ids.extend(other.gpu_block_ids) + self.cpu_block_ids.extend(other.cpu_block_ids) + self.num_scheduled_tokens = other.num_scheduled_tokens + self.num_computed_tokens = other.num_computed_tokens + self.num_gpu_computed_tokens = other.num_gpu_computed_tokens + self.num_cpu_computed_tokens = other.num_cpu_computed_tokens +@dataclass +class CPUOffloadingConnectorMetadata(KVConnectorMetadata): + requests: dict[str, ReqMeta] + finished_req_ids: set[str] +class CPUOffloadingConnector(KVConnectorBase_V1): + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + if not vllm_config.cache_config.enable_prefix_caching: + self.connector_scheduler: Optional[ + CPUOffloadingConnectorScheduler] = None + self.connector_worker: Optional[ + CPUOffloadingConnectorWorker] = None + elif role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = CPUOffloadingConnectorScheduler( + vllm_config) + self.connector_worker = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = CPUOffloadingConnectorWorker(vllm_config) + # ============================== + # Worker-side methods + # ============================== + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + if self.connector_worker is not None: + assert isinstance(connector_metadata, + CPUOffloadingConnectorMetadata) + self.connector_worker.bind_connector_metadata(connector_metadata) + def clear_connector_metadata(self) -> None: + assert self.connector_worker is not None + self.connector_worker.clear_connector_metadata() + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + if self.connector_worker is not None: + self.connector_worker.register_kv_caches(kv_caches) + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + if self.connector_worker is not None: + self.connector_worker.start_load_kv() + def wait_for_layer_load(self, layer_name: str) -> None: + if self.connector_worker is not None: + self.connector_worker.wait_for_layer_load() + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + pass + def wait_for_save(self): + pass + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + assert self.connector_worker is not None + return self.connector_worker.get_finished(), None + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + if self.connector_scheduler is not None: + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + return 0, False + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + if self.connector_scheduler is not None: + return self.connector_scheduler.update_state_after_alloc(request) + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + if self.connector_scheduler is not None: + return self.connector_scheduler.build_connector_meta( + scheduler_output) + return KVConnectorMetadata() + def request_finished( + self, request: "Request", + block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]: + if self.connector_scheduler is not None: + self.connector_scheduler.request_finished(request) + return True, None + def sending_finished(self, request: "Request"): + assert self.connector_scheduler is not None + self.connector_scheduler.sending_finished(request) +class CPUOffloadingConnectorScheduler: + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.use_mla = vllm_config.model_config.use_mla + self.num_gpu_computed_tokens: dict[str, int] = {} + self.num_cpu_computed_tokens: dict[str, int] = {} + self.allocated_req_ids: set[str] = set() + self.finished_req_ids: list[str] = [] + self.zmq_rpc_client = MetadataServer.ZMQRPCClient() + self.zmq_rpc_client.call("post_init") + if vllm_config.kv_transfer_config is not None: + self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config( + "swap_in_threshold", 0) + else: + self.swap_in_threshold = 0 + logger.info(f"swap_in_threshold: {self.swap_in_threshold}") + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call( + "get_matched_num_and_touch", request) + self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens + self.num_cpu_computed_tokens[ + request.request_id] = num_cpu_computed_tokens + if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold: + return num_cpu_computed_tokens - num_computed_tokens, load_async + else: + return 0, load_async + def update_state_after_alloc(self, request: "Request"): + self.allocated_req_ids.add(request.request_id) + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + num_tokens = { + req.req_id: + req.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req.req_id] + for req in chain(scheduler_output.scheduled_new_reqs, + scheduler_output.scheduled_cached_reqs) + } + unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() - + self.allocated_req_ids - + scheduler_output.num_scheduled_tokens.keys()) + new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots", + num_tokens, + unallocated_req_ids) + metadata = CPUOffloadingConnectorMetadata( + requests={}, + finished_req_ids=set(self.finished_req_ids), + ) + for req in scheduler_output.scheduled_new_reqs: + req_id = req.req_id + metadata.requests[req_id] = ReqMeta( + gpu_block_ids=req.block_ids[0], + cpu_block_ids=new_cpu_block_ids.get(req_id, []), + num_scheduled_tokens=scheduler_output. + num_scheduled_tokens[req_id], + num_computed_tokens=req.num_computed_tokens, + num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id], + num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id]) + for new_req in scheduler_output.scheduled_cached_reqs: + req_id = new_req.req_id + metadata.requests[req_id] = ReqMeta( + gpu_block_ids=new_req.new_block_ids[0], + cpu_block_ids=new_cpu_block_ids.get(req_id, []), + num_scheduled_tokens=scheduler_output. + num_scheduled_tokens[req_id], + num_computed_tokens=new_req.num_computed_tokens, + num_gpu_computed_tokens=new_req.num_computed_tokens, + num_cpu_computed_tokens=new_req.num_computed_tokens) + self.num_gpu_computed_tokens.clear() + self.num_cpu_computed_tokens.clear() + self.allocated_req_ids.clear() + self.finished_req_ids.clear() + return metadata + def request_finished(self, request: "Request"): + self.finished_req_ids.append(request.request_id) + def sending_finished(self, request: "Request"): + self.zmq_rpc_client.call("cache_and_free_slots", request) +class CPUOffloadingConnectorWorker: + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.pp_rank = get_pp_group().rank_in_group + self.tp_group = get_tp_group() + self.tp_rank = self.tp_group.rank_in_group + self.tp_world_size = self.tp_group.world_size + self.use_mla = vllm_config.model_config.use_mla + self.requests: dict[str, ReqMeta] = {} + self.load_stream = torch.npu.Stream() + self.save_stream = torch.npu.Stream() + self.zmq_rpc_client = MetadataServer.ZMQRPCClient() + self.load_block_mapping = [] + self.save_input_queue: queue.Queue[tuple[str, ReqMeta]] = queue.Queue() + self.save_output_queue: queue.Queue[str] = queue.Queue() + self.save_thread = threading.Thread(target=self._save_listener) + self.save_thread.start() + self.done_sending_count: defaultdict[str, int] = defaultdict(int) + def bind_connector_metadata( + self, connector_metadata: CPUOffloadingConnectorMetadata) -> None: + for req_id, req in connector_metadata.requests.items(): + if req_id in self.requests: + self.requests[req_id].update(req) + req = self.requests[req_id] + else: + self.requests[req_id] = req + for i in range(req.num_gpu_computed_tokens // self.block_size, + req.num_computed_tokens // self.block_size): + self.load_block_mapping.append( + (req.cpu_block_ids[i], req.gpu_block_ids[i])) + for req_id in connector_metadata.finished_req_ids: + self.save_input_queue.put((req_id, self.requests[req_id])) + def clear_connector_metadata(self) -> None: + self.load_block_mapping.clear() + def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]): + self.gpu_kv_caches = kv_caches + model_config = self.vllm_config.model_config + mla_config: Optional[MLAConfig] = None + if model_config.use_mla: + mla_config = MLAConfig( + model_config.hf_text_config.kv_lora_rank, + model_config.hf_text_config.qk_rope_head_dim) + self.cpu_kv_caches = list( + self.zmq_rpc_client.call( + "init_cpu_kv_caches", + self.pp_rank, + self.tp_rank, + get_kv_cache_spec(self.vllm_config), + mla_config, + ).values()) + def start_load_kv(self) -> None: + self.current_layer = 0 + self.gpu_kv_caches_load_iter = iter(self.gpu_kv_caches.values()) + self.load_kv_layer(0) + def wait_for_layer_load(self) -> None: + # TODO: Replace with `torch.npu.current_stream().wait_stream(self.load_stream)` after fixing the bug. + self.load_stream.synchronize() + self.current_layer += 1 + self.load_kv_layer(self.current_layer) + def load_kv_layer(self, layer: int): + if layer == len(self.gpu_kv_caches): + return + gpu_kv_caches = next(self.gpu_kv_caches_load_iter) + cpu_kv_caches = self.cpu_kv_caches[layer] + with torch.npu.stream(self.load_stream): + for cpu_block_id, gpu_block_id in self.load_block_mapping: + for gpu_layer_part, cpu_layer_part in zip( + gpu_kv_caches, cpu_kv_caches): + gpu_layer_part[gpu_block_id].copy_( + cpu_layer_part[cpu_block_id], non_blocking=True) + def get_finished(self) -> set[str]: + done_sending: set[str] = set() + while True: + try: + id = self.save_output_queue.get_nowait() + except queue.Empty: + break + done_sending.add(id) + for id in done_sending: + del self.requests[id] + if self.tp_world_size == 1: + return done_sending + if self.tp_rank == 0: + for req_id in done_sending: + self.done_sending_count[req_id] += 1 + other_ranks_finished_ids: list[str] = [] + for i in range(1, self.tp_world_size): + other_ranks_finished_ids.extend( + self.tp_group.recv_object(src=i)) + for req_id in other_ranks_finished_ids: + self.done_sending_count[req_id] += 1 + all_done_sending: set[str] = set() + for req_id in list(self.done_sending_count.keys()): + if self.done_sending_count[req_id] == self.tp_world_size: + del self.done_sending_count[req_id] + all_done_sending.add(req_id) + return all_done_sending + else: + self.tp_group.send_object(done_sending, dst=0) + return done_sending + def _save_listener(self): + save_block_mapping = [] + while True: + req_id, req = self.save_input_queue.get() + for i in range( + req.num_cpu_computed_tokens // self.block_size, + min((req.num_computed_tokens + req.num_scheduled_tokens) // + self.block_size, len(req.cpu_block_ids))): + save_block_mapping.append( + (req.gpu_block_ids[i], req.cpu_block_ids[i])) + with torch.npu.stream(self.save_stream): + # MLA: kv_layer is tuple[tensor, tensor] means (rope, nope). + # non-MLA: kv_layer is list[tensor], typically means [k, v]. + if self.use_mla: + start, step = self.tp_rank, self.tp_world_size + else: + start, step = 0, 1 + for i in range(start, len(save_block_mapping), step): + gpu_block_id, cpu_block_id = save_block_mapping[i] + for cpu_kv_caches, gpu_kv_caches in zip( + self.cpu_kv_caches, self.gpu_kv_caches.values()): + for cpu_layer_part, gpu_layer_part in zip( + cpu_kv_caches, gpu_kv_caches): + cpu_layer_part[cpu_block_id].copy_( + gpu_layer_part[gpu_block_id], + non_blocking=True) + self.save_stream.synchronize() + self.save_output_queue.put(req_id) + save_block_mapping.clear() +# Copied from vllm_ascend/worker/model_runner_v1.py. +def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: + forward_ctx = vllm_config.compilation_config.static_forward_context + block_size = vllm_config.cache_config.block_size + use_mla = vllm_config.model_config.use_mla + kv_cache_spec: dict[str, KVCacheSpec] = {} + for layer_name, attn_module in forward_ctx.items(): + if isinstance(attn_module, FusedMoE): + continue + assert isinstance(attn_module, Attention) + if attn_module.attn_type == AttentionType.DECODER: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + use_mla=use_mla) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + return kv_cache_spec \ No newline at end of file diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index d7f68a12c7..1d00a5ad8b 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -669,6 +669,7 @@ def __init__( eps=config.rms_norm_eps) self.routed_scaling_factor = config.routed_scaling_factor self.first_k_dense_replace = config.first_k_dense_replace + self.layer_name = prefix def forward( self, @@ -676,9 +677,14 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, + attn_metadata: Optional[Union["AttentionMetadata", + dict[str, "AttentionMetadata"]]] = None, replace_allreduce: bool = False, ) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[f"{self.layer_name}.self_attn.attn"] # Self Attention if attn_metadata is not None and attn_metadata.num_decodes > 0: mla_moe_communication = self.mla_moe_communication and replace_allreduce @@ -803,7 +809,8 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, + attn_metadata: Optional[Union["AttentionMetadata", + dict[str, "AttentionMetadata"]]] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -975,7 +982,8 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, + attn_metadata: Optional[Union["AttentionMetadata", + dict[str, "AttentionMetadata"]]] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 50d610e94b..da4751de52 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -24,7 +24,7 @@ import weakref from contextlib import contextmanager, nullcontext from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import numpy as np import numpy.typing as npt @@ -37,6 +37,9 @@ from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import get_dp_group, get_pp_group from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY @@ -1116,24 +1119,31 @@ def _process_reqs( extra_builder_kwargs['graph_pad_size'] = graph_pad_size - if self.vllm_config.model_config.use_mla: - attn_metadata = self.attn_metadata_builder.build( # type: ignore - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_attn_metadata=common_attn_metadata, - common_prefix_len=None, - **extra_builder_kwargs, - ) - else: - attn_metadata = self.attn_metadata_builder.build( # type: ignore - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_prefix_len=None, - **extra_builder_kwargs, - ) - attn_metadata.num_input_tokens = num_input_tokens + attn_metadata: dict[str, Any] = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + if self.vllm_config.model_config.use_mla: + attn_metadata_i = self.attn_metadata_builder.build( # type: ignore + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_attn_metadata=common_attn_metadata, + common_prefix_len=None, + **extra_builder_kwargs, + ) + else: + attn_metadata_i = self.attn_metadata_builder.build( # type: ignore + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=None, + **extra_builder_kwargs, + ) + attn_metadata_i.num_input_tokens = num_input_tokens + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i # Prepare input_ids token_indices = (positions_np + @@ -1400,7 +1410,7 @@ def _get_spec_token_ids( positions: torch.Tensor, num_scheduled_tokens: int, hidden_states: torch.Tensor, - attn_metadata: SpecDecodeMetadata, + attn_metadata: dict[str, SpecDecodeMetadata], aux_hidden_states: torch.Tensor = None, ) -> Optional[list[list[int]]]: if not self.use_spec_decode: @@ -1682,6 +1692,11 @@ def execute_model( attn_metadata, aux_hidden_states, ) + + # Clear KVConnector state after all KVs are generated. + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + if vllm_version_is("0.9.1"): model_runner_output = ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1860,8 +1875,14 @@ def _dummy_run( self.vllm_config, num_tokens=num_tokens): if self.torchair_graph_enabled and not with_prefill: - attn_metadata = self.attn_metadata_builder.build_dummy( - num_reqs=num_tokens, num_actual_tokens=1) + attn_metadata: dict[str, Any] = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_spec in self.kv_cache_config.kv_cache_groups: + attn_metadata_i = self.attn_metadata_builder.build_dummy( + num_reqs=num_tokens, num_actual_tokens=1) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i # Only mark static while compiling if is_compile: torch._dynamo.mark_static(input_ids) @@ -2298,7 +2319,7 @@ def _generate_mtp_token_ids( positions: torch.Tensor, num_scheduled_tokens: int, hidden_states: torch.Tensor, - attn_metadata: SpecDecodeMetadata, + attn_metadata: dict[str, SpecDecodeMetadata], ): next_token_ids: list[int] = [] for i, token_ids in enumerate(valid_sampled_token_ids): @@ -2317,14 +2338,18 @@ def _generate_mtp_token_ids( next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) + + # At this moment, we assume all eagle layers belong to the same KV + # cache group, thus using the same attention metadata. + eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_names[0]] if spec_decode_metadata is None: # input_ids can be None for multimodal models. target_token_ids = self.input_ids[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = attn_metadata.slot_mapping - cu_num_tokens = attn_metadata.query_start_loc + target_slot_mapping = eagle_attn_metadata.slot_mapping + cu_num_tokens = eagle_attn_metadata.query_start_loc else: # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens @@ -2339,13 +2364,13 @@ def _generate_mtp_token_ids( ) assert self.drafter is not None cu_num_tokens, token_indices = self.drafter.prepare_inputs( - attn_metadata.query_start_loc, + eagle_attn_metadata.query_start_loc, num_rejected_tokens, ) target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] target_hidden_states = hidden_states[token_indices] - target_slot_mapping = attn_metadata.slot_mapping[token_indices] + target_slot_mapping = eagle_attn_metadata.slot_mapping[token_indices] assert self.drafter is not None draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, @@ -2354,7 +2379,7 @@ def _generate_mtp_token_ids( target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, cu_num_tokens=cu_num_tokens, - block_table=attn_metadata.block_tables, + block_table=eagle_attn_metadata.block_tables, sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist()