Skip to content

Commit 9aa020e

Browse files
feat: splade pooling (#174)
1 parent 337fbd6 commit 9aa020e

File tree

19 files changed

+2448
-348
lines changed

19 files changed

+2448
-348
lines changed

README.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,16 @@ Options:
152152
--pooling <POOLING>
153153
Optionally control the pooling method for embedding models.
154154
155-
If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json`
156-
configuration.
155+
If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json` configuration.
157156
158157
If `pooling` is set, it will override the model pooling configuration
159158
160159
[env: POOLING=]
161-
[possible values: cls, mean]
160+
161+
Possible values:
162+
- cls: Select the CLS token as embedding
163+
- mean: Apply Mean pooling to the model embeddings
164+
- splade: Apply SPLADE (Sparse Lexical and Expansion) to the model embeddings. This option is only available if the loaded model is a `ForMaskedLM` Transformer model
162165
163166
--max-concurrent-requests <MAX_CONCURRENT_REQUESTS>
164167
The maximum amount of concurrent requests for this particular deployment.

backends/candle/src/layers/layer_norm.rs

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,15 @@ impl LayerNorm {
2323
})
2424
}
2525

26-
pub fn forward(&self, hidden_states: &Tensor, residual: &Tensor) -> Result<Tensor> {
26+
pub fn forward(&self, hidden_states: &Tensor, residual: Option<&Tensor>) -> Result<Tensor> {
2727
let _enter = self.span.enter();
2828

2929
match hidden_states.device() {
3030
Device::Cpu | Device::Metal(_) => {
31-
let hidden_states = hidden_states.add(residual)?;
31+
let mut hidden_states = hidden_states.clone();
32+
if let Some(residual) = residual {
33+
hidden_states = hidden_states.add(residual)?;
34+
}
3235
let hidden_states_dtype = hidden_states.dtype();
3336
let internal_dtype = match hidden_states_dtype {
3437
DType::F16 | DType::BF16 => DType::F32,
@@ -51,19 +54,25 @@ impl LayerNorm {
5154
Device::Cuda(_) => {
5255
#[cfg(feature = "cuda")]
5356
{
54-
use candle_layer_norm::fused_add_layer_norm;
57+
use candle_layer_norm::{fused_add_layer_norm, layer_norm};
5558

5659
let original_shape = hidden_states.shape();
5760
let hidden_states = hidden_states.flatten_to(D::Minus2)?;
58-
let residual = residual.flatten_to(D::Minus2)?;
5961

60-
let (result, _) = fused_add_layer_norm(
61-
&hidden_states,
62-
&residual,
63-
&self.weight,
64-
Some(&self.bias),
65-
self.epsilon,
66-
)?;
62+
let result = if let Some(residual) = residual {
63+
let residual = residual.flatten_to(D::Minus2)?;
64+
65+
let (result, _) = fused_add_layer_norm(
66+
&hidden_states,
67+
&residual,
68+
&self.weight,
69+
Some(&self.bias),
70+
self.epsilon,
71+
)?;
72+
Ok(result)
73+
} else {
74+
layer_norm(&hidden_states, &self.weight, Some(&self.bias), self.epsilon)
75+
}?;
6776
result.reshape(original_shape)
6877
}
6978
#[cfg(not(feature = "cuda"))]

backends/candle/src/lib.rs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ use crate::compute_cap::{
1111
get_compile_compute_cap, get_runtime_compute_cap, incompatible_compute_cap,
1212
};
1313
use crate::models::{
14-
BertModel, JinaBertModel, Model, NomicBertModel, NomicConfig, PositionEmbeddingType,
14+
BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, Model, NomicBertModel,
15+
NomicConfig, PositionEmbeddingType,
1516
};
1617
#[cfg(feature = "cuda")]
17-
use crate::models::{FlashBertModel, FlashJinaBertModel, FlashNomicBertModel};
18+
use crate::models::{
19+
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashNomicBertModel,
20+
};
1821
use candle::{DType, Device};
1922
use candle_nn::VarBuilder;
2023
use models::BertConfig;
@@ -33,6 +36,8 @@ enum Config {
3336
XlmRoberta(BertConfig),
3437
Camembert(BertConfig),
3538
Roberta(BertConfig),
39+
#[serde(rename(deserialize = "distilbert"))]
40+
DistilBert(DistilBertConfig),
3641
#[serde(rename(deserialize = "nomic_bert"))]
3742
NomicBert(NomicConfig),
3843
}
@@ -119,6 +124,12 @@ impl CandleBackend {
119124
BertModel::load_roberta(vb, &config, model_type).s()?,
120125
))
121126
}
127+
(Config::DistilBert(config), Device::Cpu | Device::Metal(_)) => {
128+
tracing::info!("Starting DistilBertModel model on {:?}", device);
129+
Ok(Box::new(
130+
DistilBertModel::load(vb, &config, model_type).s()?,
131+
))
132+
}
122133
(Config::NomicBert(config), Device::Cpu | Device::Metal(_)) => {
123134
tracing::info!("Starting NomicBertModel model on {:?}", device);
124135
Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?))
@@ -175,6 +186,26 @@ impl CandleBackend {
175186
}
176187
}
177188
#[cfg(feature = "cuda")]
189+
(Config::DistilBert(config), Device::Cuda(_)) => {
190+
if cfg!(feature = "flash-attn")
191+
&& dtype == DType::F16
192+
&& &std::env::var("USE_FLASH_ATTENTION")
193+
.unwrap_or("True".to_string())
194+
.to_lowercase()
195+
== "true"
196+
{
197+
tracing::info!("Starting FlashNomicBertModel model on {:?}", device);
198+
Ok(Box::new(
199+
FlashDistilBertModel::load(vb, &config, model_type).s()?,
200+
))
201+
} else {
202+
tracing::info!("Starting DistilBertModel model on {:?}", device);
203+
Ok(Box::new(
204+
DistilBertModel::load(vb, &config, model_type).s()?,
205+
))
206+
}
207+
}
208+
#[cfg(feature = "cuda")]
178209
(Config::NomicBert(config), Device::Cuda(_)) => {
179210
if cfg!(feature = "flash-attn")
180211
&& dtype == DType::F16

backends/candle/src/models.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,25 @@ extern crate intel_mkl_src;
55
extern crate accelerate_src;
66

77
mod bert;
8+
mod distilbert;
9+
mod jina;
10+
mod nomic;
811

912
#[cfg(feature = "cuda")]
1013
mod flash_bert;
1114

1215
#[cfg(feature = "cuda")]
1316
mod flash_jina;
14-
mod jina;
1517

1618
#[cfg(feature = "cuda")]
1719
mod flash_nomic;
18-
mod nomic;
20+
21+
#[cfg(feature = "cuda")]
22+
mod flash_distilbert;
1923

2024
pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
2125
use candle::{Result, Tensor};
26+
pub use distilbert::{DistilBertConfig, DistilBertModel};
2227
pub use jina::JinaBertModel;
2328
pub use nomic::{NomicBertModel, NomicConfig};
2429
use text_embeddings_backend_core::Batch;
@@ -32,6 +37,9 @@ pub use flash_jina::FlashJinaBertModel;
3237
#[cfg(feature = "cuda")]
3338
pub use flash_nomic::FlashNomicBertModel;
3439

40+
#[cfg(feature = "cuda")]
41+
pub use flash_distilbert::FlashDistilBertModel;
42+
3543
pub(crate) trait Model {
3644
fn is_padded(&self) -> bool;
3745

backends/candle/src/models/bert.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ pub struct BertConfig {
2626
#[serde(default)]
2727
pub use_cache: bool,
2828
pub classifier_dropout: Option<f64>,
29-
pub model_type: Option<String>,
3029
pub id2label: Option<HashMap<String, String>>,
3130
}
3231

@@ -39,7 +38,7 @@ pub enum PositionEmbeddingType {
3938
}
4039

4140
#[derive(Debug)]
42-
struct BertEmbeddings {
41+
pub struct BertEmbeddings {
4342
word_embeddings: Embedding,
4443
token_type_embeddings: Embedding,
4544
position_embeddings: Embedding,
@@ -80,7 +79,7 @@ impl BertEmbeddings {
8079
})
8180
}
8281

83-
fn forward(
82+
pub fn forward(
8483
&self,
8584
input_ids: &Tensor,
8685
token_type_ids: &Tensor,
@@ -93,7 +92,9 @@ impl BertEmbeddings {
9392
let position_embeddings = self.position_embeddings.forward(position_ids)?;
9493

9594
let embeddings = input_embeddings.add(&token_type_embeddings)?;
96-
let embeddings = self.layer_norm.forward(&embeddings, &position_embeddings)?;
95+
let embeddings = self
96+
.layer_norm
97+
.forward(&embeddings, Some(&position_embeddings))?;
9798

9899
Ok(embeddings)
99100
}
@@ -255,7 +256,7 @@ impl BertAttention {
255256
let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?;
256257

257258
let hidden_states = self.dense.forward(&context_layer)?;
258-
let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?;
259+
let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?;
259260

260261
Ok(hidden_states)
261262
}
@@ -324,7 +325,7 @@ impl BertLayer {
324325

325326
let hidden_states = self.intermediate.forward(&hidden_states)?;
326327
let hidden_states = self.output.forward(&hidden_states)?;
327-
let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?;
328+
let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?;
328329

329330
Ok(hidden_states)
330331
}
@@ -469,7 +470,12 @@ impl BertModel {
469470
Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?);
470471
(pool, Some(classifier))
471472
}
472-
ModelType::Embedding(pool) => (pool, None),
473+
ModelType::Embedding(pool) => {
474+
if pool == Pool::Splade {
475+
candle::bail!("`splade` is not supported for Nomic")
476+
}
477+
(pool, None)
478+
}
473479
};
474480

475481
let (embeddings, encoder) = match (
@@ -724,6 +730,7 @@ impl BertModel {
724730

725731
(outputs.sum(1)?.broadcast_div(&input_lengths))?
726732
}
733+
Pool::Splade => unreachable!(),
727734
};
728735
Some(pooled_embeddings)
729736
} else {

0 commit comments

Comments
 (0)