|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
4 | 4 | from dataclasses import dataclass
|
5 |
| -from typing import Any, Optional |
| 5 | +from typing import Any, ClassVar, Optional |
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 |
|
@@ -63,63 +63,91 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
63 | 63 |
|
64 | 64 |
|
65 | 65 | class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
| 66 | + full_cudagraph_supported: ClassVar[bool] = True # decode only |
66 | 67 |
|
67 | 68 | def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
68 | 69 | block_table: BlockTable):
|
69 | 70 | super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
|
70 | 71 | assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
|
71 | 72 | "only supports block size 1."
|
72 | 73 |
|
73 |
| - def _get_paged_kv_tensors( |
74 |
| - self, block_table: torch.Tensor, |
75 |
| - seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: |
| 74 | + # Preparing persistent buffers |
| 75 | + if self.runner.full_cuda_graph: |
| 76 | + device = self.runner.device |
| 77 | + max_num_reqs = self.runner.max_num_reqs |
| 78 | + self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, |
| 79 | + dtype=torch.int32, |
| 80 | + device=device) |
| 81 | + self.paged_kv_indices = torch.zeros( |
| 82 | + block_table.get_device_tensor().numel( |
| 83 | + ), # max num pages possible |
| 84 | + dtype=torch.int32, |
| 85 | + device=device) |
| 86 | + self.paged_kv_last_page_len = torch.zeros(max_num_reqs, |
| 87 | + dtype=torch.int32, |
| 88 | + device=device) |
| 89 | + |
| 90 | + self.qo_indptr = torch.arange(0, |
| 91 | + max_num_reqs + 1, |
| 92 | + dtype=torch.int32, |
| 93 | + device=device) |
| 94 | + |
| 95 | + def _build_decode(self, block_table_tensor: torch.Tensor, |
| 96 | + seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: |
76 | 97 | page_size = self.kv_cache_spec.block_size
|
77 | 98 | block_table_bounds = (seq_lens + page_size - 1) // page_size
|
78 | 99 | device = self.runner.device
|
79 | 100 |
|
80 |
| - mask = (torch.arange(block_table.size(1), |
81 |
| - dtype=block_table.dtype, |
| 101 | + mask = (torch.arange(block_table_tensor.size(1), |
| 102 | + dtype=block_table_tensor.dtype, |
82 | 103 | device=device).unsqueeze(0)
|
83 | 104 | < block_table_bounds.unsqueeze(1))
|
84 |
| - paged_kv_indices = block_table[mask] |
| 105 | + paged_kv_indices = block_table_tensor[mask] |
| 106 | + |
| 107 | + paged_kv_last_page_len = seq_lens % page_size |
| 108 | + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, |
| 109 | + page_size, paged_kv_last_page_len) |
85 | 110 |
|
86 | 111 | paged_kv_indptr = torch.cat([
|
87 | 112 | torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
|
88 | 113 | block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
89 | 114 | ])
|
90 | 115 |
|
91 |
| - paged_kv_last_page_len = seq_lens % page_size |
92 |
| - paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, |
93 |
| - page_size, paged_kv_last_page_len) |
94 |
| - qo_indptr = torch.arange(0, |
95 |
| - self._num_decodes + 1, |
96 |
| - step=1, |
97 |
| - dtype=torch.int32, |
98 |
| - device=device) |
99 |
| - |
100 |
| - return ( |
101 |
| - paged_kv_indices, |
102 |
| - paged_kv_indptr, |
103 |
| - paged_kv_last_page_len, |
104 |
| - qo_indptr, |
105 |
| - ) |
| 116 | + if self.runner.full_cuda_graph: |
| 117 | + num_reqs = self._num_decodes |
106 | 118 |
|
107 |
| - def _build_decode(self, block_table_tensor: torch.Tensor, |
108 |
| - seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: |
| 119 | + num_actual_pages = paged_kv_indices.size(0) |
| 120 | + |
| 121 | + self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, |
| 122 | + non_blocking=True) |
| 123 | + self.paged_kv_indices[num_actual_pages:].fill_(-1) |
| 124 | + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] |
| 125 | + |
| 126 | + self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr, |
| 127 | + non_blocking=True) |
| 128 | + self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1]) |
| 129 | + paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs] |
| 130 | + |
| 131 | + self.paged_kv_last_page_len[:num_reqs].copy_( |
| 132 | + paged_kv_last_page_len, non_blocking=True) |
| 133 | + self.paged_kv_last_page_len[num_reqs:].fill_(1) |
| 134 | + paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] |
| 135 | + |
| 136 | + qo_indptr = self.qo_indptr[:1 + num_reqs] |
109 | 137 |
|
110 |
| - ( |
111 |
| - paged_kv_indices, |
112 |
| - paged_kv_indptr, |
113 |
| - paged_last_page_len, |
114 |
| - qo_indptr, |
115 |
| - ) = self._get_paged_kv_tensors(block_table_tensor, seq_lens) |
| 138 | + else: |
| 139 | + qo_indptr = torch.arange(0, |
| 140 | + self._num_decodes + 1, |
| 141 | + step=1, |
| 142 | + dtype=torch.int32, |
| 143 | + device=device) |
116 | 144 |
|
117 | 145 | attn_metadata = AiterMLADecodeMetadata(
|
118 | 146 | block_table=block_table_tensor,
|
119 | 147 | seq_lens=seq_lens,
|
120 | 148 | paged_kv_indptr=paged_kv_indptr,
|
121 | 149 | paged_kv_indices=paged_kv_indices,
|
122 |
| - paged_kv_last_page_len=paged_last_page_len, |
| 150 | + paged_kv_last_page_len=paged_kv_last_page_len, |
123 | 151 | qo_indptr=qo_indptr)
|
124 | 152 |
|
125 | 153 | return attn_metadata
|
|
0 commit comments