Skip to content

Commit 697c279

Browse files
yaochengjimgoin
andauthored
[TPU] support attention head dim smaller than 128 (vllm-project#19620)
Signed-off-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: mgoin <mgoin64@gmail.com>
1 parent 987870c commit 697c279

File tree

2 files changed

+65
-7
lines changed

2 files changed

+65
-7
lines changed

tests/v1/tpu/test_basic.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,43 @@ def test_basic(
6767
assert "1024" in output or "0, 1" in output
6868

6969

70+
@pytest.mark.skipif(not current_platform.is_tpu(),
71+
reason="This is a basic test for TPU only")
72+
@pytest.mark.parametrize("max_tokens", [8])
73+
@pytest.mark.parametrize("max_num_seqs", [16])
74+
def test_phi3(
75+
vllm_runner: type[VllmRunner],
76+
monkeypatch: pytest.MonkeyPatch,
77+
max_tokens: int,
78+
max_num_seqs: int,
79+
) -> None:
80+
prompts = [
81+
"A robot may not injure a human being",
82+
"It is only with the heart that one can see rightly;",
83+
"The greatest glory in living lies not in never falling,",
84+
]
85+
answers = [
86+
" or, by violating privacy",
87+
" what is essential is love.",
88+
" but in rising every time we fall.",
89+
]
90+
# test head dim = 96
91+
model = "microsoft/Phi-3-mini-128k-instruct"
92+
93+
with monkeypatch.context() as m:
94+
m.setenv("VLLM_USE_V1", "1")
95+
96+
with vllm_runner(model,
97+
max_num_batched_tokens=256,
98+
max_num_seqs=max_num_seqs) as vllm_model:
99+
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
100+
# vllm_outputs is a list of tuples whose first element is the token id
101+
# and the second element is the output (including the prompt).
102+
for output, answer in zip(vllm_outputs, answers):
103+
generated_text = output[1]
104+
assert answer in generated_text
105+
106+
70107
TP_SIZE_8 = 8
71108

72109

vllm/v1/attention/backends/pallas.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
logger = init_logger(__name__)
1919

20+
# TPU requires the head size to be a multiple of 128.
21+
TPU_HEAD_SIZE_ALIGNMENT = 128
22+
2023

2124
class PallasAttentionBackend(AttentionBackend):
2225

@@ -43,6 +46,14 @@ def get_kv_cache_shape(
4346
num_kv_heads: int,
4447
head_size: int,
4548
) -> tuple[int, ...]:
49+
padded_head_size = cdiv(
50+
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
51+
num_blocks = num_blocks * head_size // padded_head_size
52+
if padded_head_size != head_size:
53+
logger.warning_once(
54+
"head size is padded to %d, and num_blocks is adjusted to %d"
55+
" accordingly", padded_head_size, num_blocks)
56+
head_size = padded_head_size
4657
return (num_blocks, block_size, num_kv_heads * 2, head_size)
4758

4859
@staticmethod
@@ -132,8 +143,6 @@ def __init__(
132143
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
133144

134145
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
135-
if head_size % 128 != 0:
136-
raise NotImplementedError("Head size must be a multiple of 128.")
137146
if alibi_slopes is not None:
138147
raise NotImplementedError("Alibi slopes is not supported.")
139148
if kv_cache_dtype != "auto":
@@ -187,6 +196,18 @@ def forward(
187196
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
188197
num_tokens, hidden_size = query.shape
189198
query = query.view(num_tokens, self.num_heads, self.head_size)
199+
key = key.view(-1, self.num_kv_heads, self.head_size)
200+
value = value.view(-1, self.num_kv_heads, self.head_size)
201+
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
202+
padded_head_size = cdiv(
203+
self.head_size,
204+
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
205+
query = torch.nn.functional.pad(
206+
query, (0, padded_head_size - self.head_size), value=0.0)
207+
key = torch.nn.functional.pad(
208+
key, (0, padded_head_size - self.head_size), value=0.0)
209+
value = torch.nn.functional.pad(
210+
value, (0, padded_head_size - self.head_size), value=0.0)
190211

191212
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
192213
# Write input keys and values to the KV cache.
@@ -213,6 +234,9 @@ def forward(
213234
soft_cap=self.logits_soft_cap,
214235
)
215236

237+
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
238+
output = output[:, :, :self.head_size]
239+
216240
return output.reshape(num_tokens, hidden_size)
217241

218242

@@ -231,11 +255,8 @@ def write_to_kv_cache(
231255
232256
"""
233257
_, _, num_combined_kv_heads, head_size = kv_cache.shape
234-
num_kv_heads = num_combined_kv_heads // 2
235-
236-
key = key.view(-1, num_kv_heads, head_size)
237-
value = value.view(-1, num_kv_heads, head_size)
238-
258+
head_size = cdiv(head_size,
259+
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
239260
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
240261
head_size)
241262

0 commit comments

Comments
 (0)