|
| 1 | +import torch_xla |
| 2 | + |
| 3 | + |
| 4 | +def _next_power_of_2_bit_manipulation(x): |
| 5 | + """ |
| 6 | + Finds the smallest power of 2 >= x using bit manipulation. |
| 7 | + Assumes x is an integer. |
| 8 | +
|
| 9 | + Args: |
| 10 | + x: The input number (should be an integer). |
| 11 | +
|
| 12 | + Returns: |
| 13 | + The smallest integer power of 2 that is >= x. |
| 14 | + Returns 1 if x <= 0. |
| 15 | + """ |
| 16 | + if x <= 0: |
| 17 | + return 1 |
| 18 | + if x == 1: |
| 19 | + return 1 |
| 20 | + return 1 << (x - 1).bit_length() |
| 21 | + |
| 22 | + |
| 23 | +# ragged_paged_attention |
| 24 | +# key: (q_head_num, kv_head_num, token_num, max_model_len) |
| 25 | +# value: (num_kv_pages_per_block, num_queries_per_block) |
| 26 | + |
| 27 | + |
| 28 | +def _simplify_key_ragged_paged_attention(q_head_num, kv_head_num, token_num, |
| 29 | + max_model_len): |
| 30 | + token_num = _next_power_of_2_bit_manipulation(token_num) |
| 31 | + max_model_len = _next_power_of_2_bit_manipulation(max_model_len) |
| 32 | + return q_head_num, kv_head_num, token_num, max_model_len |
| 33 | + |
| 34 | + |
| 35 | +# TODO: add more tuned block sizes in the table |
| 36 | +_ragged_attention_table = { |
| 37 | + (32, 8, 4096, 2048): (128, 64), |
| 38 | + (4, 1, 4096, 2048): (128, 128), |
| 39 | + (32, 8, 2048, 2048): (128, 32), |
| 40 | + (4, 1, 2048, 2048): (128, 64), |
| 41 | + (32, 8, 1024, 2048): (64, 32), |
| 42 | + (1, 1, 1024, 2048): (64, 32), |
| 43 | + (32, 8, 4096, 4096): (128, 64), |
| 44 | + (4, 1, 4096, 4096): (128, 128), |
| 45 | + (32, 8, 2048, 4096): (128, 32), |
| 46 | + (4, 1, 2048, 4096): (128, 64), |
| 47 | + (32, 8, 1024, 4096): (64, 32), |
| 48 | + (1, 1, 1024, 4096): (64, 32), |
| 49 | + (32, 8, 4096, 64): (32, 32), |
| 50 | + (4, 1, 4096, 64): (32, 32), |
| 51 | + (32, 8, 2048, 64): (32, 32), |
| 52 | + (4, 1, 2048, 64): (32, 32), |
| 53 | + (32, 8, 1024, 64): (32, 32), |
| 54 | + (1, 1, 1024, 64): (32, 32), |
| 55 | + (32, 8, 4096, 128): (32, 32), |
| 56 | + (4, 1, 4096, 128): (32, 32), |
| 57 | + (32, 8, 2048, 128): (32, 32), |
| 58 | + (4, 1, 2048, 128): (32, 32), |
| 59 | + (32, 8, 1024, 128): (32, 32), |
| 60 | + (1, 1, 1024, 128): (32, 32), |
| 61 | +} |
| 62 | + |
| 63 | + |
| 64 | +def get_ragged_attention_tuned_block_size(q_head_num, kv_head_num, token_num, |
| 65 | + max_model_len): |
| 66 | + tpu_version = torch_xla.tpu.version() |
| 67 | + if tpu_version < 4: |
| 68 | + raise NotImplementedError("TPU version must be 4 or higher.") |
| 69 | + if tpu_version == 4: |
| 70 | + # This default block size is not tuned, only make sure there's no |
| 71 | + # OOM in vmem |
| 72 | + num_kv_pages_per_block = 16 |
| 73 | + num_queries_per_block = 128 |
| 74 | + return num_kv_pages_per_block, num_queries_per_block |
| 75 | + |
| 76 | + key = _simplify_key_ragged_paged_attention(q_head_num, kv_head_num, token_num, |
| 77 | + max_model_len) |
| 78 | + block_sizes = _ragged_attention_table.get(key, (128, 32)) |
| 79 | + return block_sizes |
0 commit comments