@@ -24,7 +24,7 @@ use anyhow::Context;
24
24
use candle:: { DType , Device } ;
25
25
use candle_nn:: VarBuilder ;
26
26
use nohash_hasher:: BuildNoHashHasher ;
27
- use serde:: Deserialize ;
27
+ use serde:: { de :: Deserializer , Deserialize } ;
28
28
use std:: collections:: HashMap ;
29
29
use std:: path:: Path ;
30
30
use text_embeddings_backend_core:: {
@@ -33,19 +33,58 @@ use text_embeddings_backend_core::{
33
33
34
34
/// This enum is needed to be able to differentiate between jina models that also use
35
35
/// the `bert` model type and valid Bert models.
36
- /// We use the `_name_or_path` field in the config to do so. This might not be robust in the long
37
- /// run but is still better than the other options...
38
- #[ derive( Debug , Clone , PartialEq , Deserialize ) ]
39
- #[ serde( tag = "_name_or_path" ) ]
36
+ #[ derive( Debug , Clone , PartialEq ) ]
40
37
pub enum BertConfigWrapper {
41
- #[ serde( rename = "jinaai/jina-bert-implementation" ) ]
42
38
JinaBert ( BertConfig ) ,
43
- #[ serde( rename = "jinaai/jina-bert-v2-qk-post-norm" ) ]
44
39
JinaCodeBert ( BertConfig ) ,
45
- #[ serde( untagged) ]
46
40
Bert ( BertConfig ) ,
47
41
}
48
42
43
+ /// Custom deserializer is required as we need to capture both whether the `_name_or_path` value
44
+ /// is any of the JinaBERT alternatives, or alternatively to also support fine-tunes and re-uploads
45
+ /// with Sentence Transformers, we also need to check the value for the `auto_map.AutoConfig`
46
+ /// configuration file, and see if that points to the relevant remote code repositories on the Hub
47
+ impl < ' de > Deserialize < ' de > for BertConfigWrapper {
48
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
49
+ where
50
+ D : Deserializer < ' de > ,
51
+ {
52
+ use serde:: de:: Error ;
53
+
54
+ #[ allow( unused_mut) ]
55
+ let mut value = serde_json:: Value :: deserialize ( deserializer) ?;
56
+
57
+ let name_or_path = value
58
+ . get ( "_name_or_path" )
59
+ . and_then ( |v| v. as_str ( ) )
60
+ . map ( ToString :: to_string)
61
+ . unwrap_or_default ( ) ;
62
+
63
+ let auto_config = value
64
+ . get ( "auto_map" )
65
+ . and_then ( |v| v. get ( "AutoConfig" ) )
66
+ . and_then ( |v| v. as_str ( ) )
67
+ . map ( ToString :: to_string)
68
+ . unwrap_or_default ( ) ;
69
+
70
+ let config = BertConfig :: deserialize ( value) . map_err ( Error :: custom) ?;
71
+
72
+ if name_or_path == "jinaai/jina-bert-implementation"
73
+ || auto_config. contains ( "jinaai/jina-bert-implementation" )
74
+ {
75
+ // https://huggingface.co/jinaai/jina-bert-implementation
76
+ Ok ( Self :: JinaBert ( config) )
77
+ } else if name_or_path == "jinaai/jina-bert-v2-qk-post-norm"
78
+ || auto_config. contains ( "jinaai/jina-bert-v2-qk-post-norm" )
79
+ {
80
+ // https://huggingface.co/jinaai/jina-bert-v2-qk-post-norm
81
+ Ok ( Self :: JinaCodeBert ( config) )
82
+ } else {
83
+ Ok ( Self :: Bert ( config) )
84
+ }
85
+ }
86
+ }
87
+
49
88
#[ derive( Deserialize ) ]
50
89
#[ serde( tag = "model_type" , rename_all = "kebab-case" ) ]
51
90
enum Config {
0 commit comments