Skip to content

Commit c9cddf2

Browse files
committed
Add --dense-path argument (to be used within CandleBackend)
If `--dense-path` was not allowed, that would prevent users from using other `Dense` layers when available as per e.g. https://huggingface.co/NovaSearch/stella_en_400M_v5, that contains different directories for different `Dense` layers with different output vector dimensionality as `2_Dense_<dims>/`.
1 parent dcb3ee8 commit c9cddf2

21 files changed

+190
-42
lines changed

backends/candle/src/lib.rs

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ impl CandleBackend {
123123
model_path: &Path,
124124
dtype: String,
125125
model_type: ModelType,
126+
dense_path: Option<&Path>,
126127
) -> Result<Self, BackendError> {
127128
// Default files
128129
let default_safetensors = model_path.join("model.safetensors");
@@ -470,27 +471,44 @@ impl CandleBackend {
470471
}
471472
};
472473

473-
// If `2_Dense/model.safetensors` is amongst the downloaded artifacts, then create a Dense
474+
// If `2_Dense/model.safetensors` or `2_Dense/pytorch_model.bin` is amongst the downloaded artifacts, then create a Dense
474475
// block and provide it to the `CandleBackend`, otherwise, None
475-
let dense = if model_path.join("2_Dense/model.safetensors").exists() {
476-
let dense_config_path = model_path.join("2_Dense/config.json");
476+
let dense = if let Some(dense_path) = dense_path {
477+
let dense_safetensors = dense_path.join("model.safetensors");
478+
let dense_pytorch = dense_path.join("pytorch_model.bin");
479+
480+
if dense_safetensors.exists() || dense_pytorch.exists() {
481+
let dense_config_path = dense_path.join("config.json");
482+
483+
let dense_config_str =
484+
std::fs::read_to_string(&dense_config_path).map_err(|err| {
485+
BackendError::Start(format!(
486+
"Unable to read `{}/config.json` file: {err:?}",
487+
dense_path.display()
488+
))
489+
})?;
490+
let dense_config: DenseConfig =
491+
serde_json::from_str(&dense_config_str).map_err(|err| {
492+
BackendError::Start(format!(
493+
"Unable to parse `{}/config.json`: {err:?}",
494+
dense_path.display()
495+
))
496+
})?;
497+
498+
let dense_vb = if dense_safetensors.exists() {
499+
unsafe {
500+
VarBuilder::from_mmaped_safetensors(&[dense_safetensors], dtype, &device)
501+
}
502+
.s()?
503+
} else {
504+
VarBuilder::from_pth(&dense_pytorch, dtype, &device).s()?
505+
};
477506

478-
let dense_config_str = std::fs::read_to_string(&dense_config_path).map_err(|err| {
479-
BackendError::Start(format!(
480-
"Unable to read `2_Dense/config.json` file: {err:?}"
481-
))
482-
})?;
483-
let dense_config: DenseConfig =
484-
serde_json::from_str(&dense_config_str).map_err(|err| {
485-
BackendError::Start(format!("Unable to parse `2_Dense/config.json`: {err:?}"))
486-
})?;
487-
488-
let dense_path = model_path.join("2_Dense/model.safetensors");
489-
let dense_vb =
490-
unsafe { VarBuilder::from_mmaped_safetensors(&[dense_path], dtype, &device) }
491-
.s()?;
492-
493-
Some(Box::new(Dense::load(dense_vb, &dense_config).s()?) as Box<dyn DenseLayer + Send>)
507+
Some(Box::new(Dense::load(dense_vb, &dense_config).s()?)
508+
as Box<dyn DenseLayer + Send>)
509+
} else {
510+
None
511+
}
494512
} else {
495513
None
496514
};

backends/candle/tests/test_bert.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ fn test_bert() -> Result<()> {
1616
&model_root,
1717
"float32".to_string(),
1818
ModelType::Embedding(Pool::Mean),
19+
None,
1920
)?;
2021

2122
let input_batch = batch(
@@ -76,6 +77,7 @@ fn test_bert_pooled_raw() -> Result<()> {
7677
&model_root,
7778
"float32".to_string(),
7879
ModelType::Embedding(Pool::Cls),
80+
None,
7981
)?;
8082

8183
let input_batch = batch(
@@ -142,7 +144,12 @@ fn test_emotions() -> Result<()> {
142144
let model_root = download_artifacts("SamLowe/roberta-base-go_emotions", None)?;
143145
let tokenizer = load_tokenizer(&model_root)?;
144146

145-
let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;
147+
let backend = CandleBackend::new(
148+
&model_root,
149+
"float32".to_string(),
150+
ModelType::Classifier,
151+
None,
152+
)?;
146153

147154
let input_batch = batch(
148155
vec![
@@ -193,7 +200,12 @@ fn test_bert_classification() -> Result<()> {
193200
download_artifacts("ibm-research/re2g-reranker-nq", Some("refs/pr/3")).unwrap();
194201
let tokenizer = load_tokenizer(&model_root)?;
195202

196-
let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;
203+
let backend = CandleBackend::new(
204+
&model_root,
205+
"float32".to_string(),
206+
ModelType::Classifier,
207+
None,
208+
)?;
197209

198210
let input_single = batch(
199211
vec![tokenizer

backends/candle/tests/test_flash_bert.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ fn test_flash_mini() -> Result<()> {
2222
&model_root,
2323
"float16".to_string(),
2424
ModelType::Embedding(Pool::Mean),
25+
None,
2526
)?;
2627

2728
let input_batch = batch(
@@ -86,6 +87,7 @@ fn test_flash_mini_pooled_raw() -> Result<()> {
8687
&model_root,
8788
"float16".to_string(),
8889
ModelType::Embedding(Pool::Cls),
90+
None,
8991
)?;
9092

9193
let input_batch = batch(
@@ -156,7 +158,12 @@ fn test_flash_emotions() -> Result<()> {
156158
let model_root = download_artifacts("SamLowe/roberta-base-go_emotions", None)?;
157159
let tokenizer = load_tokenizer(&model_root)?;
158160

159-
let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?;
161+
let backend = CandleBackend::new(
162+
&model_root,
163+
"float16".to_string(),
164+
ModelType::Classifier,
165+
None,
166+
)?;
160167

161168
let input_batch = batch(
162169
vec![
@@ -210,7 +217,12 @@ fn test_flash_bert_classification() -> Result<()> {
210217
let model_root = download_artifacts("ibm-research/re2g-reranker-nq", Some("refs/pr/3"))?;
211218
let tokenizer = load_tokenizer(&model_root)?;
212219

213-
let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?;
220+
let backend = CandleBackend::new(
221+
&model_root,
222+
"float16".to_string(),
223+
ModelType::Classifier,
224+
None,
225+
)?;
214226

215227
let input_single = batch(
216228
vec![tokenizer

backends/candle/tests/test_flash_gte.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ fn test_flash_gte() -> Result<()> {
1818
&model_root,
1919
"float16".to_string(),
2020
ModelType::Embedding(Pool::Cls),
21+
None,
2122
)?;
2223

2324
let input_batch = batch(
@@ -62,7 +63,12 @@ fn test_flash_gte_classification() -> Result<()> {
6263
let model_root = download_artifacts("Alibaba-NLP/gte-multilingual-reranker-base", None)?;
6364
let tokenizer = load_tokenizer(&model_root)?;
6465

65-
let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?;
66+
let backend = CandleBackend::new(
67+
&model_root,
68+
"float16".to_string(),
69+
ModelType::Classifier,
70+
None,
71+
)?;
6672

6773
let input_single = batch(
6874
vec![tokenizer

backends/candle/tests/test_flash_jina.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ fn test_flash_jina_small() -> Result<()> {
1818
&model_root,
1919
"float16".to_string(),
2020
ModelType::Embedding(Pool::Mean),
21+
None,
2122
)?;
2223

2324
let input_batch = batch(

backends/candle/tests/test_flash_jina_code.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ fn test_flash_jina_code_base() -> Result<()> {
1818
&model_root,
1919
"float16".to_string(),
2020
ModelType::Embedding(Pool::Mean),
21+
None,
2122
)?;
2223

2324
let input_batch = batch(

backends/candle/tests/test_flash_mistral.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ fn test_flash_mistral() -> Result<()> {
1818
&model_root,
1919
"float16".to_string(),
2020
ModelType::Embedding(Pool::Mean),
21+
None,
2122
)?;
2223

2324
let input_batch = batch(

backends/candle/tests/test_flash_nomic.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ fn test_flash_nomic_small() -> Result<()> {
1818
&model_root,
1919
"float16".to_string(),
2020
ModelType::Embedding(Pool::Mean),
21+
None,
2122
)?;
2223

2324
let input_batch = batch(
@@ -63,6 +64,7 @@ fn test_flash_nomic_moe() -> Result<()> {
6364
&model_root,
6465
"float16".to_string(),
6566
ModelType::Embedding(Pool::Mean),
67+
None,
6668
)?;
6769

6870
let input_batch = batch(

backends/candle/tests/test_flash_qwen2.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ fn test_flash_qwen2() -> Result<()> {
4242
&model_root,
4343
"float16".to_string(),
4444
ModelType::Embedding(Pool::LastToken),
45+
None,
4546
)?;
4647

4748
let input_batch = batch(

backends/candle/tests/test_flash_qwen3.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ fn test_flash_qwen3() -> Result<()> {
1818
&model_root,
1919
"float16".to_string(),
2020
ModelType::Embedding(Pool::LastToken),
21+
None,
2122
)?;
2223

2324
let input_batch = batch(

0 commit comments

Comments
 (0)