1
+ from typing import Iterable , Optional , Set , Tuple , Union
2
+
3
+ import torch
4
+ import torch .nn .functional as F
5
+ from torch import nn
6
+ from transformers import ACT2FN
7
+
8
+ from kaiju import KaijuTextConfig
9
+
10
+ from vllm .config import CacheConfig , VllmConfig
11
+ from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
12
+ from vllm .logger import init_logger
13
+ from vllm .model_executor .layers .layernorm import KaijuRMSNorm
14
+ from vllm .model_executor .layers .quantization import QuantizationConfig
15
+
16
+ # from vllm.attention import Attention
17
+ # from vllm.compilation.decorators import support_torch_compile
18
+
19
+ # from vllm.logger import init_logger
20
+ # from vllm.model_executor.layers.activation import GeluAndMul
21
+
22
+ # from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
23
+ # QKVParallelLinear,
24
+ # RowParallelLinear)
25
+ # from vllm.model_executor.layers.logits_processor import LogitsProcessor
26
+ # from vllm.model_executor.layers.rotary_embedding import get_rope
27
+ # from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
28
+ # from vllm.model_executor.layers.vocab_parallel_embedding import (
29
+ # VocabParallelEmbedding)
30
+ # from vllm.model_executor.model_loader.weight_utils import (
31
+ # default_weight_loader, maybe_remap_kv_scale_name)
32
+ # from vllm.model_executor.sampling_metadata import SamplingMetadata
33
+ # from vllm.sequence import IntermediateTensors
34
+
35
+ # from .interfaces import SupportsLoRA, SupportsPP
36
+ from .utils import (AutoWeightsLoader , extract_layer_index ,
37
+ is_pp_missing_parameter ,
38
+ make_empty_intermediate_tensors_factory , make_layers ,
39
+ maybe_prefix )
40
+
41
+ logger = init_logger (__name__ )
42
+
43
+ class KaijuMLP (nn .Module ):
44
+ def __init__ (self ,
45
+ hidden_size : int ,
46
+ intermediate_size : int ,
47
+ hidden_act : str ,
48
+ rms_norm_eps : float ,
49
+ ):
50
+ super ().__init__ ()
51
+ self .hidden_size = hidden_size
52
+ self .intermediate_size = intermediate_size
53
+
54
+ self .residual_scale = nn .Parameter (torch .zeros (self .hidden_size ), requires_grad = False )
55
+ self .pre_ffn_norm = KaijuRMSNorm (self .hidden_size , eps = rms_norm_eps )
56
+
57
+ # TODO: Megatron style TP (MergedColumnParallelLinear then RowParallelLinear)
58
+ self .W_in = nn .Linear (self .hidden_size , self .intermediate_size , bias = False )
59
+ self .W_out = nn .Linear (self .intermediate_size , self .hidden_size , bias = False )
60
+ self .act_fn = ACT2FN [config .hidden_act ]
61
+
62
+ def forward (self , x ):
63
+ # WARNING: In whippet checkpoints, there is an `args["quantize"]["ffn_clamp_middle_output"]`
64
+ # It's only used in the backward pass in specific circumstances.
65
+ hidden_states = x
66
+ x = self .W_in (x )
67
+ x = clamp (x , 4 )
68
+ x = self .act_fn (x )
69
+ x = self .W_out (x )
70
+ hidden_states *= self .residual_scale
71
+ return x + hidden_states
72
+
73
+ @dataclass
74
+ class KaijuCache :
75
+ key_states : Optional [torch .Tensor ] = None
76
+ value_states : Optional [torch .Tensor ] = None
77
+
78
+ class KaijuAttention (nn .Module ):
79
+ def __init__ (self ,
80
+ config : KaijuTextConfig ,
81
+ max_position_embeddings : int ,
82
+ is_context_encoder : bool ,
83
+ cache_config : Optional [CacheConfig ] = None ,
84
+ quant_config : Optional [QuantizationConfig ] = None ,
85
+ attn_logits_soft_cap : Optional [float ] = None ,
86
+ prefix : str = ""
87
+ ):
88
+ super ().__init__ ()
89
+ self .config = config
90
+ self .hidden_size = config .hidden_size
91
+ self .is_context_encoder = is_context_encoder
92
+ tp_size = get_tensor_model_parallel_world_size ()
93
+ self .total_num_heads = config .num_attention_heads
94
+ assert self .total_num_heads % tp_size == 0
95
+ self .num_heads = self .total_num_heads // tp_size
96
+ self .total_num_kv_heads = config .num_key_value_heads
97
+ if self .total_num_kv_heads >= tp_size :
98
+ # Number of KV heads is greater than TP size, so we partition
99
+ # the KV heads across multiple tensor parallel GPUs.
100
+ assert self .total_num_kv_heads % tp_size == 0
101
+ else :
102
+ # Number of KV heads is less than TP size, so we replicate
103
+ # the KV heads across multiple tensor parallel GPUs.
104
+ assert tp_size % self .total_num_kv_heads == 0
105
+ self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
106
+ self .head_dim = getattr (config , "head_dim" , config .hidden_size // config .num_attention_heads )
107
+ self .q_size = self .num_heads * self .head_dim
108
+ self .kv_size = self .num_kv_heads * self .head_dim
109
+ self .scaling = self .head_dim ** - 0.5
110
+
111
+ # TODO: Combine into single proj matrix and use QKVParallelLinear
112
+ self .q_proj = nn .Linear (
113
+ self .hidden_size , self .q_size , bias = False
114
+ )
115
+ if not self .is_context_encoder :
116
+ self .k_proj = nn .Linear (
117
+ self .hidden_size , self .kv_size , bias = False
118
+ )
119
+ self .v_proj = nn .Linear (
120
+ self .hidden_size , self .kv_size , bias = False
121
+ )
122
+
123
+ # TODO: Use RowParallelLinear
124
+ self .o_proj = nn .Linear (
125
+ self .num_heads * self .head_dim , self .hidden_size , bias = False
126
+ )
127
+
128
+ self .pre_projection_norm = KaijuRMSNorm (self .config .hidden_size , eps = config .rms_norm_eps )
129
+
130
+ layer_idx = extract_layer_index (prefix )
131
+ self .is_sliding = layer_idx not in self .config .global_attention_layer_schedule
132
+ if self .is_sliding :
133
+ self .sliding_window = 1024
134
+ else :
135
+ self .sliding_window = None
136
+
137
+ self .attn = Attention (
138
+ self .num_heads ,
139
+ self .head_dim ,
140
+ self .scaling ,
141
+ num_kv_heads = self .num_kv_heads ,
142
+ cache_config = cache_config ,
143
+ quant_config = quant_config ,
144
+ logits_soft_cap = attn_logits_soft_cap ,
145
+ per_layer_sliding_window = self .sliding_window ,
146
+ prefix = f"{ prefix } .attn"
147
+ )
148
+
149
+ def forward (
150
+ self ,
151
+ positions_embeddings : Tuple [torch .Tensor , torch .Tensor ],
152
+ hidden_states : torch .Tensor ,
153
+ kv_cache : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None
154
+ ) -> torch .Tensor :
155
+
156
+ processed_hidden_states = self .pre_projection_norm (hidden_states )
157
+ input_shape = hidden_states .shape [:- 1 ]
158
+ hidden_shape = (* input_shape , - 1 , self .head_dim )
159
+
160
+ cos , sin = position_embeddings
161
+ query_states = self .q_proj (processed_hidden_states ).view (hidden_shape )
162
+
163
+ if self .is_context_encoder :
164
+ assert kv_cache is None
165
+ key_states = kv_cache .key_states
166
+ value_states = kv_cache .value_states
167
+ else :
168
+ key_states = self .k_proj (processed_hidden_states ).view (hidden_shape )
169
+ value_states = self .v_proj (processed_hidden_states ).view (hidden_shape )
170
+
171
+ if kv_cache is not None :
172
+ key_states = kv_cache .key_states
173
+ value_states = kv_cache .value_states
174
+
175
+
176
+ # We should probably cache the clamped values.
177
+ query_states = clamp (query_states , 4 )
178
+ key_states = clamp (key_states , 4 )
179
+ value_states = clamp (value_states , 4 )
180
+
181
+ # Should we cache post rope?
182
+ query_states , key_states = apply_rotary_pos_emb_kaiju (query_states , key_states , cos , sin , unsqueeze_dim = 2 )
183
+
184
+ # TODO: attention masking
185
+ attn_output = self .attn (query_states , key_states , value_states )
186
+
187
+ attn_output = attn_output .reshape (* input_shape , - 1 ).contiguous ()
188
+ attn_output = self .o_proj (attn_output )
189
+
190
+ hidden_states *= self .residual_scale
191
+ hidden_states += attn_output
192
+
193
+ return hidden_states
194
+
195
+ class KaijuDecoderLayer (nn .Module ):
196
+ def __init__ (
197
+ self ,
198
+ config : KaijuTextConfig ,
199
+ is_context_encoder : bool ,
200
+ cache_config : Optional [CacheConfig ] = None ,
201
+ quant_config : Optional [QuantizationConfig ] = None ,
202
+ prefix : str = ""
203
+ ):
204
+ super ().__init__ ()
205
+ self .hidden_size = config .hidden_size
206
+ self .self_attn = KaijuAttention (
207
+ config = config ,
208
+ max_position_embeddings = config .max_position_embeddings ,
209
+ is_context_encoder = is_context_encoder ,
210
+ cache_config = cache_config ,
211
+ quant_config = quant_config ,
212
+ attn_logits_soft_cap = None ,
213
+ prefix = f"{ prefix } .self_attn"
214
+ )
215
+
216
+ self .mlp = KaijuMLP (
217
+ hidden_size = self .hidden_size ,
218
+ intermediate_size = config .intermediate_size ,
219
+ hidden_act = config .hidden_act ,
220
+ rms_norm_eps = config .rms_norm_eps ,
221
+ )
222
+
223
+ def forward (
224
+ self ,
225
+ positions_embeddings : Tuple [torch .Tensor , torch .Tensor ],
226
+ hidden_states : torch .Tensor ,
227
+ output_attentions : bool = False ,
228
+ kv_cache : Optional [KaijuCache ] = None
229
+ ) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
230
+ # Self Attention
231
+ # attention module handles the residual stream update.
232
+ hidden_states = self .self_attn (
233
+ hidden_states = hidden_states ,
234
+ position_embeddings = position_embeddings ,
235
+ kv_cache = kv_cache ,
236
+ )
237
+
238
+ # Fully Connected
239
+ hidden_states = self .mlp (hidden_states )
240
+
241
+ outputs = (hidden_states ,)
242
+ # This isn't necessary for inference, we can consider writing a slow
243
+ # attention implementation for debugging purposes.
244
+ assert not output_attentions , "TODO: Support this"
245
+
246
+ return outputs
247
+
248
+ @support_torch_compile
249
+ class KaijuModel (nn .Module ):
250
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
251
+ super ().__init__ ()
252
+ config = vllm_config .model_config .hf_config
253
+ cache_config = vllm_config .cache_config
254
+ quant_config = vllm_config .quant_config
255
+ self .config = config
256
+ self .quant_config = quant_config
257
+
258
+ self .layer_to_kv_group = list (range (config .num_hidden_layers ))
259
+ for layers in config .share_kv_schedule :
260
+ for layer_idx in layers :
261
+ self .layer_to_kv_group [layer_idx ] = min (layers )
262
+
263
+ self .padding_idx = config .pad_token_id
264
+ self .vocab_size = config .vocab_size
265
+
266
+ # Vocab parallel embedding
267
+ self .embed_tokens = nn .Embedding (config .vocab_size , config .hidden_size , self .padding_idx )
268
+ # TODO: Get rid of this scale by "compiling" it into the embedding weights, then
269
+ # when we convert the lm head/etc we can just adjust that scale.
270
+ self .embedding_scale = nn .Parameter (torch .FloatTensor ([0 ]), requires_grad = False )
271
+
272
+ self .start_layer , self .end_layer , self .layers = make_layers_with_idx (
273
+ config .num_hidden_layers ,
274
+ lambda prefix , idx : KaijuDecoderLayer (
275
+ config , is_context_encoder = idx != self .layer_to_kv_group [idx ], cache_config = cache_config , quant_config = quant_config , prefix = prefix
276
+ ),
277
+ prefix = f"{ prefix } .layers"
278
+ )
279
+
280
+
281
+
282
+
283
+
284
+
285
+
286
+
287
+
288
+
289
+
290
+
291
+
0 commit comments