|
| 1 | +# Usage: python pytorch/xla/test/benchmarks/test_ragged_paged_attention_benchmark.py --kernel ragged-paged-attention |
| 2 | + |
| 3 | +import argparse |
| 4 | +import time |
| 5 | +from typing import List, Optional, Tuple |
| 6 | +import functools |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch_xla |
| 10 | +import torch_xla.core.xla_model as xm |
| 11 | +import jax |
| 12 | +from jax._src import test_util as jtu |
| 13 | +from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_single_query_paged_attention |
| 14 | +import jax.numpy as jnp |
| 15 | +import numpy as np |
| 16 | + |
| 17 | +from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention, make_sequence_metadata, DEFAULT_MASK_VALUE |
| 18 | + |
| 19 | + |
| 20 | +def _ref_ragged_paged_attention( |
| 21 | + queries: jax.Array, # [num_tokens, num_q_heads, head_dim] |
| 22 | + k_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] |
| 23 | + v_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] |
| 24 | + kv_lens: jax.Array, # i32[num_tokens] |
| 25 | + page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] |
| 26 | + cu_q_lens: jax.Array, # i32[num_tokens + 1] |
| 27 | + num_seqs: int, |
| 28 | +): |
| 29 | + """This is the reference ragged paged attention implementation.""" |
| 30 | + num_kv_heads, _, page_size, head_dim = k_pages.shape |
| 31 | + num_q_heads = queries.shape[1] |
| 32 | + assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0." |
| 33 | + num_query_per_kv = num_q_heads // num_kv_heads |
| 34 | + start_idx = 0 |
| 35 | + outputs: List[jax.Array] = [] |
| 36 | + for i in range(num_seqs): |
| 37 | + cur_q_len = cu_q_lens[i + 1] - cu_q_lens[i] |
| 38 | + q = queries[start_idx:start_idx + |
| 39 | + cur_q_len] # [cur_q_len, num_q_heads, head_dim] |
| 40 | + |
| 41 | + cur_kv_len = kv_lens[i] |
| 42 | + num_pages = (cur_kv_len + page_size - 1) // page_size |
| 43 | + page_indices_to_use = page_indices[i, :num_pages] |
| 44 | + k = k_pages[:, |
| 45 | + page_indices_to_use, :, :] # [num_kv_heads, page_indices_to_use, page_size, head_dim] |
| 46 | + k = jnp.permute_dims( |
| 47 | + k, (1, 2, 0, |
| 48 | + 3)) # [page_indices_to_use, page_size, num_kv_heads, head_dim] |
| 49 | + k = jnp.reshape( |
| 50 | + k, (-1, num_kv_heads, head_dim)) # [kv_len, num_kv_heads, head_dim] |
| 51 | + k = k[:cur_kv_len] # [cur_kv_lens, num_kv_heads, head_dim] |
| 52 | + |
| 53 | + v = v_pages[:, page_indices_to_use, :, :] |
| 54 | + v = jnp.permute_dims(v, (1, 2, 0, 3)) |
| 55 | + v = jnp.reshape(v, (-1, num_kv_heads, head_dim)) |
| 56 | + v = v[:cur_kv_len] # [cur_kv_lens, num_kv_heads, head_dim] |
| 57 | + |
| 58 | + if num_query_per_kv != 1: |
| 59 | + k = jnp.repeat(k, num_query_per_kv, axis=1) |
| 60 | + v = jnp.repeat(v, num_query_per_kv, axis=1) |
| 61 | + |
| 62 | + attn = jnp.einsum("qhd,khd->hqk", q, k) |
| 63 | + attn = attn.astype('float32') |
| 64 | + q_span = (cur_kv_len - cur_q_len) + jax.lax.broadcasted_iota( |
| 65 | + jnp.int32, (cur_q_len, cur_kv_len), 0) |
| 66 | + kv_span = jax.lax.broadcasted_iota(jnp.int32, (cur_q_len, cur_kv_len), 1) |
| 67 | + # Use the same DEFAULT_MASK_VALUE as in the kernel instead of float("-inf") so that the kernel can match the ref implement better. |
| 68 | + mask = jnp.where(q_span < kv_span, DEFAULT_MASK_VALUE, 0.) |
| 69 | + with jax.numpy_rank_promotion("allow"): |
| 70 | + attn = attn + mask |
| 71 | + attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) |
| 72 | + out = jnp.einsum("hqk,khd->qhd", attn, |
| 73 | + v) # [cur_q_len, num_q_heads, head_dim] |
| 74 | + |
| 75 | + outputs.append(out) |
| 76 | + start_idx += cur_q_len |
| 77 | + |
| 78 | + return jnp.concatenate(outputs, axis=0) |
| 79 | + |
| 80 | + |
| 81 | +def _get_closest_power_of_two(x): |
| 82 | + if x <= 0: |
| 83 | + raise ValueError(f"x must be positive. Got {x}") |
| 84 | + return 2**int(np.ceil(np.log2(x))) |
| 85 | + |
| 86 | + |
| 87 | +def benchmark(args): |
| 88 | + seq_lens = [ |
| 89 | + (1, 1328), |
| 90 | + (5, 18), |
| 91 | + (1, 129), |
| 92 | + (120, 229), |
| 93 | + (1, 122), # end of the first physical q block |
| 94 | + (1, 64), |
| 95 | + (32, 100), |
| 96 | + (250, 463), |
| 97 | + (1, 18), |
| 98 | + (1, 17), |
| 99 | + (99, 123), # last 3 physical q blocks [(q_len, kv_len),...] |
| 100 | + ] |
| 101 | + num_heads = (4, 4) |
| 102 | + head_dim = 128 |
| 103 | + dtype = jnp.float32 |
| 104 | + page_size = 16 |
| 105 | + num_pages = 32768 |
| 106 | + num_queries_per_block = 128 |
| 107 | + |
| 108 | + num_seqs = len(seq_lens) |
| 109 | + for i in range(num_seqs): |
| 110 | + cur_q_len = seq_lens[i][0] |
| 111 | + cur_kv_len = seq_lens[i][1] |
| 112 | + # Make sure the q_len is no longer than the kv_len. For example, |
| 113 | + # seq_lens = [(1, 1328), (5, 18), (506, 463)] is not a valid test case because |
| 114 | + # the 3rd sequence has q_len(506) > kv_len(463). |
| 115 | + assert cur_q_len <= cur_kv_len, f"cur_q_len must be less than or equal to cur_kv_len. Got {cur_q_len} and {cur_kv_len}" |
| 116 | + |
| 117 | + query_lens = [seq_len[0] for seq_len in seq_lens] |
| 118 | + num_q_tokens = sum(query_lens) |
| 119 | + kv_lens = jnp.array([seq_len[1] for seq_len in seq_lens]) |
| 120 | + num_q_heads = num_heads[0] |
| 121 | + num_kv_heads = num_heads[1] |
| 122 | + assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0." |
| 123 | + |
| 124 | + prng_key = jax.random.key(0) |
| 125 | + k1, k2, k3, k4 = jax.random.split(prng_key, 4) |
| 126 | + queries = jax.random.normal( |
| 127 | + k1, (num_q_tokens, num_q_heads, head_dim), dtype=dtype) |
| 128 | + k_pages = jax.random.normal( |
| 129 | + k2, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype) |
| 130 | + v_pages = jax.random.normal( |
| 131 | + k3, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype) |
| 132 | + |
| 133 | + # Create a kv_lens: i32[num_tokens] |
| 134 | + kv_lens_with_paddings = [0] * num_q_tokens |
| 135 | + kv_lens_with_paddings[:num_seqs] = kv_lens[:num_seqs] |
| 136 | + kv_lens_np = jnp.array(kv_lens_with_paddings) |
| 137 | + |
| 138 | + # Create a page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] |
| 139 | + max_kv_len = max([seq_len[1] for seq_len in seq_lens]) |
| 140 | + max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size |
| 141 | + # The reason why we need to pad max_num_pages_per_seq is that |
| 142 | + # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0 |
| 143 | + max_num_pages_per_seq = _get_closest_power_of_two(max_num_pages_per_seq) |
| 144 | + page_indices = jax.random.randint( |
| 145 | + k4, (num_q_tokens, max_num_pages_per_seq), 0, num_pages, dtype=jnp.int32) |
| 146 | + |
| 147 | + # Create a cu_q_lens: jax.Array, # i32[num_tokens + 1] |
| 148 | + q_lens_with_paddings = [0] * num_q_tokens |
| 149 | + for i in range(num_seqs): |
| 150 | + q_lens_with_paddings[i] = query_lens[i] |
| 151 | + cu_q_lens = jnp.cumsum(jnp.array([0] + q_lens_with_paddings)) |
| 152 | + |
| 153 | + err, actual_output = ragged_paged_attention( |
| 154 | + queries, |
| 155 | + k_pages, |
| 156 | + v_pages, |
| 157 | + kv_lens_np, |
| 158 | + page_indices, |
| 159 | + cu_q_lens, |
| 160 | + num_seqs, |
| 161 | + num_queries_per_block=num_queries_per_block, |
| 162 | + ) |
| 163 | + err.throw() # noop if there is no error. |
| 164 | + actual_output = jax.block_until_ready(actual_output) |
| 165 | + profile_path = "/workspaces/persist/myprofiles/plugins/profile" |
| 166 | + |
| 167 | + def run_benchmark(num_iters: int, profile: bool = False) -> float: |
| 168 | + start_time = time.perf_counter() |
| 169 | + if profile: |
| 170 | + jax.profiler.start_trace(profile_path) |
| 171 | + |
| 172 | + actual_output = None |
| 173 | + for _ in range(num_iters): |
| 174 | + if args.kernel == "ragged-paged-attention": |
| 175 | + err, actual_output = ragged_paged_attention( |
| 176 | + queries, |
| 177 | + k_pages, |
| 178 | + v_pages, |
| 179 | + kv_lens_np, |
| 180 | + page_indices, |
| 181 | + cu_q_lens, |
| 182 | + num_seqs, |
| 183 | + ) |
| 184 | + err.throw() |
| 185 | + elif args.kernel == "ragged-paged-attention-ref-impl": |
| 186 | + actual_output = _ref_ragged_paged_attention( |
| 187 | + queries, |
| 188 | + k_pages, |
| 189 | + v_pages, |
| 190 | + kv_lens_np, |
| 191 | + page_indices, |
| 192 | + cu_q_lens, |
| 193 | + num_seqs, |
| 194 | + ) |
| 195 | + else: |
| 196 | + assert False, f"Invalid kernel name {args.kernel}" |
| 197 | + |
| 198 | + jax.block_until_ready(actual_output) |
| 199 | + |
| 200 | + end_time = time.perf_counter() |
| 201 | + if profile: |
| 202 | + jax.profiler.stop_trace() |
| 203 | + return (end_time - start_time) / num_iters |
| 204 | + |
| 205 | + # Warmup. |
| 206 | + print("Warming up...") |
| 207 | + run_benchmark(num_iters=3, profile=False) |
| 208 | + |
| 209 | + print("Run benchmark...") |
| 210 | + if args.profile: |
| 211 | + latency = run_benchmark(num_iters=1, profile=True) |
| 212 | + else: |
| 213 | + latency = run_benchmark(num_iters=10, profile=False) |
| 214 | + print(f"Kernel running time: {latency * 1000000:.3f} us") |
| 215 | + |
| 216 | + |
| 217 | +if __name__ == "__main__": |
| 218 | + parser = argparse.ArgumentParser() |
| 219 | + parser.add_argument( |
| 220 | + "--kernel", |
| 221 | + type=str, |
| 222 | + choices=[ |
| 223 | + "ragged-paged-attention", |
| 224 | + "ragged-paged-attention-ref-impl", |
| 225 | + ], |
| 226 | + default="multi-queries-paged-attn") |
| 227 | + parser.add_argument("--profile", action="store_true") |
| 228 | + args = parser.parse_args() |
| 229 | + benchmark(args) |
0 commit comments