Skip to content

Commit 531f610

Browse files
clean up code
Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
1 parent 7219560 commit 531f610

File tree

2 files changed

+19
-93
lines changed

2 files changed

+19
-93
lines changed

vllm/attention/backends/abstract.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class AttentionType:
3030
ENCODER_ONLY = "encoder_only"
3131
# Attention between dec. Q and enc. K/V for encoder-decoder
3232
ENCODER_DECODER = "encoder_decoder"
33-
DECODER_DECODER = "decoder_decoder" # Attention layer that reuse kv cache
33+
# Attention layer that reuse kv cache
34+
DECODER_DECODER = "decoder_decoder"
3435

3536

3637
class AttentionBackend(ABC):

vllm/model_executor/models/phi3samba.py

Lines changed: 17 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from typing import List, Optional, Tuple, Union, Iterable, Dict
1+
from typing import List, Optional, Tuple, Union, Iterable
22
import math
3-
import copy
43

54
import torch
65
import torch.nn as nn
@@ -17,7 +16,7 @@
1716
RowParallelLinear,
1817
ColumnParallelLinear)
1918
from vllm.model_executor.layers.logits_processor import LogitsProcessor
20-
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
19+
from vllm.model_executor.layers.sampler import SamplerOutput
2120
from vllm.model_executor.layers.vocab_parallel_embedding import (
2221
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
2322
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -30,10 +29,7 @@
3029
causal_conv1d_fn, causal_conv1d_update)
3130
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
3231
selective_scan_fn, selective_state_update)
33-
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
34-
AttentionMetadata, AttentionType)
35-
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
36-
flash_attn_with_kvcache)
32+
from vllm.attention.backends.abstract import (AttentionMetadata, AttentionType)
3733

3834
from vllm.logger import init_logger
3935
from .utils import (maybe_prefix, make_layers)
@@ -52,6 +48,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
5248
# print(f"x1 shape: {x1.shape}, x2 shape: {x2.shape}")
5349
return x1 * nn.functional.silu(x2)
5450

