1
- from typing import List , Optional , Tuple , Union , Iterable , Dict
1
+ from typing import List , Optional , Tuple , Union , Iterable
2
2
import math
3
- import copy
4
3
5
4
import torch
6
5
import torch .nn as nn
17
16
RowParallelLinear ,
18
17
ColumnParallelLinear )
19
18
from vllm .model_executor .layers .logits_processor import LogitsProcessor
20
- from vllm .model_executor .layers .sampler import Sampler , SamplerOutput
19
+ from vllm .model_executor .layers .sampler import SamplerOutput
21
20
from vllm .model_executor .layers .vocab_parallel_embedding import (
22
21
DEFAULT_VOCAB_PADDING_SIZE , ParallelLMHead )
23
22
from vllm .model_executor .sampling_metadata import SamplingMetadata
30
29
causal_conv1d_fn , causal_conv1d_update )
31
30
from vllm .model_executor .layers .mamba .ops .mamba_ssm import (
32
31
selective_scan_fn , selective_state_update )
33
- from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
34
- AttentionMetadata , AttentionType )
35
- from vllm .vllm_flash_attn import (flash_attn_varlen_func ,
36
- flash_attn_with_kvcache )
32
+ from vllm .attention .backends .abstract import (AttentionMetadata , AttentionType )
37
33
38
34
from vllm .logger import init_logger
39
35
from .utils import (maybe_prefix , make_layers )
@@ -52,6 +48,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
52
48
# print(f"x1 shape: {x1.shape}, x2 shape: {x2.shape}")
53
49
return x1 * nn .functional .silu (x2 )
54
50
51
+
55
52
class SambaMLP (nn .Module ):
56
53
"""Gated Linear Unit.
57
54
@@ -77,34 +74,28 @@ def forward(self, hidden_states):
77
74
return self .fc2 (y )
78
75
79
76
80
- class SambaAttention (nn .Module ):
81
- """Multi-headed attention from 'Attention Is All You Need' paper"""
77
+ def get_virtual_engine ():
78
+ forward_context : ForwardContext = get_forward_context ()
79
+ return forward_context .virtual_engine
82
80
81
+ class SambaAttention (nn .Module ):
83
82
def __init__ (self ,
84
83
config ,
85
84
layer_idx : Optional [int ] = None ,
86
85
yoco_cross : bool = False ,
87
86
cache_config : Optional [CacheConfig ] = None ,
88
87
prefix : str = "" ):
89
88
super ().__init__ ()
90
- self .config = config
91
- self .layer_idx = layer_idx
92
89
if layer_idx is None :
93
90
logger .warning_once (
94
91
f"Instantiating { self .__class__ .__name__ } without passing a `layer_idx` is not recommended and will "
95
92
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
96
93
"when creating this class."
97
94
)
98
-
99
- self .attention_dropout = config .attention_dropout
100
95
self .hidden_size = config .hidden_size
101
96
self .num_heads = config .num_attention_heads
102
97
self .head_dim = self .hidden_size // self .num_heads
103
98
self .num_key_value_heads = config .num_key_value_heads
104
- self .num_key_value_groups = self .num_heads // self .num_key_value_heads
105
- self .max_position_embeddings = config .max_position_embeddings
106
- self .rope_theta = config .rope_theta
107
- self .is_causal = True
108
99
self .yoco_cross = yoco_cross
109
100
110
101
if (self .head_dim * self .num_heads ) != self .hidden_size :
@@ -120,8 +111,6 @@ def __init__(self,
120
111
else :
121
112
self .Wqkv = nn .Linear (self .hidden_size , op_size , bias = True )
122
113
123
- assert self .config .attention_dropout == 0.0 , 'Attention dropout is not supported for now'
124
-
125
114
# disable sliding window for the second half of the model
126
115
sliding_window = config .interleaved_sliding_window [layer_idx ]
127
116
if layer_idx >= config .num_hidden_layers // 2 or layer_idx % 2 == 0 :
@@ -161,9 +150,6 @@ def __init__(self,
161
150
** params
162
151
)
163
152
164
- self ._k_scale = torch .tensor (1.0 , dtype = torch .float32 )
165
- self ._v_scale = torch .tensor (1.0 , dtype = torch .float32 )
166
-
167
153
def lambda_init_fn (self , depth ):
168
154
return 0.8 - 0.6 * math .exp (- 0.3 * depth )
169
155
@@ -181,8 +167,9 @@ def forward(
181
167
attn_output = self .attn (q , k , v )
182
168
else : # re-use the kv cache, full attention
183
169
q = self .Wqkv (hidden_states )
184
- if self .attn .kv_cache [0 ].numel () == 0 :
185
- self .attn .kv_cache = [kv_cache ]
170
+ virtual_engine = get_virtual_engine ()
171
+ if self .attn .kv_cache [virtual_engine ].numel () == 0 :
172
+ self .attn .kv_cache [virtual_engine ] = kv_cache
186
173
attn_output = self .attn (q , None , None )
187
174
attn_output = attn_output .view (- 1 , self .num_heads * self .head_dim )
188
175
return self .out_proj (attn_output )
@@ -227,16 +214,6 @@ def __init__(
227
214
self .in_proj = MergedColumnParallelLinear (self .d_model , [self .d_inner ], bias = bias , ** factory_kwargs )
228
215
self .out_proj = RowParallelLinear (self .d_inner , self .d_model , bias = bias , ** factory_kwargs )
229
216
return
230
- # self.conv1d = nn.Conv1d(
231
- # in_channels=self.d_inner,
232
- # out_channels=self.d_inner,
233
- # bias=conv_bias,
234
- # kernel_size=d_conv,
235
- # groups=self.d_inner,
236
- # padding=d_conv - 1,
237
- # **factory_kwargs,
238
- # )
239
-
240
217
self .conv1d = ColumnParallelLinear (
241
218
input_size = d_conv ,
242
219
output_size = self .d_inner ,
@@ -249,16 +226,12 @@ def __init__(
249
226
# doesn't allow to override it
250
227
self .conv1d .weight .data = self .conv1d .weight .data .unsqueeze (1 )
251
228
252
- # self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
253
229
self .in_proj = MergedColumnParallelLinear (self .d_model ,
254
230
[self .d_inner ] * 2 ,
255
231
bias = bias ,
256
232
params_dtype = dtype ,
257
233
)
258
234
259
- # self.x_proj = nn.Linear(
260
- # self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
261
- # )
262
235
# selective projection used to make dt, B and C input dependent
263
236
self .x_proj = RowParallelLinear (
264
237
self .d_inner ,
@@ -267,7 +240,6 @@ def __init__(
267
240
params_dtype = dtype ,
268
241
)
269
242
270
- # self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
271
243
# time step projection (discretization) -
272
244
# In the forward we need to apply dt_proj without the bias,
273
245
# as the bias is added in the selective scan kernel.
@@ -297,15 +269,13 @@ def __init__(
297
269
))
298
270
self .D = nn .Parameter (torch .ones (self .d_inner , dtype = torch .float32 ))
299
271
300
- # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
301
272
self .out_proj = RowParallelLinear (
302
273
self .d_inner ,
303
274
self .d_model ,
304
275
bias = bias ,
305
276
input_is_parallel = True ,
306
277
params_dtype = dtype ,
307
278
)
308
- print (f"-------- layer_idx { layer_idx } " )
309
279
self .activation = "silu"
310
280
311
281
def forward (
@@ -451,9 +421,6 @@ def __init__(self,
451
421
yoco_cross = self .yoco_cross , yoco_kv = self .yoco_mb , ** factory_kwargs )
452
422
else :
453
423
self .attn = SambaAttention (config , layer_idx = layer_idx , yoco_cross = self .yoco_cross , cache_config = cache_config , prefix = f"{ prefix } .self_attn" )
454
-
455
- self .resid_attn_dropout = nn .Dropout (config .resid_pdrop )
456
- self .resid_mlp_dropout = nn .Dropout (config .resid_pdrop )
457
424
self .post_attention_layernorm = nn .LayerNorm (config .hidden_size , eps = config .layer_norm_eps )
458
425
459
426
def forward (
@@ -488,21 +455,11 @@ def forward(
488
455
kv_cache ,
489
456
attn_metadata ,
490
457
)
491
- try :
492
- hidden_states = residual + self .resid_attn_dropout (attn_outputs )
493
- except Exception as e :
494
- print ('>>> exception: ' , e )
495
- print ('>>>' , hidden_states .shape )
496
- print ('>>>' , self .layer_idx )
497
- print ('>>>' , residual .shape )
498
- print ('>>>' , self .resid_attn_dropout )
499
- print ('>>>' , attn_outputs )
500
- raise
501
-
458
+ hidden_states = residual + attn_outputs
502
459
residual = hidden_states
503
460
hidden_states = self .post_attention_layernorm (hidden_states .to (dtype = self .post_attention_layernorm .weight .dtype ))
504
461
hidden_states = self .mlp (hidden_states )
505
- hidden_states = residual + self . resid_mlp_dropout ( hidden_states )
462
+ hidden_states = residual + hidden_states
506
463
507
464
return hidden_states , ssm_output
508
465
@@ -523,19 +480,14 @@ def __init__(
523
480
prefix : str = ""
524
481
) -> None :
525
482
super ().__init__ ()
526
-
527
483
self .config = config
528
-
529
- self .padding_idx = config .pad_token_id
530
484
self .vocab_size = config .vocab_size
531
-
532
- # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
533
485
self .embed_tokens = VocabParallelEmbedding (
534
486
self .vocab_size ,
535
487
config .hidden_size ,
536
488
org_num_embeddings = config .vocab_size ,
537
489
)
538
- self . embed_dropout = nn . Dropout ( config . embd_pdrop )
490
+
539
491
# Pipeline parallel is not supported since the second half of the layers share the kv cache.
540
492
if get_pp_group ().world_size != 1 :
541
493
raise ValueError ("Pipeline Parallel not supported" )
@@ -591,10 +543,6 @@ def forward(
591
543
hidden_states = hidden_states .index_select (0 , selected_token_indices )
592
544
ssm_output = ssm_output .index_select (0 , selected_token_indices )
593
545
594
-
595
- # start_env = torch.cuda.Event(enable_timing=True)
596
- # end_env = torch.cuda.Event(enable_timing=True)
597
- # start_env.record()
598
546
if layer .use_mamba :
599
547
if i < self .config .num_hidden_layers // 2 :
600
548
mamba_cache = mamba_cache_params .at_layer_idx (mamba_state_idx )
@@ -637,9 +585,6 @@ def forward(
637
585
None , # mamba_cache_params
638
586
ssm_output = ssm_output
639
587
)
640
- # end_env.record()
641
- # torch.cuda.synchronize()
642
- # print('>>> layer', i, 'time', start_env.elapsed_time(end_env))
643
588
644
589
hidden_states = self .final_layernorm (hidden_states .to (dtype = self .final_layernorm .weight .dtype ))
645
590
return hidden_states
@@ -690,7 +635,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
690
635
self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
691
636
config .vocab_size ,
692
637
logits_as_input = False )
693
- # self.sampler = Sampler()
694
638
self .sampler = get_sampler ()
695
639
696
640
def forward (
@@ -767,7 +711,6 @@ def load_weights(
767
711
weights : Iterable [Tuple [str , torch .Tensor ]],
768
712
):
769
713
weights = {name : weight for name , weight in weights }
770
- print (f"--------- num of keys: { len (weights .keys ())} " )
771
714
adjusted_weights = {}
772
715
for name , weight in weights .items ():
773
716
if "A_log" in name :
@@ -777,31 +720,13 @@ def load_weights(
777
720
name = name .replace ("inner_cross_attn." , "" )
778
721
adjusted_weights [name ] = weight
779
722
adjusted_weights ["lm_head.weight" ] = weights ["model.embed_tokens.weight" ]
780
- for name , loaded_weight in adjusted_weights .items ():
781
- print (name , loaded_weight .shape )
782
-
783
- params_dict = dict (self .named_parameters ())
784
-
785
- print (f"{ adjusted_weights .keys () - params_dict .keys ()} not in model" )
786
- print (f"{ params_dict .keys () - adjusted_weights .keys ()} not in weights" )
787
-
788
723
loaded_params : Set [str ] = set ()
789
-
790
724
for name , param in self .named_parameters ():
791
725
weight = adjusted_weights .get (name , None )
792
726
if weight is not None and weight .shape != param .shape :
793
- print (f"Shape mismatch: { name } { weight .shape } { param .shape } " )
727
+ logger . warning (f"Shape mismatch: { name } { weight .shape } { param .shape } " )
794
728
loaded_params .add (name )
795
729
missing_keys , unexpected_keys = self .load_state_dict (adjusted_weights , strict = False )
796
- print (f"--------------- missing keys { missing_keys } " )
797
- print ("--------------- unexpected keys ---------------" )
798
- for key in unexpected_keys :
799
- print (key )
800
- if not key .endswith ("bias" ):
801
- print ("------- not bias -------" )
802
- # assert missing_keys == ['embedding_bias', 'lm_head.weight',], f"Missing keys: {missing_keys}"
803
- # assert unexpected_keys == ['lm_head.bias',], f"Unexpected keys: {unexpected_keys}"
804
- # self.lm_head.weight.data.copy_(adjusted_weights['model.embed_tokens.weight'])
805
- # self.embedding_bias.data.copy_(adjusted_weights['lm_head.bias'])
806
- # self.embedding_bias = None
730
+ assert len (unexpected_keys ) == 0 , f"Unexpected keys: { unexpected_keys } "
731
+ assert len (missing_keys ) == 0 , f"Missing keys: { missing_keys } "
807
732
return loaded_params
0 commit comments