Skip to content

Commit 01d0fbd

Browse files
feat: Implement MPNet model (#363) (#447)
Co-authored-by: Hyeongchan Kim <kozistr@gmail.com>
1 parent 0462171 commit 01d0fbd

11 files changed

+29281
-3
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ Ember, GTE and E5. TEI implements many features such as:
6565
#### Text Embeddings
6666

6767
Text Embeddings Inference currently supports Nomic, BERT, CamemBERT, XLM-RoBERTa models with absolute positions, JinaBERT
68-
model with Alibi positions and Mistral, Alibaba GTE and Qwen2 models with Rope positions.
68+
model with Alibi positions and Mistral, Alibaba GTE, Qwen2 models with Rope positions, and MPNet.
6969

7070
Below are some examples of the currently supported models:
7171

@@ -81,7 +81,7 @@ Below are some examples of the currently supported models:
8181
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) |
8282
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
8383
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jinaai/jina-embeddings-v2-base-code) |
84-
84+
| N/A | 0.1B | MPNet | [sentence-transformers/all-mpnet-base-v2](https://hf.co/sentence-transformers/all-mpnet-base-v2) |
8585

8686
To explore the list of best performing text embeddings models, visit the
8787
[Massive Text Embedding Benchmark (MTEB) Leaderboard](https://huggingface.co/spaces/mteb/leaderboard).

backends/candle/src/lib.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ use crate::compute_cap::{
1212
};
1313
use crate::models::{
1414
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
15-
JinaCodeBertModel, MistralConfig, Model, NomicBertModel, NomicConfig, Qwen2Config,
15+
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, NomicBertModel, NomicConfig,
16+
Qwen2Config,
1617
};
1718
#[cfg(feature = "cuda")]
1819
use crate::models::{
@@ -60,6 +61,8 @@ enum Config {
6061
#[serde(rename = "new")]
6162
Gte(GTEConfig),
6263
Qwen2(Qwen2Config),
64+
#[serde(rename = "mpnet")]
65+
MPNet(MPNetConfig),
6366
}
6467

6568
pub struct CandleBackend {
@@ -226,6 +229,10 @@ impl CandleBackend {
226229
"Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
227230
.to_string(),
228231
)),
232+
(Config::MPNet(config), _) => {
233+
tracing::info!("Starting MPNet model on {:?}", device);
234+
Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?))
235+
}
229236
#[cfg(feature = "cuda")]
230237
(Config::Bert(config), Device::Cuda(_)) => {
231238
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))

backends/candle/src/models/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ mod flash_mistral;
3434
#[cfg(feature = "cuda")]
3535
mod flash_qwen2;
3636
mod gte;
37+
mod mpnet;
3738
mod qwen2;
3839

3940
pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
@@ -44,6 +45,7 @@ pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, GTEMLP};
4445
pub use jina::JinaBertModel;
4546
pub use jina_code::JinaCodeBertModel;
4647
pub use mistral::MistralConfig;
48+
pub use mpnet::{MPNetConfig, MPNetModel};
4749
pub use nomic::{NomicBertModel, NomicConfig};
4850
pub use qwen2::Qwen2Config;
4951
use text_embeddings_backend_core::Batch;

0 commit comments

Comments
 (0)