Skip to content

Commit 69714ef

Browse files
authored
Fusing both Gte Configs. (#530)
1 parent e50e195 commit 69714ef

File tree

3 files changed

+3
-6
lines changed

3 files changed

+3
-6
lines changed

backends/candle/src/lib.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,8 @@ enum Config {
5959
NomicBert(NomicConfig),
6060
#[allow(dead_code)]
6161
Mistral(MistralConfig),
62+
#[serde(alias = "new")]
6263
Gte(GTEConfig),
63-
#[serde(rename = "new")]
64-
GteAlibaba(GTEConfig),
6564
#[allow(dead_code)]
6665
Qwen2(Qwen2Config),
6766
#[serde(rename = "mpnet")]
@@ -224,7 +223,7 @@ impl CandleBackend {
224223
"Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
225224
.to_string(),
226225
)),
227-
(Config::Gte(config) | Config::GteAlibaba(config), Device::Cpu | Device::Metal(_)) => {
226+
(Config::Gte(config), Device::Cpu | Device::Metal(_)) => {
228227
tracing::info!("Starting GTE model on {:?}", device);
229228
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
230229
}
@@ -355,7 +354,7 @@ impl CandleBackend {
355354
))
356355
}
357356
#[cfg(feature = "cuda")]
358-
(Config::Gte(config) | Config::GteAlibaba(config), Device::Cuda(_)) => {
357+
(Config::Gte(config), Device::Cuda(_)) => {
359358
if dtype != DType::F16
360359
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
361360
{

backends/candle/tests/snapshots/test_gte__snowflake_gte_batch.snap

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2307,4 +2307,3 @@ expression: embeddings_batch
23072307
- -0.16524515
23082308
- -0.100704014
23092309
- 0.3677737
2310-

backends/candle/tests/snapshots/test_gte__snowflake_gte_single.snap

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,4 +771,3 @@ expression: embeddings_single
771771
- -0.16524515
772772
- -0.100704014
773773
- 0.3677737
774-

0 commit comments

Comments
 (0)