Skip to content

Add Dense layer in 2_Dense/ modules #660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions backends/candle/src/layers/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,19 @@ impl Linear {
),
}
} else {
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
let (x, w) = match x.dims() {
&[bsize, _, _] => (x, self.weight.broadcast_left(bsize)?.t()?),
// Metal devices require contiguous tensors for 2D matrix multiplication apparently
_ if matches!(x.device(), Device::Metal(_)) => (&x.contiguous()?, self.weight.t()?),
_ => (x, self.weight.t()?),
};
let x = x.matmul(&w)?;

let x = match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}?;

if let Some(act) = &self.act {
match act {
HiddenAct::Gelu => x.gelu(),
Expand Down
48 changes: 45 additions & 3 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ use crate::compute_cap::{
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, ModernBertConfig,
ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config, Qwen3Config, Qwen3Model,
BertConfig, BertModel, Dense, DenseConfig, DenseLayer, DistilBertConfig, DistilBertModel,
GTEConfig, GTEModel, JinaBertModel, JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig,
Model, ModernBertConfig, ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config,
Qwen3Config, Qwen3Model,
};
#[cfg(feature = "cuda")]
use crate::models::{
Expand Down Expand Up @@ -114,6 +115,7 @@ enum Config {
pub struct CandleBackend {
device: Device,
model: Box<dyn Model + Send>,
dense: Option<Box<dyn DenseLayer + Send>>,
}

impl CandleBackend {
Expand Down Expand Up @@ -468,9 +470,35 @@ impl CandleBackend {
}
};

// If `2_Dense/model.safetensors` is amongst the downloaded artifacts, then create a Dense
// block and provide it to the `CandleBackend`, otherwise, None
let dense = if model_path.join("2_Dense/model.safetensors").exists() {
let dense_config_path = model_path.join("2_Dense/config.json");

let dense_config_str = std::fs::read_to_string(&dense_config_path).map_err(|err| {
BackendError::Start(format!(
"Unable to read `2_Dense/config.json` file: {err:?}"
))
})?;
let dense_config: DenseConfig =
serde_json::from_str(&dense_config_str).map_err(|err| {
BackendError::Start(format!("Unable to parse `2_Dense/config.json`: {err:?}"))
})?;

let dense_path = model_path.join("2_Dense/model.safetensors");
let dense_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[dense_path], dtype, &device) }
.s()?;

Some(Box::new(Dense::load(dense_vb, &dense_config).s()?) as Box<dyn DenseLayer + Send>)
} else {
None
};

Ok(Self {
device,
model: model?,
dense,
})
}
}
Expand Down Expand Up @@ -507,6 +535,19 @@ impl Backend for CandleBackend {
// Run forward
let (pooled_embeddings, raw_embeddings) = self.model.embed(batch).e()?;

// Apply dense layer if available
let pooled_embeddings = match pooled_embeddings {
None => None,
Some(pooled_embeddings) => {
let pooled_embeddings = if let Some(ref dense) = self.dense {
dense.forward(&pooled_embeddings).e()?
} else {
pooled_embeddings
};
Some(pooled_embeddings)
}
};

// Device => Host data transfer
let pooled_embeddings = match pooled_embeddings {
None => vec![],
Expand Down Expand Up @@ -540,6 +581,7 @@ impl Backend for CandleBackend {
let batch_size = batch.len();

let results = self.model.predict(batch).e()?;

let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?;

let mut predictions =
Expand Down
78 changes: 78 additions & 0 deletions backends/candle/src/models/dense.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use crate::layers::Linear;
use candle::{Result, Tensor};
use candle_nn::VarBuilder;
use serde::Deserialize;

#[derive(Debug, Clone)]
pub enum DenseActivation {
Tanh,
Identity,
}

impl DenseActivation {
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
match self {
Self::Tanh => x.tanh(),
Self::Identity => Ok(x.clone()),
}
}
}

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct DenseConfig {
in_features: usize,
out_features: usize,
bias: bool,
activation_function: Option<String>,
}

pub trait DenseLayer {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
}

#[derive(Debug)]
pub struct Dense {
linear: Linear,
activation: DenseActivation,
span: tracing::Span,
}

impl Dense {
pub fn load(vb: VarBuilder, config: &DenseConfig) -> Result<Self> {
let weight = vb.get((config.out_features, config.in_features), "linear.weight")?;
let bias = if config.bias {
Some(vb.get(config.out_features, "linear.bias")?)
} else {
None
};

// Here we cannot leverage HiddenAct, since the activation functions for the
// 2_Dense/config.json are defined as PyTorch imports instead, so the deserialization would
// be different, as well as the range of commonly used activation functions (mainly tanh
// and identity)
let activation = match config.activation_function {
// e.g. https://huggingface.co/sentence-transformers/LaBSE/blob/main/2_Dense/config.json
Some(ref act) if act == "torch.nn.modules.activation.Tanh" => DenseActivation::Tanh,
// e.g. https://huggingface.co/NovaSearch/stella_en_400M_v5/blob/main/2_Dense/config.json
Some(ref act) if act == "torch.nn.modules.linear.Identity" => DenseActivation::Identity,
_ => DenseActivation::Identity,
};

let linear = Linear::new(weight, bias, None);

Ok(Self {
linear,
activation,
span: tracing::span!(tracing::Level::TRACE, "dense"),
})
}
}

impl DenseLayer for Dense {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let hidden_states = self.linear.forward(hidden_states)?;
self.activation.forward(&hidden_states)
}
}
2 changes: 2 additions & 0 deletions backends/candle/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;

mod bert;
mod dense;
mod distilbert;
mod jina;
mod jina_code;
Expand Down Expand Up @@ -49,6 +50,7 @@ mod qwen3;

pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
use candle::{Result, Tensor};
pub use dense::{Dense, DenseConfig, DenseLayer};
pub use distilbert::{DistilBertConfig, DistilBertModel};
#[allow(unused_imports)]
pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, GTEMLP};
Expand Down
7 changes: 5 additions & 2 deletions backends/ort/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,11 @@ impl Backend for OrtBackend {
Pool::Cls => outputs.slice(s![.., 0, ..]).into_owned().into_dyn(),
Pool::LastToken => {
let axis_len = outputs.len_of(Axis(1));
outputs.slice(s![.., axis_len - 1, ..]).into_owned().into_dyn()
},
outputs
.slice(s![.., axis_len - 1, ..])
.into_owned()
.into_dyn()
}
// Mean pooling
Pool::Mean => {
if masking {
Expand Down
7 changes: 2 additions & 5 deletions backends/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,8 @@ impl Backend {
}

max_input_length = std::cmp::min(max_input_length, max_warmup_length);
let mut seq_lengths: Vec<usize> = generate_bucket_sizes(
seq_bucket_size,
max_input_length,
seq_len_exp_base,
);
let mut seq_lengths: Vec<usize> =
generate_bucket_sizes(seq_bucket_size, max_input_length, seq_len_exp_base);
if let Some(&last) = seq_lengths.last() {
if last < max_input_length {
seq_lengths.push(max_input_length);
Expand Down
22 changes: 22 additions & 0 deletions core/src/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ pub async fn download_artifacts(api: &ApiRepo, pool_config: bool) -> Result<Path
err
});

// Try to download the `2_Dense/config.json`
if download_dense_config(api).await.is_ok() {
// If `2_Dense/config.json` is there, then try to download the `model.safetensors`
if let Err(err) = download_dense_safetensors(api).await {
tracing::warn!("Failed to download dense safetensors: {err}");
}
}

tracing::info!("Downloading `config.json`");
api.get("config.json").await?;

Expand All @@ -55,6 +63,20 @@ pub async fn download_pool_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {
Ok(pool_config_path)
}

#[instrument(skip_all)]
pub async fn download_dense_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {
tracing::info!("Downloading `2_Dense/config.json`");
let dense_config_path = api.get("2_Dense/config.json").await?;
Ok(dense_config_path)
}

#[instrument(skip_all)]
pub async fn download_dense_safetensors(api: &ApiRepo) -> Result<PathBuf, ApiError> {
tracing::info!("Downloading `2_Dense/model.safetensors`");
let dense_safetensors_path = api.get("2_Dense/model.safetensors").await?;
Ok(dense_safetensors_path)
}

#[instrument(skip_all)]
pub async fn download_st_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {
// Try default path
Expand Down
1 change: 1 addition & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ fn get_backend_model_type(
}
}
};

Ok(text_embeddings_backend::ModelType::Embedding(pool))
}

Expand Down
Loading