1
1
use std:: collections:: HashMap ;
2
2
3
3
use crate :: flash_attn:: flash_attn_varlen;
4
- use crate :: layers:: { apply_rotary , get_cos_sin, get_inv_freqs, LayerNorm , Linear } ;
4
+ use crate :: layers:: { get_cos_sin, get_inv_freqs, LayerNormNoBias , Linear } ;
5
5
use crate :: models:: modernbert:: {
6
6
ClassificationHead , ModernBertClassificationHead , ModernBertConfig , ModernBertEmbeddings ,
7
7
ModernBertMLP ,
8
8
} ;
9
9
use crate :: models:: Model ;
10
10
use candle:: { DType , Device , IndexOp , Result , Tensor } ;
11
11
use candle_nn:: VarBuilder ;
12
+ use candle_rotary:: apply_rotary_inplace;
12
13
use text_embeddings_backend_core:: { Batch , ModelType , Pool } ;
13
14
14
15
struct ModernBertAttention {
@@ -79,35 +80,34 @@ impl ModernBertAttention {
79
80
new_qkv_shape. pop ( ) ;
80
81
new_qkv_shape. push ( self . num_attention_heads * 3 ) ;
81
82
new_qkv_shape. push ( self . attention_head_size ) ;
82
- let qkv = qkv. reshape ( new_qkv_shape. as_slice ( ) ) ?. transpose ( 1 , 2 ) ? ;
83
+ let qkv = qkv. reshape ( new_qkv_shape. as_slice ( ) ) ?;
83
84
84
- let qkv = qkv. chunk ( 3 , 1 ) ? ;
85
- let query_layer = & qkv[ 0 ] . contiguous ( ) ?;
86
- let key_layer = & qkv[ 1 ] . contiguous ( ) ?;
87
- let value_layer = & qkv[ 2 ] ;
85
+ // Split qkv tensor
86
+ let q = qkv. narrow ( 1 , 0 , self . num_attention_heads ) ?;
87
+ let k = qkv. narrow ( 1 , self . num_attention_heads , self . num_attention_heads ) ?;
88
+ let v = qkv. narrow ( 1 , self . num_attention_heads * 2 , self . num_attention_heads ) ? ;
88
89
89
- let query_layer = apply_rotary ( query_layer, cos, sin, self . attention_head_size ) ?;
90
- let key_layer = apply_rotary ( key_layer, cos, sin, self . attention_head_size ) ?;
90
+ apply_rotary_inplace ( & q, & k, & cos, & sin, true ) ?;
91
91
92
- let attention_size = if self . use_local_attention {
92
+ let window_size = if self . use_local_attention {
93
93
Some ( self . local_attention )
94
94
} else {
95
95
None
96
96
} ;
97
97
98
98
let attention = flash_attn_varlen (
99
- & query_layer ,
100
- & key_layer ,
101
- & value_layer ,
99
+ & q ,
100
+ & k ,
101
+ & v ,
102
102
None ,
103
103
cu_seqlens,
104
104
cu_seqlens,
105
105
max_s,
106
106
max_s,
107
107
self . softmax_scale ,
108
108
false ,
109
- attention_size ,
110
- attention_size ,
109
+ window_size ,
110
+ window_size ,
111
111
) ?;
112
112
let attention = attention. flatten_from ( candle:: D :: Minus2 ) ?;
113
113
@@ -118,9 +118,9 @@ impl ModernBertAttention {
118
118
}
119
119
120
120
struct ModernBertEncoderLayer {
121
- attn_norm : Option < LayerNorm > ,
121
+ attn_norm : Option < LayerNormNoBias > ,
122
122
attn : ModernBertAttention ,
123
- mlp_norm : LayerNorm ,
123
+ mlp_norm : LayerNormNoBias ,
124
124
mlp : ModernBertMLP ,
125
125
126
126
span : tracing:: Span ,
@@ -129,7 +129,7 @@ struct ModernBertEncoderLayer {
129
129
impl ModernBertEncoderLayer {
130
130
pub fn load ( vb : VarBuilder , index : usize , config : & ModernBertConfig ) -> Result < Self > {
131
131
let attn_norm = if index != 0 {
132
- Some ( LayerNorm :: load (
132
+ Some ( LayerNormNoBias :: load (
133
133
vb. pp ( "attn_norm" ) ,
134
134
config. hidden_size ,
135
135
config. norm_eps as f32 ,
@@ -140,7 +140,7 @@ impl ModernBertEncoderLayer {
140
140
141
141
let attn = ModernBertAttention :: load ( vb. pp ( "attn" ) , index, config) ?;
142
142
143
- let mlp_norm = LayerNorm :: load (
143
+ let mlp_norm = LayerNormNoBias :: load (
144
144
vb. pp ( "mlp_norm" ) ,
145
145
config. hidden_size ,
146
146
config. norm_eps as f32 ,
@@ -236,11 +236,10 @@ impl ModernBertEncoder {
236
236
pub struct FlashModernBertModel {
237
237
embeddings : ModernBertEmbeddings ,
238
238
encoder : ModernBertEncoder ,
239
- final_norm : LayerNorm ,
239
+ final_norm : LayerNormNoBias ,
240
240
pool : Pool ,
241
241
classifier : Option < Box < dyn ClassificationHead + Send > > ,
242
242
243
- rotary_dim : usize ,
244
243
rotary_cache : HashMap < bool , ( Tensor , Tensor ) > ,
245
244
246
245
device : Device ,
@@ -277,13 +276,22 @@ impl FlashModernBertModel {
277
276
}
278
277
} ;
279
278
280
- let embeddings = ModernBertEmbeddings :: load ( vb. pp ( "model.embeddings" ) , config) ?;
281
- let encoder = ModernBertEncoder :: load ( vb. pp ( "model.layers" ) , config) ?;
282
- let final_norm = LayerNorm :: load (
279
+ let embeddings = ModernBertEmbeddings :: load ( vb. pp ( "model.embeddings" ) , config)
280
+ . or_else ( |_| ModernBertEmbeddings :: load ( vb. pp ( "embeddings" ) , config) ) ?;
281
+ let encoder = ModernBertEncoder :: load ( vb. pp ( "model.layers" ) , config)
282
+ . or_else ( |_| ModernBertEncoder :: load ( vb. pp ( "layers" ) , config) ) ?;
283
+ let final_norm = LayerNormNoBias :: load (
283
284
vb. pp ( "model.final_norm" ) ,
284
285
config. hidden_size ,
285
286
config. norm_eps as f32 ,
286
- ) ?;
287
+ )
288
+ . or_else ( |_| {
289
+ LayerNormNoBias :: load (
290
+ vb. pp ( "final_norm" ) ,
291
+ config. hidden_size ,
292
+ config. norm_eps as f32 ,
293
+ )
294
+ } ) ?;
287
295
288
296
let rotary_dim = config. hidden_size / config. num_attention_heads ;
289
297
let mut rotary_cache: HashMap < bool , ( Tensor , Tensor ) > = HashMap :: new ( ) ;
@@ -295,15 +303,11 @@ impl FlashModernBertModel {
295
303
config. global_rope_theta
296
304
} ;
297
305
298
- let max_position_embeddings = if use_local_attention {
299
- config. max_position_embeddings
300
- } else {
301
- config. local_attention
302
- } ;
306
+ let max_position_embeddings = config. max_position_embeddings ;
303
307
304
308
let inv_freqs = get_inv_freqs ( rotary_dim, rope_theta as f32 , vb. device ( ) , None ) ?;
305
309
306
- let ( cos, sin) = get_cos_sin ( max_position_embeddings, & inv_freqs, vb. dtype ( ) , true ) ?;
310
+ let ( cos, sin) = get_cos_sin ( max_position_embeddings, & inv_freqs, vb. dtype ( ) , false ) ?;
307
311
308
312
rotary_cache. insert ( use_local_attention, ( cos, sin) ) ;
309
313
}
@@ -314,7 +318,6 @@ impl FlashModernBertModel {
314
318
final_norm,
315
319
pool,
316
320
classifier,
317
- rotary_dim,
318
321
rotary_cache,
319
322
device : vb. device ( ) . clone ( ) ,
320
323
span : tracing:: span!( tracing:: Level :: TRACE , "model" ) ,
@@ -343,9 +346,6 @@ impl FlashModernBertModel {
343
346
let cos = cos. index_select ( & position_ids, 0 ) ?;
344
347
let sin = sin. index_select ( & position_ids, 0 ) ?;
345
348
346
- let cos = cos. reshape ( ( batch_size, 1 , max_length, self . rotary_dim ) ) ?;
347
- let sin = sin. reshape ( ( batch_size, 1 , max_length, self . rotary_dim ) ) ?;
348
-
349
349
rotary_cache. insert ( use_local_attention, ( cos, sin) ) ;
350
350
}
351
351
0 commit comments