51+
5552
class SambaMLP(nn.Module):
5653
"""Gated Linear Unit.
5754
@@ -77,34 +74,28 @@ def forward(self, hidden_states):
7774
return self.fc2(y)
7875

7976

80-
class SambaAttention(nn.Module):
81-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
77+
def get_virtual_engine():
78+
forward_context: ForwardContext = get_forward_context()
79+
return forward_context.virtual_engine
8280

81+
class SambaAttention(nn.Module):
8382
def __init__(self,
8483
config,
8584
layer_idx: Optional[int] = None,
8685
yoco_cross: bool = False,
8786
cache_config: Optional[CacheConfig] = None,
8887
prefix: str = ""):
8988
super().__init__()
90-
self.config = config
91-
self.layer_idx = layer_idx
9289
if layer_idx is None:
9390
logger.warning_once(
9491
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
9592
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
9693
"when creating this class."
9794
)
98-
99-
self.attention_dropout = config.attention_dropout
10095
self.hidden_size = config.hidden_size
10196
self.num_heads = config.num_attention_heads
10297
self.head_dim = self.hidden_size // self.num_heads
10398
self.num_key_value_heads = config.num_key_value_heads
104-
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
105-
self.max_position_embeddings = config.max_position_embeddings
106-
self.rope_theta = config.rope_theta
107-
self.is_causal = True
10899
self.yoco_cross = yoco_cross
109100

110101
if (self.head_dim * self.num_heads) != self.hidden_size:
@@ -120,8 +111,6 @@ def __init__(self,
120111
else:
121112
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
122113

123-
assert self.config.attention_dropout == 0.0, 'Attention dropout is not supported for now'
124-
125114
# disable sliding window for the second half of the model
126115
sliding_window = config.interleaved_sliding_window[layer_idx]
127116
if layer_idx >= config.num_hidden_layers // 2 or layer_idx % 2 == 0:
@@ -161,9 +150,6 @@ def __init__(self,
161150
**params
162151
)
163152

164-
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
165-
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
166-
167153
def lambda_init_fn(self, depth):
168154
return 0.8 - 0.6 * math.exp(-0.3 * depth)
169155

@@ -181,8 +167,9 @@ def forward(
181167
attn_output = self.attn(q, k, v)
182168
else: # re-use the kv cache, full attention
183169
q = self.Wqkv(hidden_states)
184-
if self.attn.kv_cache[0].numel() == 0:
185-
self.attn.kv_cache = [kv_cache]
170+
virtual_engine = get_virtual_engine()
171+
if self.attn.kv_cache[virtual_engine].numel() == 0:
172+
self.attn.kv_cache[virtual_engine] = kv_cache
186173
attn_output = self.attn(q, None, None)
187174
attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
188175
return self.out_proj(attn_output)
@@ -227,16 +214,6 @@ def __init__(
227214
self.in_proj = MergedColumnParallelLinear(self.d_model, [self.d_inner], bias=bias, **factory_kwargs)
228215
self.out_proj = RowParallelLinear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
229216
return
230-
# self.conv1d = nn.Conv1d(
231-
# in_channels=self.d_inner,
232-
# out_channels=self.d_inner,
233-
# bias=conv_bias,
234-
# kernel_size=d_conv,
235-
# groups=self.d_inner,
236-
# padding=d_conv - 1,
237-
# **factory_kwargs,
238-
# )
239-
240217
self.conv1d = ColumnParallelLinear(
241218
input_size=d_conv,
242219
output_size=self.d_inner,
@@ -249,16 +226,12 @@ def __init__(
249226
# doesn't allow to override it
250227
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
251228

252-
# self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
253229
self.in_proj = MergedColumnParallelLinear(self.d_model,
254230
[self.d_inner] * 2,
255231
bias=bias,
256232
params_dtype=dtype,
257233
)
258234

259-
# self.x_proj = nn.Linear(
260-
# self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
261-
# )
262235
# selective projection used to make dt, B and C input dependent
263236
self.x_proj = RowParallelLinear(
264237
self.d_inner,
@@ -267,7 +240,6 @@ def __init__(
267240
params_dtype=dtype,
268241
)
269242

270-
# self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
271243
# time step projection (discretization) -
272244
# In the forward we need to apply dt_proj without the bias,
273245
# as the bias is added in the selective scan kernel.
@@ -297,15 +269,13 @@ def __init__(
297269
))
298270
self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32))
299271

300-
# self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
301272
self.out_proj = RowParallelLinear(
302273
self.d_inner,
303274
self.d_model,
304275
bias=bias,
305276
input_is_parallel=True,
306277
params_dtype=dtype,
307278
)
308-
print(f"-------- layer_idx {layer_idx}")
309279
self.activation = "silu"
310280

311281
def forward(
@@ -451,9 +421,6 @@ def __init__(self,
451421
yoco_cross=self.yoco_cross, yoco_kv=self.yoco_mb, **factory_kwargs)
452422
else:
453423
self.attn = SambaAttention(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross, cache_config=cache_config, prefix=f"{prefix}.self_attn")
454-
455-
self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
456-
self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
457424
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
458425

459426
def forward(
@@ -488,21 +455,11 @@ def forward(
488455
kv_cache,
489456
attn_metadata,
490457
)
491-
try:
492-
hidden_states = residual + self.resid_attn_dropout(attn_outputs)
493-
except Exception as e:
494-
print('>>> exception: ', e)
495-
print('>>>', hidden_states.shape)
496-
print('>>>', self.layer_idx)
497-
print('>>>', residual.shape)
498-
print('>>>', self.resid_attn_dropout)
499-
print('>>>', attn_outputs)
500-
raise
501-
458+
hidden_states = residual + attn_outputs
502459
residual = hidden_states
503460
hidden_states = self.post_attention_layernorm(hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype))
504461
hidden_states = self.mlp(hidden_states)
505-
hidden_states = residual + self.resid_mlp_dropout(hidden_states)
462+
hidden_states = residual + hidden_states
506463

507464
return hidden_states, ssm_output
508465

@@ -523,19 +480,14 @@ def __init__(
523480
prefix: str = ""
524481
) -> None:
525482
super().__init__()
526-
527483
self.config = config
528-
529-
self.padding_idx = config.pad_token_id
530484
self.vocab_size = config.vocab_size
531-
532-
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
533485
self.embed_tokens = VocabParallelEmbedding(
534486
self.vocab_size,
535487
config.hidden_size,
536488
org_num_embeddings=config.vocab_size,
537489
)
538-
self.embed_dropout = nn.Dropout(config.embd_pdrop)
490+
539491
# Pipeline parallel is not supported since the second half of the layers share the kv cache.
540492
if get_pp_group().world_size != 1:
541493
raise ValueError("Pipeline Parallel not supported")
@@ -591,10 +543,6 @@ def forward(
591543
hidden_states = hidden_states.index_select(0, selected_token_indices)
592544
ssm_output = ssm_output.index_select(0, selected_token_indices)
593545

594-
595-
# start_env = torch.cuda.Event(enable_timing=True)
596-
# end_env = torch.cuda.Event(enable_timing=True)
597-
# start_env.record()
598546
if layer.use_mamba:
599547
if i < self.config.num_hidden_layers // 2:
600548
mamba_cache = mamba_cache_params.at_layer_idx(mamba_state_idx)
@@ -637,9 +585,6 @@ def forward(
637585
None, # mamba_cache_params
638586
ssm_output = ssm_output
639587
)
640-
# end_env.record()
641-
# torch.cuda.synchronize()
642-
# print('>>> layer', i, 'time', start_env.elapsed_time(end_env))
643588

644589
hidden_states = self.final_layernorm(hidden_states.to(dtype=self.final_layernorm.weight.dtype))
645590
return hidden_states
@@ -690,7 +635,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
690635
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
691636
config.vocab_size,
692637
logits_as_input=False)
693-
# self.sampler = Sampler()
694638
self.sampler = get_sampler()
695639

696640
def forward(
@@ -767,7 +711,6 @@ def load_weights(
767711
weights: Iterable[Tuple[str, torch.Tensor]],
768712
):
769713
weights = {name: weight for name, weight in weights}
770-
print(f"--------- num of keys: {len(weights.keys())}")
771714
adjusted_weights = {}
772715
for name, weight in weights.items():
773716
if "A_log" in name:
@@ -777,31 +720,13 @@ def load_weights(
777720
name = name.replace("inner_cross_attn.", "")
778721
adjusted_weights[name] = weight
779722
adjusted_weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
780-
for name, loaded_weight in adjusted_weights.items():
781-
print(name, loaded_weight.shape)
782-
783-
params_dict = dict(self.named_parameters())
784-
785-
print(f"{adjusted_weights.keys() - params_dict.keys()} not in model")
786-
print(f"{params_dict.keys() - adjusted_weights.keys()} not in weights")
787-
788723
loaded_params: Set[str] = set()
789-
790724
for name, param in self.named_parameters():
791725
weight = adjusted_weights.get(name, None)
792726
if weight is not None and weight.shape != param.shape:
793-
print(f"Shape mismatch: {name} {weight.shape} {param.shape}")
727+
logger.warning(f"Shape mismatch: {name} {weight.shape} {param.shape}")
794728
loaded_params.add(name)
795729
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, strict=False)
796-
print(f"--------------- missing keys {missing_keys}")
797-
print("--------------- unexpected keys ---------------")
798-
for key in unexpected_keys:
799-
print(key)
800-
if not key.endswith("bias"):
801-
print("------- not bias -------")
802-
# assert missing_keys == ['embedding_bias', 'lm_head.weight',], f"Missing keys: {missing_keys}"
803-
# assert unexpected_keys == ['lm_head.bias',], f"Unexpected keys: {unexpected_keys}"
804-
# self.lm_head.weight.data.copy_(adjusted_weights['model.embed_tokens.weight'])
805-
# self.embedding_bias.data.copy_(adjusted_weights['lm_head.bias'])
806-
# self.embedding_bias = None
730+
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
731+
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
807732
return loaded_params

0 commit comments

Comments
 (0)