Skip to content

Commit 7f7832e

Browse files
authored
Use custom serde deserializer for JinaBERT models (#559)
1 parent f2c308e commit 7f7832e

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

Cargo.lock

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backends/candle/src/lib.rs

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use anyhow::Context;
2424
use candle::{DType, Device};
2525
use candle_nn::VarBuilder;
2626
use nohash_hasher::BuildNoHashHasher;
27-
use serde::Deserialize;
27+
use serde::{de::Deserializer, Deserialize};
2828
use std::collections::HashMap;
2929
use std::path::Path;
3030
use text_embeddings_backend_core::{
@@ -33,19 +33,58 @@ use text_embeddings_backend_core::{
3333

3434
/// This enum is needed to be able to differentiate between jina models that also use
3535
/// the `bert` model type and valid Bert models.
36-
/// We use the `_name_or_path` field in the config to do so. This might not be robust in the long
37-
/// run but is still better than the other options...
38-
#[derive(Debug, Clone, PartialEq, Deserialize)]
39-
#[serde(tag = "_name_or_path")]
36+
#[derive(Debug, Clone, PartialEq)]
4037
pub enum BertConfigWrapper {
41-
#[serde(rename = "jinaai/jina-bert-implementation")]
4238
JinaBert(BertConfig),
43-
#[serde(rename = "jinaai/jina-bert-v2-qk-post-norm")]
4439
JinaCodeBert(BertConfig),
45-
#[serde(untagged)]
4640
Bert(BertConfig),
4741
}
4842

43+
/// Custom deserializer is required as we need to capture both whether the `_name_or_path` value
44+
/// is any of the JinaBERT alternatives, or alternatively to also support fine-tunes and re-uploads
45+
/// with Sentence Transformers, we also need to check the value for the `auto_map.AutoConfig`
46+
/// configuration file, and see if that points to the relevant remote code repositories on the Hub
47+
impl<'de> Deserialize<'de> for BertConfigWrapper {
48+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
49+
where
50+
D: Deserializer<'de>,
51+
{
52+
use serde::de::Error;
53+
54+
#[allow(unused_mut)]
55+
let mut value = serde_json::Value::deserialize(deserializer)?;
56+
57+
let name_or_path = value
58+
.get("_name_or_path")
59+
.and_then(|v| v.as_str())
60+
.map(ToString::to_string)
61+
.unwrap_or_default();
62+
63+
let auto_config = value
64+
.get("auto_map")
65+
.and_then(|v| v.get("AutoConfig"))
66+
.and_then(|v| v.as_str())
67+
.map(ToString::to_string)
68+
.unwrap_or_default();
69+
70+
let config = BertConfig::deserialize(value).map_err(Error::custom)?;
71+
72+
if name_or_path == "jinaai/jina-bert-implementation"
73+
|| auto_config.contains("jinaai/jina-bert-implementation")
74+
{
75+
// https://huggingface.co/jinaai/jina-bert-implementation
76+
Ok(Self::JinaBert(config))
77+
} else if name_or_path == "jinaai/jina-bert-v2-qk-post-norm"
78+
|| auto_config.contains("jinaai/jina-bert-v2-qk-post-norm")
79+
{
80+
// https://huggingface.co/jinaai/jina-bert-v2-qk-post-norm
81+
Ok(Self::JinaCodeBert(config))
82+
} else {
83+
Ok(Self::Bert(config))
84+
}
85+
}
86+
}
87+
4988
#[derive(Deserialize)]
5089
#[serde(tag = "model_type", rename_all = "kebab-case")]
5190
enum Config {

0 commit comments

Comments
 (0)