Skip to content

Commit b645118

Browse files
Narsilalvarobartt
andauthored
Fix VarBuilder handling in GTE e.g. gte-multilingual-reranker-base (#538)
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
1 parent b4dc0da commit b645118

File tree

5 files changed

+3170
-22
lines changed

5 files changed

+3170
-22
lines changed

backends/candle/src/models/gte.rs

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,42 @@ impl GTEModel {
408408
ModelType::Embedding(pool) => (pool, None),
409409
};
410410

411+
let (word_embeddings, token_type_embeddings, encoder, embeddings_norm) =
412+
Self::inner_load(vb.pp("new"), config)
413+
.or_else(|_| Self::inner_load(vb.clone(), config))?;
414+
415+
let rotary_dim = encoder.layers[0].attention.attention_head_size;
416+
let inv_freqs = get_inv_freqs(
417+
rotary_dim,
418+
config.rope_theta,
419+
vb.device(),
420+
config.rope_scaling.as_ref(),
421+
)?;
422+
423+
let rotary_cache =
424+
get_cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype(), true)?;
425+
426+
Ok(Self {
427+
word_embeddings,
428+
token_type_embeddings,
429+
encoder,
430+
embeddings_norm,
431+
rotary_cache,
432+
classifier,
433+
pool,
434+
num_attention_heads: config.num_attention_heads,
435+
device: vb.device().clone(),
436+
dtype: vb.dtype(),
437+
span: tracing::span!(tracing::Level::TRACE, "model"),
438+
rotary_dim,
439+
})
440+
}
441+
442+
fn inner_load(
443+
vb: VarBuilder,
444+
config: &GTEConfig,
445+
) -> Result<(Embedding, Option<Embedding>, GTEEncoder, LayerNorm)> {
446+
411447
let word_embeddings = Embedding::new(
412448
vb.pp("embeddings.word_embeddings")
413449
.get((config.vocab_size, config.hidden_size), "weight")?,
@@ -431,32 +467,12 @@ impl GTEModel {
431467
config.hidden_size,
432468
config.layer_norm_eps,
433469
)?;
434-
435-
let rotary_dim = encoder.layers[0].attention.attention_head_size;
436-
let inv_freqs = get_inv_freqs(
437-
rotary_dim,
438-
config.rope_theta,
439-
vb.device(),
440-
config.rope_scaling.as_ref(),
441-
)?;
442-
443-
let rotary_cache =
444-
get_cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype(), true)?;
445-
446-
Ok(Self {
470+
Ok((
447471
word_embeddings,
448472
token_type_embeddings,
449473
encoder,
450474
embeddings_norm,
451-
rotary_cache,
452-
classifier,
453-
pool,
454-
num_attention_heads: config.num_attention_heads,
455-
device: vb.device().clone(),
456-
dtype: vb.dtype(),
457-
span: tracing::span!(tracing::Level::TRACE, "model"),
458-
rotary_dim,
459-
})
475+
))
460476
}
461477

462478
pub fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {

0 commit comments

Comments
 (0)