Skip to content

Commit 750898d

Browse files
fix: add cls pooling as default for BERT variants (#426)
1 parent 205f96c commit 750898d

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

router/src/lib.rs

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,20 @@ pub async fn run(
9191
// Optionally download the pooling config.
9292
if pooling.is_none() {
9393
// If a pooling config exist, download it
94-
let _ = download_pool_config(&api_repo).await;
94+
let _ = download_pool_config(&api_repo).await.map_err(|err| {
95+
tracing::warn!("Download failed: {err}");
96+
err
97+
});
9598
}
9699

97-
// Download sentence transformers config
100+
// Download legacy sentence transformers config
101+
// We don't warn on failure as it is a legacy file
98102
let _ = download_st_config(&api_repo).await;
99103
// Download new sentence transformers config
100-
let _ = download_new_st_config(&api_repo).await;
104+
let _ = download_new_st_config(&api_repo).await.map_err(|err| {
105+
tracing::warn!("Download failed: {err}");
106+
err
107+
});
101108

102109
// Download model from the Hub
103110
download_artifacts(&api_repo)
@@ -387,10 +394,21 @@ fn get_backend_model_type(
387394
None => {
388395
// Load pooling config
389396
let config_path = model_root.join("1_Pooling/config.json");
390-
let config = fs::read_to_string(config_path).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?;
391-
let config: PoolConfig =
392-
serde_json::from_str(&config).context("Failed to parse `1_Pooling/config.json`")?;
393-
Pool::try_from(config)?
397+
398+
match fs::read_to_string(config_path) {
399+
Ok(config) => {
400+
let config: PoolConfig = serde_json::from_str(&config)
401+
.context("Failed to parse `1_Pooling/config.json`")?;
402+
Pool::try_from(config)?
403+
}
404+
Err(err) => {
405+
if !config.model_type.to_lowercase().contains("bert") {
406+
return Err(err).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.");
407+
}
408+
tracing::warn!("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model but the model is a BERT variant. Defaulting to `CLS` pooling.");
409+
text_embeddings_backend::Pool::Cls
410+
}
411+
}
394412
}
395413
};
396414
Ok(text_embeddings_backend::ModelType::Embedding(pool))

0 commit comments

Comments
 (0)