Skip to content

Commit 7219560

Browse files
clean up code
Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
1 parent fcd59c3 commit 7219560

File tree

1 file changed

+0
-232
lines changed

1 file changed

+0
-232
lines changed

vllm/model_executor/models/phi3samba.py

Lines changed: 0 additions & 232 deletions
Original file line numberDiff line numberDiff line change
@@ -167,161 +167,6 @@ def __init__(self,
167167
def lambda_init_fn(self, depth):
168168
return 0.8 - 0.6 * math.exp(-0.3 * depth)
169169

170-
171-
def split_heads(self, x):
172-
# split by num_heads, the stripe pattern is friendly to tensor parallel.
173-
x = rearrange(x, "... (H two) D -> ... H two D", two=2)
174-
x1 = x[..., 0, :]
175-
x2 = x[..., 1, :]
176-
return x1.contiguous(), x2.contiguous()
177-
178-
def split_kv_cache(self, x):
179-
# split by num_heads, the stripe pattern is friendly to tensor parallel.
180-
if x.numel() == 0:
181-
return torch.empty(0), torch.empty(0)
182-
183-
x1, x2 = x[0], x[1]
184-
return x1, x2
185-
186-
def forward_decode(
187-
self,
188-
query: torch.Tensor,
189-
k_cache: torch.Tensor,
190-
v_cache: torch.Tensor,
191-
attn_metadata: AttentionMetadata,
192-
):
193-
if not attn_metadata.decode_metadata:
194-
block_tables_arg = attn_metadata.cross_layer_shared_block_tables
195-
else:
196-
block_tables_arg = attn_metadata.block_tables
197-
198-
output = flash_attn_with_kvcache(
199-
q=query.unsqueeze(1),
200-
k_cache=k_cache,
201-
v_cache=v_cache,
202-
block_table=block_tables_arg,
203-
cache_seqlens=attn_metadata.seq_lens_tensor,
204-
softmax_scale=self.attn.impl.scale,
205-
causal=True,
206-
window_size=self.attn.impl.sliding_window,
207-
alibi_slopes=self.attn.impl.alibi_slopes,
208-
softcap=self.attn.impl.logits_soft_cap,
209-
).squeeze(1)
210-
return output
211-
212-
def populate_kv_cache(self,
213-
key,
214-
value,
215-
kv_cache,
216-
attn_metadata):
217-
if (kv_cache.numel() > 0):
218-
if (key is not None) and (value is not None):
219-
updated_slot_mapping = attn_metadata.slot_mapping
220-
# previous_key_cache_sum = key_cache.sum()
221-
# previous_value_cache_sum = value_cache.sum()
222-
223-
torch.ops._C_cache_ops.reshape_and_cache_flash(
224-
key,
225-
value,
226-
kv_cache[0],
227-
kv_cache[1],
228-
updated_slot_mapping.flatten(),
229-
self.attn.impl.kv_cache_dtype,
230-
self._k_scale,
231-
self._v_scale,
232-
)
233-
# assert key_cache.sum() - previous_key_cache_sum == key.sum(), "key_cache sum mismatch"
234-
# assert value_cache.sum() - previous_value_cache_sum == value.sum(), "value_cache sum mismatch"
235-
# if key_cache.sum() - previous_key_cache_sum != key.sum():
236-
# print("key_cache sum mismatch")
237-
# if value_cache.sum() - previous_value_cache_sum != value.sum():
238-
# print("value_cache sum mismatch")
239-
240-
def forward_customized(
241-
self,
242-
query: torch.Tensor,
243-
key: Optional[torch.Tensor],
244-
value: Optional[torch.Tensor],
245-
k_cache: torch.Tensor,
246-
v_cache: torch.Tensor,
247-
attn_metadata: AttentionMetadata
248-
) -> torch.Tensor:
249-
250-
head_size = self.head_dim
251-
num_heads = self.num_heads // 2
252-
num_kv_heads = self.num_key_value_heads // 2
253-
254-
query = query.view(-1, num_heads, head_size)
255-
if key is not None:
256-
assert value is not None
257-
key = key.view(-1, num_kv_heads, head_size)
258-
value = value.view(-1, num_kv_heads, head_size)
259-
else:
260-
assert value is None
261-
262-
num_prefill_tokens = attn_metadata.num_prefill_tokens
263-
num_decode_tokens = attn_metadata.num_decode_tokens
264-
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch"
265-
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch"
266-
267-
output = torch.empty_like(query)
268-
# Query for decode. KV is not needed because it is already cached.
269-
decode_query = query[num_prefill_tokens:]
270-
# QKV for prefill.
271-
query = query[:num_prefill_tokens]
272-
if key is not None and value is not None:
273-
key = key[:num_prefill_tokens]
274-
value = value[:num_prefill_tokens]
275-
276-
assert query.shape[0] == num_prefill_tokens, "query shape mismatch"
277-
assert decode_query.shape[0] == num_decode_tokens, "decode query shape mismatch"
278-
279-
if prefill_meta := attn_metadata.prefill_metadata:
280-
# Prompt run.
281-
if k_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
282-
# normal attention
283-
prefill_output = flash_attn_varlen_func(
284-
q=query,
285-
k=key,
286-
v=value,
287-
cu_seqlens_q=prefill_meta.seq_start_loc,
288-
cu_seqlens_k=prefill_meta.seq_start_loc,
289-
max_seqlen_q=prefill_meta.max_prefill_seq_len,
290-
max_seqlen_k=prefill_meta.max_prefill_seq_len,
291-
softmax_scale=self.attn.impl.scale,
292-
causal=True,
293-
window_size=self.attn.impl.sliding_window,
294-
alibi_slopes=self.attn.impl.alibi_slopes,
295-
softcap=self.attn.impl.logits_soft_cap,
296-
)
297-
assert prefill_output.shape == output[:num_prefill_tokens].shape
298-
output[:num_prefill_tokens] = prefill_output
299-
else:
300-
raise Exception("prefix caching not supported")
301-
302-
if decode_meta := attn_metadata.decode_metadata:
303-
block_tables_arg = decode_meta.block_tables
304-
try:
305-
output[num_prefill_tokens:] = flash_attn_with_kvcache(
306-
q=decode_query.unsqueeze(1),
307-
k_cache=k_cache,
308-
v_cache=v_cache,
309-
block_table=block_tables_arg,
310-
cache_seqlens=decode_meta.seq_lens_tensor,
311-
softmax_scale=self.attn.impl.scale,
312-
causal=True,
313-
window_size=self.attn.impl.sliding_window,
314-
alibi_slopes=self.attn.impl.alibi_slopes,
315-
softcap=self.attn.impl.logits_soft_cap,
316-
).squeeze(1)
317-
except Exception as e:
318-
logger.error(
319-
f"Error in PagedAttention.forward_decode: {str(e)}")
320-
raise e
321-
322-
# Reshape the output tensor.
323-
return output.view(-1, num_heads, head_size)
324-
325170
def forward(
326171
self,
327172
hidden_states: torch.Tensor,
@@ -333,86 +178,9 @@ def forward(
333178
if not self.yoco_cross: # need to generate kv-cache
334179
qkv = self.Wqkv(hidden_states)
335180
q, k, v = qkv.split([self.hidden_size, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim], dim=-1)
336-
reference_attn_output = self.attn(q, k, v)
337-
# # q, k = self.rotary_emb(positions, q, k)
338-
# # reshape
339-
# q = q.view(-1, self.num_heads, self.head_dim)
340-
# k = k.view(-1, self.num_key_value_heads, self.head_dim)
341-
# v = v.view(-1, self.num_key_value_heads, self.head_dim)
342-
343-
# q1, q2 = self.split_heads(q)
344-
# k1, k2 = self.split_heads(k)
345-
# v1, v2 = self.split_heads(v)
346-
347-
# # kv_cache shape is (2, 2, num_blocks, block_size * num_kv_heads // 2 * head_size)
348-
# # Split by half along the first dimension.
349-
# kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
350-
# assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous"
351-
# assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous"
352-
353-
# if kv_cache1.numel() != 0:
354-
# self.populate_kv_cache(k1, v1, kv_cache1, attn_metadata)
355-
# self.populate_kv_cache(k2, v2, kv_cache2, attn_metadata)
356-
357-
# key_cache1, value_cache1 = self.split_kv_cache(kv_cache1)
358-
# key_cache2, value_cache2 = self.split_kv_cache(kv_cache2)
359-
# else:
360-
# key_cache1, value_cache1 = torch.empty(0), torch.empty(0)
361-
# key_cache2, value_cache2 = torch.empty(0), torch.empty(0)
362-
# attn11 = self.forward_customized(q1, k1, v1, key_cache1, value_cache1, attn_metadata)
363-
# attn12 = self.forward_customized(q1, k1, v2, key_cache1, value_cache2, attn_metadata)
364-
# attn11 = attn11.view(q1.shape)
365-
# attn12 = attn12.view(q1.shape)
366-
# attn1 = torch.cat([attn11, attn12], dim=-1)
367-
368-
# attn21 = self.forward_customized(q2, k2, v1, key_cache2, value_cache1, attn_metadata)
369-
# attn22 = self.forward_customized(q2, k2, v2, key_cache2, value_cache2, attn_metadata)
370-
# attn21 = attn21.view(q2.shape)
371-
# attn22 = attn22.view(q2.shape)
372-
# attn2 = torch.cat([attn21, attn22], dim=-1)
373-
374-
# lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
375-
# lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
376-
# lambda_full = lambda_1 - lambda_2 + self.lambda_init
377-
378-
# attn = attn1 - lambda_full * attn2
379-
# # attn shape (-1, self.num_heads // 2, 2 * self.head_dim)
380-
# attn = self.subln(attn)
381-
# attn = attn * (1 - self.lambda_init)
382-
# # reshape back to 2 * num_head
383-
# attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2)
384181
attn_output = self.attn(q, k, v)
385182
else: # re-use the kv cache, full attention
386183
q = self.Wqkv(hidden_states)
387-
# q = q.view(-1, self.num_heads, self.head_dim)
388-
# q1, q2 = self.split_heads(q)
389-
# # kv_cache shape is (2, num_blocks, block_size * num_kv_heads * head_size)
390-
# kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
391-
# key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1]
392-
# key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1]
393-
394-
# attn11 = self.forward_decode(q1, key_cache1, value_cache1, attn_metadata)
395-
# attn12 = self.forward_decode(q1, key_cache1, value_cache2, attn_metadata)
396-
# attn11 = attn11.view(q1.shape)
397-
# attn12 = attn12.view(q1.shape)
398-
# attn1 = torch.cat([attn11, attn12], dim=-1)
399-
400-
# attn21 = self.forward_decode(q2, key_cache2, value_cache1, attn_metadata)
401-
# attn22 = self.forward_decode(q2, key_cache2, value_cache2, attn_metadata)
402-
# attn21 = attn21.view(q2.shape)
403-
# attn22 = attn22.view(q2.shape)
404-
# attn2 = torch.cat([attn21, attn22], dim=-1)
405-
406-
# lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
407-
# lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
408-
# lambda_full = lambda_1 - lambda_2 + self.lambda_init
409-
# attn = attn1 - lambda_full * attn2
410-
# attn = self.subln(attn)
411-
# attn = attn * (1 - self.lambda_init)
412-
# # reshape back to 2 * num_head
413-
# attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2)
414-
415-
416184
if self.attn.kv_cache[0].numel() == 0:
417185
self.attn.kv_cache = [kv_cache]
418186
attn_output = self.attn(q, None, None)

0 commit comments

Comments
 (0)