Skip to content

Commit 7af5bd6

Browse files
feat: add --pooling arg (#14)
1 parent 63cc6d4 commit 7af5bd6

File tree

17 files changed

+257
-163
lines changed

17 files changed

+257
-163
lines changed

Cargo.lock

Lines changed: 11 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,22 @@ Usage: text-embeddings-router [OPTIONS]
102102
103103
Options:
104104
--model-id <MODEL_ID>
105-
The name of the model to load. Can be a MODEL_ID as listed on <https://hf.co/models> like `thenlper/gte-base`. Or it can be a local directory containing the necessary files as saved by `save_pretrained(...)` methods of transformers
105+
The name of the model to load. Can be a MODEL_ID as listed on <https://hf.co/models> like `thenlper/gte-base`.
106+
Or it can be a local directory containing the necessary files as saved by `save_pretrained(...)` methods of
107+
transformers
106108
107109
[env: MODEL_ID=]
108110
[default: thenlper/gte-base]
109111
110112
--revision <REVISION>
111-
The actual revision of the model if you're referring to a model on the hub. You can use a specific commit id or a branch like `refs/pr/2`
113+
The actual revision of the model if you're referring to a model on the hub. You can use a specific commit id
114+
or a branch like `refs/pr/2`
112115
113116
[env: REVISION=]
114117
115118
--tokenization-workers <TOKENIZATION_WORKERS>
116-
Optionally control the number of tokenizer workers used for payload tokenization, validation and truncation. Default to the number of CPU cores on the machine
119+
Optionally control the number of tokenizer workers used for payload tokenization, validation and truncation.
120+
Default to the number of CPU cores on the machine
117121
118122
[env: TOKENIZATION_WORKERS=]
119123
@@ -124,8 +128,21 @@ Options:
124128
[default: float16]
125129
[possible values: float16, float32]
126130
131+
--pooling <POOLING>
132+
Optionally control the pooling method.
133+
134+
If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json`
135+
configuration.
136+
137+
If `pooling` is set, it will override the model pooling configuration
138+
139+
[env: POOLING=]
140+
[possible values: cls, mean]
141+
127142
--max-concurrent-requests <MAX_CONCURRENT_REQUESTS>
128-
The maximum amount of concurrent requests for this particular deployment. Having a low limit will refuse clients requests instead of having them wait for too long and is usually good to handle backpressure correctly
143+
The maximum amount of concurrent requests for this particular deployment.
144+
Having a low limit will refuse clients requests instead of having them wait for too long and is usually good
145+
to handle backpressure correctly
129146
130147
[env: MAX_CONCURRENT_REQUESTS=]
131148
[default: 512]
@@ -137,7 +154,8 @@ Options:
137154
138155
For `max_batch_tokens=1000`, you could fit `10` queries of `total_tokens=100` or a single query of `1000` tokens.
139156
140-
Overall this number should be the largest possible until the model is compute bound. Since the actual memory overhead depends on the model implementation, text-embeddings-inference cannot infer this number automatically.
157+
Overall this number should be the largest possible until the model is compute bound. Since the actual memory
158+
overhead depends on the model implementation, text-embeddings-inference cannot infer this number automatically.
141159
142160
[env: MAX_BATCH_TOKENS=]
143161
[default: 16384]
@@ -171,13 +189,15 @@ Options:
171189
[default: 3000]
172190
173191
--uds-path <UDS_PATH>
174-
The name of the unix socket some text-embeddings-inference backends will use as they communicate internally with gRPC
192+
The name of the unix socket some text-embeddings-inference backends will use as they communicate internally
193+
with gRPC
175194
176195
[env: UDS_PATH=]
177196
[default: /tmp/text-embeddings-inference-server]
178197
179198
--huggingface-hub-cache <HUGGINGFACE_HUB_CACHE>
180-
The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance
199+
The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk
200+
for instance
181201
182202
[env: HUGGINGFACE_HUB_CACHE=/data]
183203

backends/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ tokio = { version = "^1.25", features = ["sync"] }
1515
tracing = "^0.1"
1616

1717
[features]
18-
clap = [ "dep:clap" ]
18+
clap = ["dep:clap", "text-embeddings-backend-core/clap"]
1919
python = ["dep:text-embeddings-backend-python"]
2020
candle = ["dep:text-embeddings-backend-candle"]
2121
mkl = ["text-embeddings-backend-candle?/mkl"]

backends/candle/src/lib.rs

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,23 @@ use crate::compute_cap::{incompatible_compute_cap, COMPILE_COMPUTE_CAP, RUNTIME_
99
use crate::models::{BertModel, EmbeddingModel, QuantBertModel};
1010
use candle::{DType, Device};
1111
use candle_nn::VarBuilder;
12-
use models::{Config, PoolConfig};
12+
use models::Config;
1313
use std::path::PathBuf;
14-
use text_embeddings_backend_core::{BackendError, Batch, Embedding, EmbeddingBackend};
14+
use text_embeddings_backend_core::{BackendError, Batch, Embedding, EmbeddingBackend, Pool};
1515

1616
pub struct CandleBackend {
1717
model: Box<dyn EmbeddingModel + Send>,
1818
device: Device,
1919
}
2020

2121
impl CandleBackend {
22-
pub fn new(model_path: PathBuf, dtype: String) -> Result<Self, BackendError> {
22+
pub fn new(model_path: PathBuf, dtype: String, pool: Pool) -> Result<Self, BackendError> {
2323
// Load config
2424
let config: String = std::fs::read_to_string(model_path.join("config.json"))
2525
.map_err(|err| BackendError::Start(err.to_string()))?;
2626
let config: Config =
2727
serde_json::from_str(&config).map_err(|err| BackendError::Start(err.to_string()))?;
2828

29-
// Load pooling config
30-
let pool_config: String = std::fs::read_to_string(model_path.join("1_Pooling/config.json"))
31-
.map_err(|err| BackendError::Start(err.to_string()))?;
32-
let pool_config: PoolConfig = serde_json::from_str(&pool_config)
33-
.map_err(|err| BackendError::Start(err.to_string()))?;
34-
3529
// Get candle device
3630
let device = match Device::cuda_if_available(0) {
3731
Ok(device) => device,
@@ -71,23 +65,17 @@ impl CandleBackend {
7165
} else {
7266
VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, &device)
7367
}
74-
.map_err(|err| BackendError::Start(err.to_string()))?;
68+
.s()?;
7569

76-
Box::new(
77-
BertModel::load(vb, &config, pool_config.into())
78-
.map_err(|err| BackendError::Start(err.to_string()))?,
79-
)
70+
Box::new(BertModel::load(vb, &config, pool).s()?)
8071
} else if &dtype == "q6k" {
8172
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
8273
model_path.join("ggml-model-q6k.bin"),
8374
)
8475
.map_err(|err| BackendError::Start(err.to_string()))?;
8576
tracing::info!("vb");
8677

87-
Box::new(
88-
QuantBertModel::load(vb, &config, pool_config.into())
89-
.map_err(|err| BackendError::Start(err.to_string()))?,
90-
)
78+
Box::new(QuantBertModel::load(vb, &config, pool).s()?)
9179
} else {
9280
return Err(BackendError::Start(format!(
9381
"dtype {dtype} is not supported"
@@ -126,17 +114,14 @@ impl CandleBackend {
126114
} else {
127115
VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, &device)
128116
}
129-
.map_err(|err| BackendError::Start(err.to_string()))?;
117+
.s()?;
130118

131119
if incompatible_compute_cap() {
132120
return Err(BackendError::Start(format!("Runtime compute cap {} is not compatible with compile time compute cap {}", *RUNTIME_COMPUTE_CAP, *COMPILE_COMPUTE_CAP)));
133121
}
134122

135123
tracing::info!("Starting FlashBert model on Cuda");
136-
Box::new(
137-
FlashBertModel::load(vb, &config, pool_config.into())
138-
.map_err(|err| BackendError::Start(err.to_string()))?,
139-
)
124+
Box::new(FlashBertModel::load(vb, &config, pool).s()?)
140125
}
141126
}
142127
};
@@ -151,8 +136,8 @@ impl EmbeddingBackend for CandleBackend {
151136
}
152137

153138
fn embed(&self, batch: Batch) -> Result<Vec<Embedding>, BackendError> {
154-
let results = self.model.embed(batch).w()?;
155-
let results = results.to_dtype(DType::F32).w()?.to_vec2().w()?;
139+
let results = self.model.embed(batch).e()?;
140+
let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?;
156141
Ok(results)
157142
}
158143

@@ -165,11 +150,15 @@ impl EmbeddingBackend for CandleBackend {
165150
}
166151

167152
pub trait WrapErr<O> {
168-
fn w(self) -> Result<O, BackendError>;
153+
fn s(self) -> Result<O, BackendError>;
154+
fn e(self) -> Result<O, BackendError>;
169155
}
170156

171157
impl<O> WrapErr<O> for Result<O, candle::Error> {
172-
fn w(self) -> Result<O, BackendError> {
158+
fn s(self) -> Result<O, BackendError> {
159+
self.map_err(|e| BackendError::Start(e.to_string()))
160+
}
161+
fn e(self) -> Result<O, BackendError> {
173162
self.map_err(|e| BackendError::Inference(e.to_string()))
174163
}
175164
}

backends/candle/src/models.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ extern crate accelerate_src;
77
mod bert;
88
mod bert_quant;
99

10-
pub use bert::{BertModel, Config, PoolConfig};
10+
pub use bert::{BertModel, Config};
1111
pub use bert_quant::QuantBertModel;
1212
use candle::{Result, Tensor};
1313
use text_embeddings_backend_core::Batch;

backends/candle/src/models/bert.rs

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use candle_nn::ops::softmax;
44
use candle_nn::{Embedding, VarBuilder};
55
use serde::Deserialize;
66
use std::collections::HashMap;
7-
use text_embeddings_backend_core::Batch;
7+
use text_embeddings_backend_core::{Batch, Pool};
88

99
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
1010
#[derive(Debug, Clone, PartialEq, Deserialize)]
@@ -37,43 +37,13 @@ pub enum HiddenAct {
3737
Relu,
3838
}
3939

40-
#[derive(Debug, PartialEq)]
41-
pub enum Pool {
42-
Cls,
43-
Mean,
44-
Max,
45-
MeanSqrt,
46-
}
47-
48-
impl From<PoolConfig> for Pool {
49-
fn from(value: PoolConfig) -> Self {
50-
if value.pooling_mode_cls_token {
51-
Pool::Cls
52-
} else if value.pooling_mode_mean_tokens {
53-
Pool::Mean
54-
} else if value.pooling_mode_max_tokens {
55-
Pool::Max
56-
} else {
57-
Pool::MeanSqrt
58-
}
59-
}
60-
}
61-
6240
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
6341
#[serde(rename_all = "lowercase")]
6442
pub enum PositionEmbeddingType {
6543
#[default]
6644
Absolute,
6745
}
6846

69-
#[derive(Debug, Clone, PartialEq, Deserialize)]
70-
pub struct PoolConfig {
71-
pooling_mode_cls_token: bool,
72-
pooling_mode_mean_tokens: bool,
73-
pooling_mode_max_tokens: bool,
74-
pooling_mode_mean_sqrt_len_tokens: bool,
75-
}
76-
7747
#[derive(Debug)]
7848
struct LayerNorm {
7949
weight: Tensor,
@@ -85,8 +55,12 @@ struct LayerNorm {
8555
impl LayerNorm {
8656
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
8757
Ok(Self {
88-
weight: vb.get(config.hidden_size, "weight")?,
89-
bias: vb.get(config.hidden_size, "bias")?,
58+
weight: vb
59+
.get(config.hidden_size, "weight")
60+
.or_else(|_| vb.get(config.hidden_size, "gamma"))?,
61+
bias: vb
62+
.get(config.hidden_size, "bias")
63+
.or_else(|_| vb.get(config.hidden_size, "beta"))?,
9064
epsilon: config.layer_norm_eps,
9165
span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
9266
})
@@ -435,15 +409,18 @@ impl BertModel {
435409
) {
436410
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
437411
(Err(err), _) | (_, Err(err)) => {
438-
if let Some(model_type) = &config.model_type {
439-
if let (Ok(embeddings), Ok(encoder)) = (
440-
BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
441-
BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
442-
) {
443-
(embeddings, encoder)
444-
} else {
445-
return Err(err);
446-
}
412+
let model_type = config.model_type.clone().unwrap_or("bert".to_string());
413+
414+
if let (Ok(embeddings), Ok(encoder)) = (
415+
BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
416+
BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
417+
) {
418+
(embeddings, encoder)
419+
} else if let (Ok(embeddings), Ok(encoder)) = (
420+
BertEmbeddings::load(vb.pp("bert.embeddings"), config),
421+
BertEncoder::load(vb.pp("bert.encoder"), config),
422+
) {
423+
(embeddings, encoder)
447424
} else {
448425
return Err(err);
449426
}
@@ -484,7 +461,6 @@ impl BertModel {
484461
Pool::Cls => outputs.i(0..1)?,
485462
// Mean pooling
486463
Pool::Mean => (outputs.sum_keepdim(0)? / (batch.max_length as f64))?,
487-
_ => candle::bail!("Pool type {:?} is not supported", self.pool),
488464
};
489465

490466
// Normalize

0 commit comments

Comments
 (0)