Skip to content

Commit 86338cd

Browse files
refactor triton
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 4e9d4ec commit 86338cd

File tree

1 file changed

+30
-35
lines changed

1 file changed

+30
-35
lines changed

vllm/v1/attention/backends/triton_attn.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Attention layer with PagedAttention and Triton prefix prefill."""
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any, ClassVar, Optional
5+
from typing import Any, ClassVar, Optional
66

77
import torch
88

@@ -14,17 +14,14 @@
1414
chunked_prefill_paged_decode)
1515
from vllm.attention.ops.paged_attn import PagedAttention
1616
from vllm.attention.ops.triton_unified_attention import unified_attention
17+
from vllm.config import VllmConfig
1718
from vllm.logger import init_logger
1819
from vllm.platforms import current_platform
1920
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
2021
from vllm.v1.attention.backends.utils import (
2122
AttentionMetadataBuilder, CommonAttentionMetadata,
2223
make_local_attention_virtual_batches)
2324
from vllm.v1.kv_cache_interface import AttentionSpec
24-
from vllm.v1.worker.block_table import BlockTable
25-
26-
if TYPE_CHECKING:
27-
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
2825

2926
logger = init_logger(__name__)
3027

@@ -75,12 +72,21 @@ class TritonAttentionMetadataBuilder(
7572
AttentionMetadataBuilder[TritonAttentionMetadata]):
7673
full_cudagraph_supported: ClassVar[bool] = True
7774

78-
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
79-
block_table: BlockTable):
80-
self.runner = runner
75+
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
76+
device: torch.device):
77+
self.device = device
8178
self.block_size = kv_cache_spec.block_size
8279
self.kv_cache_spec = kv_cache_spec
83-
self.block_table = block_table
80+
81+
model_config = vllm_config.model_config
82+
self.num_heads_q = model_config.get_num_attention_heads(
83+
vllm_config.parallel_config)
84+
self.num_heads_kv = model_config.get_num_kv_heads(
85+
vllm_config.parallel_config)
86+
self.headdim = model_config.get_head_size()
87+
88+
self.attention_chunk_size = getattr(vllm_config.scheduler_config,
89+
'attention_chunk_size', None)
8490

8591
def build_for_cudagraph_capture(
8692
self, common_attn_metadata: CommonAttentionMetadata
@@ -96,42 +102,32 @@ def build(self,
96102
common_prefix_len: int,
97103
common_attn_metadata: CommonAttentionMetadata,
98104
fast_build: bool = False) -> TritonAttentionMetadata:
99-
num_reqs = common_attn_metadata.num_reqs
100105
num_actual_tokens = common_attn_metadata.num_actual_tokens
101106
max_query_len = common_attn_metadata.max_query_len
102107

103-
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
108+
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
104109
query_start_loc = common_attn_metadata.query_start_loc
105110
seq_lens = common_attn_metadata.seq_lens
106-
block_table = self.block_table
107-
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
108-
109-
block_table.slot_mapping[:num_actual_tokens].copy_(
110-
block_table.slot_mapping_cpu[:num_actual_tokens],
111-
non_blocking=True)
112-
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
113-
# mode.
114-
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
115-
116-
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
111+
block_table_tensor = common_attn_metadata.block_table_tensor
112+
slot_mapping = common_attn_metadata.slot_mapping
117113

118114
# for local attention
119115
local_attn_metadata = None
120-
if self.runner.attention_chunk_size is not None:
116+
if self.attention_chunk_size is not None:
121117
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
122118
virt_block_table_tensor = make_local_attention_virtual_batches(
123-
self.runner.attention_chunk_size,
124-
self.runner.query_start_loc_np[:num_reqs + 1],
125-
self.runner.seq_lens_np[:num_reqs],
119+
self.attention_chunk_size,
120+
common_attn_metadata.query_start_loc_cpu.numpy(),
121+
common_attn_metadata.seq_lens_cpu.numpy(),
126122
block_table_tensor,
127123
self.block_size,
128124
)
129125
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
130-
self.runner.device, non_blocking=True)
126+
self.device, non_blocking=True)
131127
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
132-
self.runner.device, non_blocking=True)
133-
local_max_query_len = seqlens_q_local_np.max()
134-
local_max_seq_len = virt_k_seqlens_np.max()
128+
self.device, non_blocking=True)
129+
local_max_query_len = seqlens_q_local_np.max().item()
130+
local_max_seq_len = virt_k_seqlens_np.max().item()
135131

136132
local_attn_metadata = TritonAttentionMetadata \
137133
.LocalAttentionMetadata(
@@ -148,14 +144,13 @@ def build(self,
148144
if use_cascade:
149145
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
150146
dtype=torch.int32,
151-
device=self.runner.device)
147+
device=self.device)
152148
prefix_kv_lens = torch.tensor([common_prefix_len],
153149
dtype=torch.int32,
154-
device=self.runner.device)
155-
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
150+
device=self.device)
151+
suffix_kv_lens = (common_attn_metadata.seq_lens_cpu -
156152
common_prefix_len)
157-
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
158-
self.runner.device)
153+
suffix_kv_lens = suffix_kv_lens.to(self.device)
159154
else:
160155
cu_prefix_query_lens = None
161156
prefix_kv_lens = None

0 commit comments

Comments
 (0)