Skip to content

Commit f99ce07

Browse files
Narsilkozistr
andauthored
Fixing FlashAttention ModernBert. (#560)
Co-authored-by: Hyeongchan Kim <kozistr@gmail.com>
1 parent 5104236 commit f99ce07

File tree

5 files changed

+66
-38
lines changed

5 files changed

+66
-38
lines changed

backends/candle/src/lib.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ use crate::models::{
1818
#[cfg(feature = "cuda")]
1919
use crate::models::{
2020
FlashBertModel, FlashDistilBertModel, FlashGTEModel, FlashJinaBertModel,
21-
FlashJinaCodeBertModel, FlashMistralModel, FlashNomicBertModel, FlashQwen2Model,
21+
FlashJinaCodeBertModel, FlashMistralModel, FlashModernBertModel, FlashNomicBertModel,
22+
FlashQwen2Model,
2223
};
2324
use anyhow::Context;
2425
use candle::{DType, Device};
@@ -276,7 +277,7 @@ impl CandleBackend {
276277
tracing::info!("Starting MPNet model on {:?}", device);
277278
Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?))
278279
}
279-
(Config::ModernBert(config), _) => match device {
280+
(Config::ModernBert(config), Device::Cpu | Device::Metal(_)) => match device {
280281
Device::Metal(_) => {
281282
return Err(BackendError::Start(
282283
"ModernBert is not currently supported on MPS device".to_string(),
@@ -357,6 +358,27 @@ impl CandleBackend {
357358
}
358359
}
359360
#[cfg(feature = "cuda")]
361+
(Config::ModernBert(config), Device::Cuda(_)) => {
362+
if cfg!(feature = "flash-attn")
363+
&& dtype == DType::F16
364+
// Allow disabling because of flash attention v1 precision problems
365+
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
366+
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
367+
{
368+
tracing::info!("Starting FlashModernBert model on {:?}", device);
369+
Ok(Box::new(
370+
FlashModernBertModel::load(vb, &config, model_type).s()?,
371+
))
372+
} else {
373+
#[cfg(feature = "flash-attn-v1")]
374+
tracing::warn!("Flash attention V1 cannot be used with ModernBert because it lacks windowing support.");
375+
tracing::info!("Starting ModernBert model on {:?}", device);
376+
Ok(Box::new(
377+
ModernBertModel::load(vb, &config, model_type).s()?,
378+
))
379+
}
380+
}
381+
#[cfg(feature = "cuda")]
360382
(Config::DistilBert(config), Device::Cuda(_)) => {
361383
if cfg!(feature = "flash-attn")
362384
&& dtype == DType::F16

backends/candle/src/models/flash_modernbert.rs

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
use std::collections::HashMap;
22

33
use crate::flash_attn::flash_attn_varlen;
4-
use crate::layers::{apply_rotary, get_cos_sin, get_inv_freqs, LayerNorm, Linear};
4+
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNormNoBias, Linear};
55
use crate::models::modernbert::{
66
ClassificationHead, ModernBertClassificationHead, ModernBertConfig, ModernBertEmbeddings,
77
ModernBertMLP,
88
};
99
use crate::models::Model;
1010
use candle::{DType, Device, IndexOp, Result, Tensor};
1111
use candle_nn::VarBuilder;
12+
use candle_rotary::apply_rotary_inplace;
1213
use text_embeddings_backend_core::{Batch, ModelType, Pool};
1314

1415
struct ModernBertAttention {
@@ -79,35 +80,34 @@ impl ModernBertAttention {
7980
new_qkv_shape.pop();
8081
new_qkv_shape.push(self.num_attention_heads * 3);
8182
new_qkv_shape.push(self.attention_head_size);
82-
let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
83+
let qkv = qkv.reshape(new_qkv_shape.as_slice())?;
8384

84-
let qkv = qkv.chunk(3, 1)?;
85-
let query_layer = &qkv[0].contiguous()?;
86-
let key_layer = &qkv[1].contiguous()?;
87-
let value_layer = &qkv[2];
85+
// Split qkv tensor
86+
let q = qkv.narrow(1, 0, self.num_attention_heads)?;
87+
let k = qkv.narrow(1, self.num_attention_heads, self.num_attention_heads)?;
88+
let v = qkv.narrow(1, self.num_attention_heads * 2, self.num_attention_heads)?;
8889

89-
let query_layer = apply_rotary(query_layer, cos, sin, self.attention_head_size)?;
90-
let key_layer = apply_rotary(key_layer, cos, sin, self.attention_head_size)?;
90+
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
9191

92-
let attention_size = if self.use_local_attention {
92+
let window_size = if self.use_local_attention {
9393
Some(self.local_attention)
9494
} else {
9595
None
9696
};
9797

9898
let attention = flash_attn_varlen(
99-
&query_layer,
100-
&key_layer,
101-
&value_layer,
99+
&q,
100+
&k,
101+
&v,
102102
None,
103103
cu_seqlens,
104104
cu_seqlens,
105105
max_s,
106106
max_s,
107107
self.softmax_scale,
108108
false,
109-
attention_size,
110-
attention_size,
109+
window_size,
110+
window_size,
111111
)?;
112112
let attention = attention.flatten_from(candle::D::Minus2)?;
113113

@@ -118,9 +118,9 @@ impl ModernBertAttention {
118118
}
119119

120120
struct ModernBertEncoderLayer {
121-
attn_norm: Option<LayerNorm>,
121+
attn_norm: Option<LayerNormNoBias>,
122122
attn: ModernBertAttention,
123-
mlp_norm: LayerNorm,
123+
mlp_norm: LayerNormNoBias,
124124
mlp: ModernBertMLP,
125125

126126
span: tracing::Span,
@@ -129,7 +129,7 @@ struct ModernBertEncoderLayer {
129129
impl ModernBertEncoderLayer {
130130
pub fn load(vb: VarBuilder, index: usize, config: &ModernBertConfig) -> Result<Self> {
131131
let attn_norm = if index != 0 {
132-
Some(LayerNorm::load(
132+
Some(LayerNormNoBias::load(
133133
vb.pp("attn_norm"),
134134
config.hidden_size,
135135
config.norm_eps as f32,
@@ -140,7 +140,7 @@ impl ModernBertEncoderLayer {
140140

141141
let attn = ModernBertAttention::load(vb.pp("attn"), index, config)?;
142142

143-
let mlp_norm = LayerNorm::load(
143+
let mlp_norm = LayerNormNoBias::load(
144144
vb.pp("mlp_norm"),
145145
config.hidden_size,
146146
config.norm_eps as f32,
@@ -236,11 +236,10 @@ impl ModernBertEncoder {
236236
pub struct FlashModernBertModel {
237237
embeddings: ModernBertEmbeddings,
238238
encoder: ModernBertEncoder,
239-
final_norm: LayerNorm,
239+
final_norm: LayerNormNoBias,
240240
pool: Pool,
241241
classifier: Option<Box<dyn ClassificationHead + Send>>,
242242

243-
rotary_dim: usize,
244243
rotary_cache: HashMap<bool, (Tensor, Tensor)>,
245244

246245
device: Device,
@@ -277,13 +276,22 @@ impl FlashModernBertModel {
277276
}
278277
};
279278

280-
let embeddings = ModernBertEmbeddings::load(vb.pp("model.embeddings"), config)?;
281-
let encoder = ModernBertEncoder::load(vb.pp("model.layers"), config)?;
282-
let final_norm = LayerNorm::load(
279+
let embeddings = ModernBertEmbeddings::load(vb.pp("model.embeddings"), config)
280+
.or_else(|_| ModernBertEmbeddings::load(vb.pp("embeddings"), config))?;
281+
let encoder = ModernBertEncoder::load(vb.pp("model.layers"), config)
282+
.or_else(|_| ModernBertEncoder::load(vb.pp("layers"), config))?;
283+
let final_norm = LayerNormNoBias::load(
283284
vb.pp("model.final_norm"),
284285
config.hidden_size,
285286
config.norm_eps as f32,
286-
)?;
287+
)
288+
.or_else(|_| {
289+
LayerNormNoBias::load(
290+
vb.pp("final_norm"),
291+
config.hidden_size,
292+
config.norm_eps as f32,
293+
)
294+
})?;
287295

288296
let rotary_dim = config.hidden_size / config.num_attention_heads;
289297
let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new();
@@ -295,15 +303,11 @@ impl FlashModernBertModel {
295303
config.global_rope_theta
296304
};
297305

298-
let max_position_embeddings = if use_local_attention {
299-
config.max_position_embeddings
300-
} else {
301-
config.local_attention
302-
};
306+
let max_position_embeddings = config.max_position_embeddings;
303307

304308
let inv_freqs = get_inv_freqs(rotary_dim, rope_theta as f32, vb.device(), None)?;
305309

306-
let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), true)?;
310+
let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), false)?;
307311

308312
rotary_cache.insert(use_local_attention, (cos, sin));
309313
}
@@ -314,7 +318,6 @@ impl FlashModernBertModel {
314318
final_norm,
315319
pool,
316320
classifier,
317-
rotary_dim,
318321
rotary_cache,
319322
device: vb.device().clone(),
320323
span: tracing::span!(tracing::Level::TRACE, "model"),
@@ -343,9 +346,6 @@ impl FlashModernBertModel {
343346
let cos = cos.index_select(&position_ids, 0)?;
344347
let sin = sin.index_select(&position_ids, 0)?;
345348

346-
let cos = cos.reshape((batch_size, 1, max_length, self.rotary_dim))?;
347-
let sin = sin.reshape((batch_size, 1, max_length, self.rotary_dim))?;
348-
349349
rotary_cache.insert(use_local_attention, (cos, sin));
350350
}
351351

backends/candle/src/models/modernbert.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,10 @@ impl ModernBertModel {
578578
}
579579

580580
fn get_local_attention_mask(&self, attention_mask: &Tensor) -> Result<Tensor> {
581-
let attention_mask = attention_mask.to_dtype(DType::U8)?;
581+
let dev = attention_mask.device();
582+
let attention_mask = attention_mask
583+
.to_device(&Device::Cpu)?
584+
.to_dtype(DType::U8)?;
582585

583586
let mask_shape = attention_mask.shape();
584587
let (_, _, seq_len, _) = mask_shape.dims4()?;
@@ -597,6 +600,7 @@ impl ModernBertModel {
597600

598601
let zero_tensor = Tensor::zeros_like(&attention_mask)?;
599602
let local_attention_mask = attention_mask.where_cond(&window_mask, &zero_tensor)?;
603+
let local_attention_mask = local_attention_mask.to_device(dev)?;
600604

601605
Ok(local_attention_mask)
602606
}

backends/grpc-client/build.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::fs;
22

33
fn main() -> Result<(), Box<dyn std::error::Error>> {
4-
println!("cargo:rerun-if-changed=../../proto/embed.proto");
4+
println!("cargo:rerun-if-changed=../proto/embed.proto");
55
fs::create_dir("src/pb").unwrap_or(());
66

77
let mut config = prost_build::Config::new();

flake.nix

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@
220220
LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib:/run/opengl-driver/lib";
221221
LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib:/run/opengl-driver/lib";
222222
CUDA_ROOT = "${pkgs.cudaPackages.cudatoolkit}";
223+
CANDLE_FLASH_ATTN_BUILD_DIR = "./kernels";
224+
CANDLE_LAYER_NORM_BUILD_DIR = "./kernels";
223225
};
224226
}
225227
);

0 commit comments

Comments
 (0)