1
1
use crate :: flash_attn:: flash_attn_varlen;
2
2
use crate :: layers:: { get_cos_sin, get_inv_freqs, LayerNorm , Linear } ;
3
- use crate :: models:: nomic:: { NomicBertEmbeddings , NomicBertGatedMLP } ;
3
+ use crate :: models:: nomic:: { NomicBertEmbeddings , NomicMLP } ;
4
4
use crate :: models:: { Model , NomicConfig } ;
5
5
use candle:: { DType , Device , IndexOp , Result , Tensor , D } ;
6
6
use candle_nn:: VarBuilder ;
@@ -25,16 +25,25 @@ impl NomicAttention {
25
25
let attention_head_size = config. n_embd / config. n_head ;
26
26
let hidden_size = config. n_embd ;
27
27
28
- let qkv_weight = vb. pp ( "Wqkv" ) . get (
29
- ( 3 * num_attention_heads * attention_head_size, hidden_size) ,
30
- "weight" ,
31
- ) ?;
32
- let qkv_linear = Linear :: new ( qkv_weight, None , None ) ;
28
+ let qkv_dim = 3 * num_attention_heads * attention_head_size;
29
+
30
+ let qkv_weight = vb. pp ( "Wqkv" ) . get ( ( qkv_dim, hidden_size) , "weight" ) ?;
31
+ let qkv_bias = if config. qkv_proj_bias {
32
+ Some ( vb. pp ( "Wqkv" ) . get ( ( qkv_dim, ) , "bias" ) ?)
33
+ } else {
34
+ None
35
+ } ;
36
+ let qkv_linear = Linear :: new ( qkv_weight, qkv_bias, None ) ;
33
37
34
38
let out_proj_weight = vb
35
39
. pp ( "out_proj" )
36
40
. get ( ( hidden_size, hidden_size) , "weight" ) ?;
37
- let out_proj = Linear :: new ( out_proj_weight, None , None ) ;
41
+ let out_proj_bias = if config. qkv_proj_bias {
42
+ Some ( vb. pp ( "out_proj" ) . get ( ( hidden_size, ) , "bias" ) ?)
43
+ } else {
44
+ None
45
+ } ;
46
+ let out_proj = Linear :: new ( out_proj_weight, out_proj_bias, None ) ;
38
47
39
48
let softmax_scale = ( 1. / ( attention_head_size as f64 ) . sqrt ( ) ) as f32 ;
40
49
@@ -93,17 +102,18 @@ impl NomicAttention {
93
102
94
103
struct NomicBertBlock {
95
104
attention : NomicAttention ,
96
- mlp : NomicBertGatedMLP ,
105
+ mlp : NomicMLP ,
97
106
post_attention_layer_norm : LayerNorm ,
98
107
output_layer_norm : LayerNorm ,
99
108
100
109
span : tracing:: Span ,
101
110
}
102
111
103
112
impl NomicBertBlock {
104
- pub fn load ( vb : VarBuilder , config : & NomicConfig ) -> Result < Self > {
113
+ pub fn load ( vb : VarBuilder , index : usize , config : & NomicConfig ) -> Result < Self > {
105
114
let attention = NomicAttention :: load ( vb. pp ( "attn" ) , config) ?;
106
- let mlp = NomicBertGatedMLP :: load ( vb. pp ( "mlp" ) , config) ?;
115
+
116
+ let mlp = NomicMLP :: load ( vb. pp ( "mlp" ) , index, config) ?;
107
117
108
118
let post_attention_layer_norm =
109
119
LayerNorm :: load ( vb. pp ( "norm1" ) , config. n_embd , config. layer_norm_epsilon ) ?;
@@ -132,6 +142,7 @@ impl NomicBertBlock {
132
142
let attn_output = self
133
143
. attention
134
144
. forward ( & hidden_states, cu_seqlens, cos, sin, max_s) ?;
145
+
135
146
let hidden_states = self
136
147
. post_attention_layer_norm
137
148
. forward ( & hidden_states, Some ( & attn_output) ) ?;
@@ -145,13 +156,14 @@ impl NomicBertBlock {
145
156
146
157
struct NomicBertEncoder {
147
158
layers : Vec < NomicBertBlock > ,
159
+
148
160
span : tracing:: Span ,
149
161
}
150
162
151
163
impl NomicBertEncoder {
152
164
pub fn load ( vb : VarBuilder , config : & NomicConfig ) -> Result < Self > {
153
165
let layers = ( 0 ..config. n_layer )
154
- . map ( |index| NomicBertBlock :: load ( vb. pp ( format ! ( "layers.{index}" ) ) , config) )
166
+ . map ( |index| NomicBertBlock :: load ( vb. pp ( format ! ( "layers.{index}" ) ) , index , config) )
155
167
. collect :: < Result < Vec < _ > > > ( ) ?;
156
168
157
169
let span = tracing:: span!( tracing:: Level :: TRACE , "encoder" ) ;
@@ -170,7 +182,6 @@ impl NomicBertEncoder {
170
182
171
183
let mut hidden_states = hidden_states. clone ( ) ;
172
184
173
- // Use a loop rather than a fold as it's easier to modify when adding debug/...
174
185
for layer in self . layers . iter ( ) {
175
186
hidden_states = layer. forward ( & hidden_states, cu_seqlens, cos, sin, max_s) ?
176
187
}
@@ -419,6 +430,7 @@ impl Model for FlashNomicBertModel {
419
430
fn is_padded ( & self ) -> bool {
420
431
false
421
432
}
433
+
422
434
fn embed ( & self , batch : Batch ) -> Result < ( Option < Tensor > , Option < Tensor > ) > {
423
435
self . forward ( batch)
424
436
}
0 commit comments