@@ -4,11 +4,13 @@ mod prometheus;
4
4
5
5
#[ cfg( feature = "http" ) ]
6
6
mod http;
7
+
7
8
#[ cfg( feature = "http" ) ]
8
9
use :: http:: HeaderMap ;
9
10
10
11
#[ cfg( feature = "grpc" ) ]
11
12
mod grpc;
13
+
12
14
#[ cfg( feature = "grpc" ) ]
13
15
use tonic:: codegen:: http:: HeaderMap ;
14
16
@@ -25,14 +27,14 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
25
27
use std:: path:: Path ;
26
28
use std:: time:: { Duration , Instant } ;
27
29
use text_embeddings_backend:: DType ;
28
- use text_embeddings_core:: download:: { download_artifacts, download_pool_config} ;
30
+ use text_embeddings_core:: download:: {
31
+ download_artifacts, download_pool_config, download_st_config, ST_CONFIG_NAMES ,
32
+ } ;
29
33
use text_embeddings_core:: infer:: Infer ;
30
34
use text_embeddings_core:: queue:: Queue ;
31
35
use text_embeddings_core:: tokenization:: Tokenization ;
32
36
use text_embeddings_core:: TextEmbeddingsError ;
33
- use tokenizers:: decoders:: metaspace:: PrependScheme ;
34
- use tokenizers:: pre_tokenizers:: sequence:: Sequence ;
35
- use tokenizers:: { PreTokenizerWrapper , Tokenizer } ;
37
+ use tokenizers:: Tokenizer ;
36
38
use tracing:: Span ;
37
39
38
40
pub use logging:: init_logging;
@@ -83,6 +85,9 @@ pub async fn run(
83
85
let _ = download_pool_config ( & api_repo) . await ;
84
86
}
85
87
88
+ // Download sentence transformers config
89
+ let _ = download_st_config ( & api_repo) . await ;
90
+
86
91
// Download model from the Hub
87
92
download_artifacts ( & api_repo)
88
93
. await
@@ -178,7 +183,25 @@ pub async fn run(
178
183
} else {
179
184
0
180
185
} ;
181
- let max_input_length = config. max_position_embeddings - position_offset;
186
+
187
+ // Try to load ST Config
188
+ let mut st_config: Option < STConfig > = None ;
189
+ for name in ST_CONFIG_NAMES {
190
+ let config_path = model_root. join ( name) ;
191
+ if let Ok ( config) = fs:: read_to_string ( config_path) {
192
+ st_config =
193
+ Some ( serde_json:: from_str ( & config) . context ( format ! ( "Failed to parse `{}`" , name) ) ?) ;
194
+ break ;
195
+ }
196
+ }
197
+ let max_input_length = match st_config {
198
+ Some ( config) => config. max_seq_length ,
199
+ None => {
200
+ tracing:: warn!( "Could not find a Sentence Transformers config" ) ;
201
+ config. max_position_embeddings - position_offset
202
+ }
203
+ } ;
204
+ tracing:: info!( "Maximum number of tokens per request: {max_input_length}" ) ;
182
205
183
206
let tokenization_workers = tokenization_workers. unwrap_or_else ( num_cpus:: get_physical) ;
184
207
@@ -311,6 +334,11 @@ pub struct PoolConfig {
311
334
pooling_mode_mean_sqrt_len_tokens : bool ,
312
335
}
313
336
337
+ #[ derive( Debug , Deserialize ) ]
338
+ pub struct STConfig {
339
+ pub max_seq_length : usize ,
340
+ }
341
+
314
342
#[ derive( Clone , Debug , Serialize ) ]
315
343
#[ cfg_attr( feature = "http" , derive( utoipa:: ToSchema ) ) ]
316
344
pub struct EmbeddingModel {
0 commit comments