Skip to content

Commit 337fbd6

Browse files
fix: fix loading of bert classification models (#173)
1 parent bd7f8eb commit 337fbd6

File tree

3 files changed

+165
-42
lines changed

3 files changed

+165
-42
lines changed

backends/candle/src/lib.rs

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,7 @@ impl CandleBackend {
101101
(_, Device::Cuda(_)) => Err(BackendError::Start(
102102
"`cuda` feature is not enabled".to_string(),
103103
)),
104-
(
105-
Config::Bert(config)
106-
| Config::XlmRoberta(config)
107-
| Config::Camembert(config)
108-
| Config::Roberta(config),
109-
Device::Cpu | Device::Metal(_),
110-
) => {
104+
(Config::Bert(config), Device::Cpu | Device::Metal(_)) => {
111105
if config.position_embedding_type == PositionEmbeddingType::Alibi {
112106
tracing::info!("Starting JinaBertModel model on {:?}", device);
113107
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
@@ -116,14 +110,21 @@ impl CandleBackend {
116110
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
117111
}
118112
}
119-
#[cfg(feature = "cuda")]
120113
(
121-
Config::Bert(config)
122-
| Config::XlmRoberta(config)
123-
| Config::Camembert(config)
124-
| Config::Roberta(config),
125-
Device::Cuda(_),
114+
Config::XlmRoberta(config) | Config::Camembert(config) | Config::Roberta(config),
115+
Device::Cpu | Device::Metal(_),
126116
) => {
117+
tracing::info!("Starting Bert model on {:?}", device);
118+
Ok(Box::new(
119+
BertModel::load_roberta(vb, &config, model_type).s()?,
120+
))
121+
}
122+
(Config::NomicBert(config), Device::Cpu | Device::Metal(_)) => {
123+
tracing::info!("Starting NomicBertModel model on {:?}", device);
124+
Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?))
125+
}
126+
#[cfg(feature = "cuda")]
127+
(Config::Bert(config), Device::Cuda(_)) => {
127128
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
128129
&& dtype == DType::F16
129130
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
@@ -150,9 +151,28 @@ impl CandleBackend {
150151
}
151152
}
152153
}
153-
(Config::NomicBert(config), Device::Cpu | Device::Metal(_)) => {
154-
tracing::info!("Starting NomicBertModel model on {:?}", device);
155-
Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?))
154+
#[cfg(feature = "cuda")]
155+
(
156+
Config::XlmRoberta(config) | Config::Camembert(config) | Config::Roberta(config),
157+
Device::Cuda(_),
158+
) => {
159+
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
160+
&& dtype == DType::F16
161+
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
162+
// Allow disabling because of flash attention v1 precision problems
163+
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
164+
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
165+
{
166+
tracing::info!("Starting FlashBert model on {:?}", device);
167+
Ok(Box::new(
168+
FlashBertModel::load_roberta(vb, &config, model_type).s()?,
169+
))
170+
} else {
171+
tracing::info!("Starting Bert model on {:?}", device);
172+
Ok(Box::new(
173+
BertModel::load_roberta(vb, &config, model_type).s()?,
174+
))
175+
}
156176
}
157177
#[cfg(feature = "cuda")]
158178
(Config::NomicBert(config), Device::Cuda(_)) => {

backends/candle/src/models/bert.rs

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -466,14 +466,7 @@ impl BertModel {
466466
let pool = Pool::Cls;
467467

468468
let classifier: Box<dyn ClassificationHead + Send> =
469-
if config.model_type == Some("bert".to_string()) {
470-
Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?)
471-
} else {
472-
Box::new(RobertaClassificationHead::load(
473-
vb.pp("classifier"),
474-
config,
475-
)?)
476-
};
469+
Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?);
477470
(pool, Some(classifier))
478471
}
479472
ModelType::Embedding(pool) => (pool, None),
@@ -485,16 +478,71 @@ impl BertModel {
485478
) {
486479
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
487480
(Err(err), _) | (_, Err(err)) => {
488-
let model_type = config.model_type.clone().unwrap_or("bert".to_string());
481+
if let (Ok(embeddings), Ok(encoder)) = (
482+
BertEmbeddings::load(vb.pp("bert.embeddings".to_string()), config),
483+
BertEncoder::load(vb.pp("bert.encoder".to_string()), config),
484+
) {
485+
(embeddings, encoder)
486+
} else {
487+
return Err(err);
488+
}
489+
}
490+
};
491+
492+
Ok(Self {
493+
embeddings,
494+
encoder,
495+
pool,
496+
classifier,
497+
num_attention_heads: config.num_attention_heads,
498+
device: vb.device().clone(),
499+
dtype: vb.dtype(),
500+
span: tracing::span!(tracing::Level::TRACE, "model"),
501+
})
502+
}
503+
504+
pub fn load_roberta(
505+
vb: VarBuilder,
506+
config: &BertConfig,
507+
model_type: ModelType,
508+
) -> Result<Self> {
509+
// Check position embedding type
510+
if config.position_embedding_type != PositionEmbeddingType::Absolute {
511+
candle::bail!("Bert only supports absolute position embeddings")
512+
}
513+
514+
let (pool, classifier) = match model_type {
515+
// Classifier models always use CLS pooling
516+
ModelType::Classifier => {
517+
let pool = Pool::Cls;
518+
519+
let classifier: Box<dyn ClassificationHead + Send> = Box::new(
520+
RobertaClassificationHead::load(vb.pp("classifier"), config)?,
521+
);
522+
(pool, Some(classifier))
523+
}
524+
ModelType::Embedding(pool) => (pool, None),
525+
};
489526

527+
let (embeddings, encoder) = match (
528+
BertEmbeddings::load(vb.pp("embeddings"), config),
529+
BertEncoder::load(vb.pp("encoder"), config),
530+
) {
531+
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
532+
(Err(err), _) | (_, Err(err)) => {
490533
if let (Ok(embeddings), Ok(encoder)) = (
491-
BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
492-
BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
534+
BertEmbeddings::load(vb.pp("roberta.embeddings".to_string()), config),
535+
BertEncoder::load(vb.pp("roberta.encoder".to_string()), config),
536+
) {
537+
(embeddings, encoder)
538+
} else if let (Ok(embeddings), Ok(encoder)) = (
539+
BertEmbeddings::load(vb.pp("xlm-roberta.embeddings".to_string()), config),
540+
BertEncoder::load(vb.pp("xlm-roberta.encoder".to_string()), config),
493541
) {
494542
(embeddings, encoder)
495543
} else if let (Ok(embeddings), Ok(encoder)) = (
496-
BertEmbeddings::load(vb.pp("roberta.embeddings"), config),
497-
BertEncoder::load(vb.pp("roberta.encoder"), config),
544+
BertEmbeddings::load(vb.pp("camembert.embeddings".to_string()), config),
545+
BertEncoder::load(vb.pp("camembert.encoder".to_string()), config),
498546
) {
499547
(embeddings, encoder)
500548
} else {

backends/candle/src/models/flash_bert.rs

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,7 @@ impl FlashBertModel {
306306
let pool = Pool::Cls;
307307

308308
let classifier: Box<dyn ClassificationHead + Send> =
309-
if config.model_type == Some("bert".to_string()) {
310-
Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?)
311-
} else {
312-
Box::new(RobertaClassificationHead::load(
313-
vb.pp("classifier"),
314-
config,
315-
)?)
316-
};
309+
Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?);
317310
(pool, Some(classifier))
318311
}
319312
ModelType::Embedding(pool) => (pool, None),
@@ -325,16 +318,78 @@ impl FlashBertModel {
325318
) {
326319
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
327320
(Err(err), _) | (_, Err(err)) => {
328-
let model_type = config.model_type.clone().unwrap_or("bert".to_string());
321+
if let (Ok(embeddings), Ok(encoder)) = (
322+
BertEmbeddings::load(vb.pp("bert.embeddings".to_string()), config),
323+
BertEncoder::load(vb.pp("bert.encoder".to_string()), config),
324+
) {
325+
(embeddings, encoder)
326+
} else {
327+
return Err(err);
328+
}
329+
}
330+
};
331+
332+
Ok(Self {
333+
embeddings,
334+
encoder,
335+
pool,
336+
classifier,
337+
device: vb.device().clone(),
338+
span: tracing::span!(tracing::Level::TRACE, "model"),
339+
})
340+
}
341+
342+
pub fn load_roberta(
343+
vb: VarBuilder,
344+
config: &BertConfig,
345+
model_type: ModelType,
346+
) -> Result<Self> {
347+
match vb.device() {
348+
Device::Cuda(_) => {}
349+
_ => candle::bail!("FlashBert requires Cuda"),
350+
}
351+
352+
if vb.dtype() != DType::F16 {
353+
candle::bail!("FlashBert requires DType::F16")
354+
}
355+
356+
// Check position embedding type
357+
if config.position_embedding_type != PositionEmbeddingType::Absolute {
358+
candle::bail!("FlashBert only supports absolute position embeddings")
359+
}
360+
361+
let (pool, classifier) = match model_type {
362+
// Classifier models always use CLS pooling
363+
ModelType::Classifier => {
364+
let pool = Pool::Cls;
365+
366+
let classifier: Box<dyn ClassificationHead + Send> = Box::new(
367+
RobertaClassificationHead::load(vb.pp("classifier"), config)?,
368+
);
369+
(pool, Some(classifier))
370+
}
371+
ModelType::Embedding(pool) => (pool, None),
372+
};
329373

374+
let (embeddings, encoder) = match (
375+
BertEmbeddings::load(vb.pp("embeddings"), config),
376+
BertEncoder::load(vb.pp("encoder"), config),
377+
) {
378+
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
379+
(Err(err), _) | (_, Err(err)) => {
330380
if let (Ok(embeddings), Ok(encoder)) = (
331-
BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
332-
BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
381+
BertEmbeddings::load(vb.pp("roberta.embeddings".to_string()), config),
382+
BertEncoder::load(vb.pp("roberta.encoder".to_string()), config),
383+
) {
384+
(embeddings, encoder)
385+
} else if let (Ok(embeddings), Ok(encoder)) = (
386+
BertEmbeddings::load(vb.pp("xlm-roberta.embeddings".to_string()), config),
387+
BertEncoder::load(vb.pp("xlm-roberta.encoder".to_string()), config),
333388
) {
334389
(embeddings, encoder)
335390
} else if let (Ok(embeddings), Ok(encoder)) = (
336-
BertEmbeddings::load(vb.pp("roberta.embeddings"), config),
337-
BertEncoder::load(vb.pp("roberta.encoder"), config),
391+
BertEmbeddings::load(vb.pp("camembert.embeddings".to_string()), config),
392+
BertEncoder::load(vb.pp("camembert.encoder".to_string()), config),
338393
) {
339394
(embeddings, encoder)
340395
} else {

0 commit comments

Comments
 (0)