@@ -408,6 +408,42 @@ impl GTEModel {
408
408
ModelType :: Embedding ( pool) => ( pool, None ) ,
409
409
} ;
410
410
411
+ let ( word_embeddings, token_type_embeddings, encoder, embeddings_norm) =
412
+ Self :: inner_load ( vb. pp ( "new" ) , config)
413
+ . or_else ( |_| Self :: inner_load ( vb. clone ( ) , config) ) ?;
414
+
415
+ let rotary_dim = encoder. layers [ 0 ] . attention . attention_head_size ;
416
+ let inv_freqs = get_inv_freqs (
417
+ rotary_dim,
418
+ config. rope_theta ,
419
+ vb. device ( ) ,
420
+ config. rope_scaling . as_ref ( ) ,
421
+ ) ?;
422
+
423
+ let rotary_cache =
424
+ get_cos_sin ( config. max_position_embeddings , & inv_freqs, vb. dtype ( ) , true ) ?;
425
+
426
+ Ok ( Self {
427
+ word_embeddings,
428
+ token_type_embeddings,
429
+ encoder,
430
+ embeddings_norm,
431
+ rotary_cache,
432
+ classifier,
433
+ pool,
434
+ num_attention_heads : config. num_attention_heads ,
435
+ device : vb. device ( ) . clone ( ) ,
436
+ dtype : vb. dtype ( ) ,
437
+ span : tracing:: span!( tracing:: Level :: TRACE , "model" ) ,
438
+ rotary_dim,
439
+ } )
440
+ }
441
+
442
+ fn inner_load (
443
+ vb : VarBuilder ,
444
+ config : & GTEConfig ,
445
+ ) -> Result < ( Embedding , Option < Embedding > , GTEEncoder , LayerNorm ) > {
446
+
411
447
let word_embeddings = Embedding :: new (
412
448
vb. pp ( "embeddings.word_embeddings" )
413
449
. get ( ( config. vocab_size , config. hidden_size ) , "weight" ) ?,
@@ -431,32 +467,12 @@ impl GTEModel {
431
467
config. hidden_size ,
432
468
config. layer_norm_eps ,
433
469
) ?;
434
-
435
- let rotary_dim = encoder. layers [ 0 ] . attention . attention_head_size ;
436
- let inv_freqs = get_inv_freqs (
437
- rotary_dim,
438
- config. rope_theta ,
439
- vb. device ( ) ,
440
- config. rope_scaling . as_ref ( ) ,
441
- ) ?;
442
-
443
- let rotary_cache =
444
- get_cos_sin ( config. max_position_embeddings , & inv_freqs, vb. dtype ( ) , true ) ?;
445
-
446
- Ok ( Self {
470
+ Ok ( (
447
471
word_embeddings,
448
472
token_type_embeddings,
449
473
encoder,
450
474
embeddings_norm,
451
- rotary_cache,
452
- classifier,
453
- pool,
454
- num_attention_heads : config. num_attention_heads ,
455
- device : vb. device ( ) . clone ( ) ,
456
- dtype : vb. dtype ( ) ,
457
- span : tracing:: span!( tracing:: Level :: TRACE , "model" ) ,
458
- rotary_dim,
459
- } )
475
+ ) )
460
476
}
461
477
462
478
pub fn forward ( & self , batch : Batch ) -> Result < ( Option < Tensor > , Option < Tensor > ) > {
0 commit comments