Skip to content

Commit 1a59eaf

Browse files
committed
Add Dense, DenseLayer and DenseConfig to handle 2_Dense/
Required for some models as e.g. https://huggingface.co/sentence-transformers/LaBSE
1 parent 73935bc commit 1a59eaf

File tree

5 files changed

+118
-3
lines changed

5 files changed

+118
-3
lines changed

backends/candle/src/lib.rs

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ use crate::compute_cap::{
1111
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
1212
};
1313
use crate::models::{
14-
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
15-
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, ModernBertConfig,
16-
ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config, Qwen3Config, Qwen3Model,
14+
BertConfig, BertModel, Dense, DenseConfig, DenseLayer, DistilBertConfig, DistilBertModel,
15+
GTEConfig, GTEModel, JinaBertModel, JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig,
16+
Model, ModernBertConfig, ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config,
17+
Qwen3Config, Qwen3Model,
1718
};
1819
#[cfg(feature = "cuda")]
1920
use crate::models::{
@@ -114,6 +115,7 @@ enum Config {
114115
pub struct CandleBackend {
115116
device: Device,
116117
model: Box<dyn Model + Send>,
118+
dense: Option<Box<dyn DenseLayer + Send>>,
117119
}
118120

119121
impl CandleBackend {
@@ -468,9 +470,35 @@ impl CandleBackend {
468470
}
469471
};
470472

473+
// If `2_Dense/model.safetensors` is amongst the downloaded artifacts, then create a Linear
474+
// layer from the VarBuilder using `candle` to provide it as an extra `Dense` layer to the
475+
// `CandleBackend`, otherwise leave it as None
476+
let dense = if model_path.join("2_Dense/model.safetensors").exists() {
477+
let dense_config_path = model_path.join("2_Dense/config.json");
478+
479+
// Load dense config
480+
let dense_config_str = std::fs::read_to_string(&dense_config_path).map_err(|err| {
481+
BackendError::Start(format!("Unable to read dense config file: {err:?}"))
482+
})?;
483+
let dense_config: DenseConfig =
484+
serde_json::from_str(&dense_config_str).map_err(|err| {
485+
BackendError::Start(format!("Unable to parse dense config: {err:?}"))
486+
})?;
487+
488+
let dense_path = model_path.join("2_Dense/model.safetensors");
489+
let dense_vb =
490+
unsafe { VarBuilder::from_mmaped_safetensors(&[dense_path], dtype, &device) }
491+
.s()?;
492+
493+
Some(Box::new(Dense::load(dense_vb, &dense_config).s()?) as Box<dyn DenseLayer + Send>)
494+
} else {
495+
None
496+
};
497+
471498
Ok(Self {
472499
device,
473500
model: model?,
501+
dense: dense,
474502
})
475503
}
476504
}
@@ -507,6 +535,19 @@ impl Backend for CandleBackend {
507535
// Run forward
508536
let (pooled_embeddings, raw_embeddings) = self.model.embed(batch).e()?;
509537

538+
// Apply dense layer if available
539+
let pooled_embeddings = match pooled_embeddings {
540+
None => None,
541+
Some(pooled_embeddings) => {
542+
let pooled_embeddings = if let Some(ref dense) = self.dense {
543+
dense.forward(&pooled_embeddings).e()?
544+
} else {
545+
pooled_embeddings
546+
};
547+
Some(pooled_embeddings)
548+
}
549+
};
550+
510551
// Device => Host data transfer
511552
let pooled_embeddings = match pooled_embeddings {
512553
None => vec![],
@@ -540,6 +581,7 @@ impl Backend for CandleBackend {
540581
let batch_size = batch.len();
541582

542583
let results = self.model.predict(batch).e()?;
584+
543585
let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?;
544586

545587
let mut predictions =

backends/candle/src/models/dense.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
use crate::layers::Linear;
2+
use candle::{Result, Tensor};
3+
use candle_nn::VarBuilder;
4+
use serde::Deserialize;
5+
6+
#[derive(Debug, Clone, PartialEq, Deserialize)]
7+
pub struct DenseConfig {
8+
in_features: usize,
9+
out_features: usize,
10+
bias: bool,
11+
#[allow(unused)]
12+
activation_function: Option<String>,
13+
}
14+
15+
pub trait DenseLayer {
16+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
17+
}
18+
19+
#[derive(Debug)]
20+
pub struct Dense {
21+
linear: Linear,
22+
span: tracing::Span,
23+
}
24+
25+
impl Dense {
26+
pub fn load(vb: VarBuilder, config: &DenseConfig) -> Result<Self> {
27+
let dense_weight = vb.get((config.out_features, config.in_features), "linear.weight")?;
28+
let dense_bias = if config.bias {
29+
Some(vb.get(config.out_features, "linear.bias")?)
30+
} else {
31+
None
32+
};
33+
34+
let linear = Linear::new(dense_weight, dense_bias, None);
35+
36+
Ok(Self {
37+
linear,
38+
span: tracing::span!(tracing::Level::TRACE, "dense"),
39+
})
40+
}
41+
}
42+
43+
impl DenseLayer for Dense {
44+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
45+
let _enter = self.span.enter();
46+
self.linear.forward(hidden_states)?.tanh()
47+
}
48+
}

backends/candle/src/models/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ extern crate intel_mkl_src;
55
extern crate accelerate_src;
66

77
mod bert;
8+
mod dense;
89
mod distilbert;
910
mod jina;
1011
mod jina_code;
@@ -49,6 +50,7 @@ mod qwen3;
4950

5051
pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
5152
use candle::{Result, Tensor};
53+
pub use dense::{Dense, DenseConfig, DenseLayer};
5254
pub use distilbert::{DistilBertConfig, DistilBertModel};
5355
#[allow(unused_imports)]
5456
pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, GTEMLP};

core/src/download.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ pub async fn download_artifacts(api: &ApiRepo, pool_config: bool) -> Result<Path
3737
err
3838
});
3939

40+
// Try to download the `2_Dense/config.json`
41+
if let Ok(_) = download_dense_config(api).await {
42+
// If `2_Dense/config.json` is there, then try to download the `model.safetensors`
43+
if let Err(err) = download_dense_safetensors(api).await {
44+
tracing::warn!("Failed to download dense safetensors: {err}");
45+
}
46+
}
47+
4048
tracing::info!("Downloading `config.json`");
4149
api.get("config.json").await?;
4250

@@ -55,6 +63,20 @@ pub async fn download_pool_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {
5563
Ok(pool_config_path)
5664
}
5765

66+
#[instrument(skip_all)]
67+
pub async fn download_dense_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {
68+
tracing::info!("Downloading `2_Dense/config.json`");
69+
let dense_config_path = api.get("2_Dense/config.json").await?;
70+
Ok(dense_config_path)
71+
}
72+
73+
#[instrument(skip_all)]
74+
pub async fn download_dense_safetensors(api: &ApiRepo) -> Result<PathBuf, ApiError> {
75+
tracing::info!("Downloading `2_Dense/model.safetensors`");
76+
let dense_safetensors_path = api.get("2_Dense/model.safetensors").await?;
77+
Ok(dense_safetensors_path)
78+
}
79+
5880
#[instrument(skip_all)]
5981
pub async fn download_st_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {
6082
// Try default path

router/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ fn get_backend_model_type(
405405
}
406406
}
407407
};
408+
408409
Ok(text_embeddings_backend::ModelType::Embedding(pool))
409410
}
410411

0 commit comments

Comments
 (0)