Skip to content

[Draft][WIP][Feature]cpu offload connector #1659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.layer import (wait_for_kv_layer_from_connector,
maybe_save_kv_layer_to_connector)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -444,8 +446,11 @@ def unified_ascend_attention_with_output(
output: torch.Tensor,
layer_name: str,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
Expand All @@ -456,7 +461,7 @@ def unified_ascend_attention_with_output(
attn_metadata,
output,
trace_flag=False)
return
maybe_save_kv_layer_to_connector(layer_name, kv_cache)


def unified_attention_with_output_fake(
Expand Down
6 changes: 6 additions & 0 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.attention.layer import (wait_for_kv_layer_from_connector,
maybe_save_kv_layer_to_connector)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
Expand Down Expand Up @@ -1078,6 +1080,8 @@ def forward(
prefill_k_pe = k_pe[num_decode_tokens:]
else:
decode_hs_or_q_c = hidden_states_or_q_c
if has_prefill:
wait_for_kv_layer_from_connector(layer.layer_name)
if has_decode:
decode_k_nope = None
assert attn_metadata.decode is not None
Expand Down Expand Up @@ -1208,5 +1212,7 @@ def forward(
current_ms_metadata.after_comm_event.record()
else:
output[:num_decode_tokens] = output_decode
if has_prefill:
maybe_save_kv_layer_to_connector(layer.layer_name, kv_cache)

return output_padded
Loading
Loading