Skip to content

Commit 9ca35a8

Browse files
fix: use st max_seq_length (#167)
1 parent 00a17ea commit 9ca35a8

File tree

4 files changed

+63
-7
lines changed

4 files changed

+63
-7
lines changed

core/src/download.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@ use hf_hub::api::tokio::{ApiError, ApiRepo};
22
use std::path::PathBuf;
33
use tracing::instrument;
44

5+
// Old classes used other config names than 'sentence_bert_config.json'
6+
pub const ST_CONFIG_NAMES: [&str; 7] = [
7+
"sentence_bert_config.json",
8+
"sentence_roberta_config.json",
9+
"sentence_distilbert_config.json",
10+
"sentence_camembert_config.json",
11+
"sentence_albert_config.json",
12+
"sentence_xlm-roberta_config.json",
13+
"sentence_xlnet_config.json",
14+
];
15+
516
#[instrument(skip_all)]
617
pub async fn download_artifacts(api: &ApiRepo) -> Result<PathBuf, ApiError> {
718
let start = std::time::Instant::now();
@@ -32,3 +43,20 @@ pub async fn download_pool_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {
3243
let pool_config_path = api.get("1_Pooling/config.json").await?;
3344
Ok(pool_config_path)
3445
}
46+
47+
#[instrument(skip_all)]
48+
pub async fn download_st_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {
49+
// Try default path
50+
let err = match api.get(ST_CONFIG_NAMES[0]).await {
51+
Ok(st_config_path) => return Ok(st_config_path),
52+
Err(err) => err,
53+
};
54+
55+
for name in &ST_CONFIG_NAMES[1..] {
56+
if let Ok(st_config_path) = api.get(name).await {
57+
return Ok(st_config_path);
58+
}
59+
}
60+
61+
Err(err)
62+
}

router/src/http/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ use crate::{
99
shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType,
1010
ResponseMetadata,
1111
};
12+
use ::http::HeaderMap;
1213
use anyhow::Context;
1314
use axum::extract::Extension;
1415
use axum::http::HeaderValue;
15-
use ::http::HeaderMap;
1616
use axum::http::{Method, StatusCode};
1717
use axum::routing::{get, post};
1818
use axum::{http, Json, Router};

router/src/lib.rs

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ mod prometheus;
44

55
#[cfg(feature = "http")]
66
mod http;
7+
78
#[cfg(feature = "http")]
89
use ::http::HeaderMap;
910

1011
#[cfg(feature = "grpc")]
1112
mod grpc;
13+
1214
#[cfg(feature = "grpc")]
1315
use tonic::codegen::http::HeaderMap;
1416

@@ -25,14 +27,14 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
2527
use std::path::Path;
2628
use std::time::{Duration, Instant};
2729
use text_embeddings_backend::DType;
28-
use text_embeddings_core::download::{download_artifacts, download_pool_config};
30+
use text_embeddings_core::download::{
31+
download_artifacts, download_pool_config, download_st_config, ST_CONFIG_NAMES,
32+
};
2933
use text_embeddings_core::infer::Infer;
3034
use text_embeddings_core::queue::Queue;
3135
use text_embeddings_core::tokenization::Tokenization;
3236
use text_embeddings_core::TextEmbeddingsError;
33-
use tokenizers::decoders::metaspace::PrependScheme;
34-
use tokenizers::pre_tokenizers::sequence::Sequence;
35-
use tokenizers::{PreTokenizerWrapper, Tokenizer};
37+
use tokenizers::Tokenizer;
3638
use tracing::Span;
3739

3840
pub use logging::init_logging;
@@ -83,6 +85,9 @@ pub async fn run(
8385
let _ = download_pool_config(&api_repo).await;
8486
}
8587

88+
// Download sentence transformers config
89+
let _ = download_st_config(&api_repo).await;
90+
8691
// Download model from the Hub
8792
download_artifacts(&api_repo)
8893
.await
@@ -178,7 +183,25 @@ pub async fn run(
178183
} else {
179184
0
180185
};
181-
let max_input_length = config.max_position_embeddings - position_offset;
186+
187+
// Try to load ST Config
188+
let mut st_config: Option<STConfig> = None;
189+
for name in ST_CONFIG_NAMES {
190+
let config_path = model_root.join(name);
191+
if let Ok(config) = fs::read_to_string(config_path) {
192+
st_config =
193+
Some(serde_json::from_str(&config).context(format!("Failed to parse `{}`", name))?);
194+
break;
195+
}
196+
}
197+
let max_input_length = match st_config {
198+
Some(config) => config.max_seq_length,
199+
None => {
200+
tracing::warn!("Could not find a Sentence Transformers config");
201+
config.max_position_embeddings - position_offset
202+
}
203+
};
204+
tracing::info!("Maximum number of tokens per request: {max_input_length}");
182205

183206
let tokenization_workers = tokenization_workers.unwrap_or_else(num_cpus::get_physical);
184207

@@ -311,6 +334,11 @@ pub struct PoolConfig {
311334
pooling_mode_mean_sqrt_len_tokens: bool,
312335
}
313336

337+
#[derive(Debug, Deserialize)]
338+
pub struct STConfig {
339+
pub max_seq_length: usize,
340+
}
341+
314342
#[derive(Clone, Debug, Serialize)]
315343
#[cfg_attr(feature = "http", derive(utoipa::ToSchema))]
316344
pub struct EmbeddingModel {

router/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use anyhow::Result;
22
use clap::Parser;
3+
use mimalloc::MiMalloc;
34
use opentelemetry::global;
45
use text_embeddings_backend::DType;
56
use veil::Redact;
6-
use mimalloc::MiMalloc;
77

88
#[global_allocator]
99
static GLOBAL: MiMalloc = MiMalloc;

0 commit comments

Comments
 (0)