Skip to content

Commit 0462171

Browse files
feat: Implement GTE model to support the non-flash-attn version (#446)
Co-authored-by: Hyeongchan Kim <kozistr@gmail.com>
1 parent e27a4fb commit 0462171

File tree

13 files changed

+3964
-203
lines changed

13 files changed

+3964
-203
lines changed

backends/candle/src/layers/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ mod layer_norm;
44
mod linear;
55
#[allow(dead_code, unused)]
66
mod rms_norm;
7+
mod rotary;
78

89
pub use cublaslt::get_cublas_lt_wrapper;
910
pub use layer_norm::LayerNorm;
1011
pub use linear::{HiddenAct, Linear};
1112
#[allow(unused_imports)]
1213
pub use rms_norm::RMSNorm;
14+
pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeScaling};

backends/candle/src/layers/rotary.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
use candle::{DType, Device, Result, Tensor, D};
2+
use serde::Deserialize;
3+
4+
#[derive(Debug, Clone, PartialEq, Deserialize)]
5+
pub struct NTKScaling {
6+
pub factor: f32,
7+
}
8+
9+
#[derive(Debug, Clone, PartialEq, Deserialize)]
10+
#[serde(tag = "type", rename_all = "kebab-case")]
11+
pub enum RopeScaling {
12+
Ntk(NTKScaling),
13+
}
14+
15+
pub fn get_inv_freqs(
16+
dim: usize,
17+
base: f32,
18+
device: &Device,
19+
rope_scaling: Option<&RopeScaling>,
20+
) -> Result<Tensor> {
21+
let get_inv_freqs_inner = |dim: usize, base: f32, device: &Device| {
22+
let inv_freq: Vec<_> = (0..dim)
23+
.step_by(2)
24+
.map(|i| 1f32 / base.powf(i as f32 / dim as f32))
25+
.collect();
26+
let inv_freq_len = inv_freq.len();
27+
Tensor::from_vec(inv_freq, (1, inv_freq_len), device)
28+
};
29+
30+
if let Some(rope_scaling) = rope_scaling {
31+
match rope_scaling {
32+
RopeScaling::Ntk(ntk_scaling) => {
33+
let inv_freqs = get_inv_freqs_inner(dim, base * ntk_scaling.factor, device)?;
34+
let s = ntk_scaling.factor.powf(2.0 / dim as f32) as f64;
35+
return inv_freqs / s;
36+
}
37+
}
38+
}
39+
get_inv_freqs_inner(dim, base, device)
40+
}
41+
42+
pub fn get_cos_sin(
43+
length: usize,
44+
inv_freqs: &Tensor,
45+
dtype: DType,
46+
repeat_freqs: bool,
47+
) -> Result<(Tensor, Tensor)> {
48+
let t = Tensor::arange(0u32, length as u32, inv_freqs.device())?
49+
.to_dtype(DType::F32)?
50+
.reshape((length, 1))?;
51+
let mut freqs = t.matmul(inv_freqs)?;
52+
if repeat_freqs {
53+
freqs = Tensor::cat(&[&freqs, &freqs], 1)?;
54+
}
55+
56+
let cos = freqs.cos()?.to_dtype(dtype)?;
57+
let sin = freqs.sin()?.to_dtype(dtype)?;
58+
Ok((cos, sin))
59+
}
60+
61+
pub fn apply_rotary(
62+
x: &Tensor,
63+
cos: &Tensor,
64+
sin: &Tensor,
65+
attention_head_size: usize,
66+
) -> Result<Tensor> {
67+
let dim = attention_head_size / 2;
68+
let x1 = x.narrow(D::Minus1, 0, dim)?;
69+
let x2 = x.narrow(D::Minus1, dim, dim)?;
70+
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
71+
let rope = (x.broadcast_mul(cos)? + rotate_x.broadcast_mul(sin)?)?;
72+
Ok(rope)
73+
}

backends/candle/src/lib.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::compute_cap::{
1111
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
1212
};
1313
use crate::models::{
14-
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, JinaBertModel,
14+
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
1515
JinaCodeBertModel, MistralConfig, Model, NomicBertModel, NomicConfig, Qwen2Config,
1616
};
1717
#[cfg(feature = "cuda")]
@@ -218,10 +218,10 @@ impl CandleBackend {
218218
"Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
219219
.to_string(),
220220
)),
221-
(Config::Gte(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
222-
"GTE is only supported on Cuda devices in fp16 with flash attention enabled"
223-
.to_string(),
224-
)),
221+
(Config::Gte(config), Device::Cpu | Device::Metal(_)) => {
222+
tracing::info!("Starting GTE model on {:?}", device);
223+
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
224+
}
225225
(Config::Qwen2(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
226226
"Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
227227
.to_string(),
@@ -349,10 +349,12 @@ impl CandleBackend {
349349
if dtype != DType::F16
350350
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
351351
{
352-
return Err(BackendError::Start("GTE is only supported on Cuda devices in fp16 with flash attention enabled".to_string()));
352+
tracing::info!("Starting GTE model on {:?}", device);
353+
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
354+
} else {
355+
tracing::info!("Starting FlashGTE model on {:?}", device);
356+
Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?))
353357
}
354-
tracing::info!("Starting FlashGTE model on {:?}", device);
355-
Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?))
356358
}
357359
#[cfg(feature = "cuda")]
358360
(Config::Qwen2(config), Device::Cuda(_)) => {

backends/candle/src/models/flash_gte.rs

Lines changed: 18 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use crate::flash_attn::flash_attn_varlen;
2-
use crate::layers::{HiddenAct, LayerNorm, Linear};
3-
use crate::models::{GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling};
2+
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear};
3+
use crate::models::{GTEClassificationHead, GTEConfig, Model, PositionEmbeddingType, GTEMLP};
44
use candle::{DType, Device, IndexOp, Result, Tensor};
55
use candle_nn::{Embedding, Module, VarBuilder};
6+
use candle_rotary::apply_rotary_inplace;
67
use text_embeddings_backend_core::{Batch, ModelType, Pool};
78

89
struct GTEAttention {
@@ -72,7 +73,7 @@ impl GTEAttention {
7273
let k = qkv.narrow(1, self.num_attention_heads, self.num_attention_heads)?;
7374
let v = qkv.narrow(1, self.num_attention_heads * 2, self.num_attention_heads)?;
7475

75-
candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
76+
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
7677

7778
let attention = flash_attn_varlen(
7879
&q,
@@ -93,60 +94,7 @@ impl GTEAttention {
9394
}
9495
}
9596

96-
struct GTEMLP {
97-
up_gate_proj: Linear,
98-
down_proj: Linear,
99-
100-
act: HiddenAct,
101-
intermediate_size: usize,
102-
103-
span: tracing::Span,
104-
}
105-
106-
impl GTEMLP {
107-
pub fn load(vb: VarBuilder, config: &GTEConfig) -> Result<Self> {
108-
let intermediate_size = config.intermediate_size;
109-
110-
let up_gate_proj_weight = vb
111-
.pp("up_gate_proj")
112-
.get((intermediate_size * 2, config.hidden_size), "weight")?;
113-
114-
let up_gate_proj = Linear::new(up_gate_proj_weight, None, None);
115-
116-
let down_proj_weight = vb
117-
.pp("down_proj")
118-
.get((config.hidden_size, intermediate_size), "weight")?;
119-
let down_proj_bias = vb.pp("down_proj").get(config.hidden_size, "bias")?;
120-
let down_proj = Linear::new(down_proj_weight, Some(down_proj_bias), None);
121-
122-
Ok(Self {
123-
up_gate_proj,
124-
down_proj,
125-
intermediate_size,
126-
act: config.hidden_act.clone(),
127-
span: tracing::span!(tracing::Level::TRACE, "mlp"),
128-
})
129-
}
130-
131-
pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
132-
let _enter = self.span.enter();
133-
134-
let up_gate_states = self.up_gate_proj.forward(hidden_states)?;
135-
let up_states = up_gate_states.narrow(1, 0, self.intermediate_size)?;
136-
let gate_states =
137-
up_gate_states.narrow(1, self.intermediate_size, self.intermediate_size)?;
138-
139-
let gate_states = match self.act {
140-
HiddenAct::Gelu => gate_states.gelu(),
141-
HiddenAct::Relu => gate_states.relu(),
142-
HiddenAct::Swiglu => gate_states.silu(),
143-
}?;
144-
let r = self.down_proj.forward(&(gate_states * up_states)?);
145-
r
146-
}
147-
}
148-
149-
struct GTELayer {
97+
pub struct GTELayer {
15098
attention: GTEAttention,
15199
mlp: GTEMLP,
152100
attention_layer_norm: LayerNorm,
@@ -198,58 +146,6 @@ impl GTELayer {
198146
}
199147
}
200148

201-
pub struct GTEClassificationHead {
202-
pooler: Option<Linear>,
203-
classifier: Linear,
204-
span: tracing::Span,
205-
}
206-
207-
impl GTEClassificationHead {
208-
#[allow(dead_code)]
209-
pub(crate) fn load(vb: VarBuilder, config: &GTEConfig) -> Result<Self> {
210-
let n_classes = match &config.id2label {
211-
None => candle::bail!("`id2label` must be set for classifier models"),
212-
Some(id2label) => id2label.len(),
213-
};
214-
215-
let pooler = if let Ok(pooler_weight) = vb
216-
.pp("pooler.dense")
217-
.get((config.hidden_size, config.hidden_size), "weight")
218-
{
219-
let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?;
220-
Some(Linear::new(pooler_weight, Some(pooler_bias), None))
221-
} else {
222-
None
223-
};
224-
225-
let classifier_weight = vb
226-
.pp("classifier")
227-
.get((n_classes, config.hidden_size), "weight")?;
228-
let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?;
229-
let classifier = Linear::new(classifier_weight, Some(classifier_bias), None);
230-
231-
Ok(Self {
232-
classifier,
233-
pooler,
234-
span: tracing::span!(tracing::Level::TRACE, "classifier"),
235-
})
236-
}
237-
238-
pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
239-
let _enter = self.span.enter();
240-
241-
let mut hidden_states = hidden_states.unsqueeze(1)?;
242-
if let Some(pooler) = self.pooler.as_ref() {
243-
hidden_states = pooler.forward(&hidden_states)?;
244-
hidden_states = hidden_states.tanh()?;
245-
}
246-
247-
let hidden_states = self.classifier.forward(&hidden_states)?;
248-
let hidden_states = hidden_states.squeeze(1)?;
249-
Ok(hidden_states)
250-
}
251-
}
252-
253149
pub struct FlashGTEModel {
254150
word_embeddings: Embedding,
255151
token_type_embeddings: Option<Embedding>,
@@ -322,24 +218,19 @@ impl FlashGTEModel {
322218
config.layer_norm_eps,
323219
)?;
324220

325-
let inv_freqs = if let Some(RopeScaling::Ntk(NTKScaling { factor })) = config.rope_scaling {
326-
let inv_freqs = candle_rotary::inv_freqs(
327-
layers[0].attention.attention_head_size,
328-
config.rope_theta * factor,
329-
vb.device(),
330-
)?;
331-
let s = factor.powf(2.0 / layers[0].attention.attention_head_size as f32) as f64;
332-
inv_freqs / s
333-
} else {
334-
candle_rotary::inv_freqs(
335-
layers[0].attention.attention_head_size,
336-
config.rope_theta,
337-
vb.device(),
338-
)
339-
}?;
340-
341-
let (cos_cache, sin_cache) =
342-
candle_rotary::cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?;
221+
let inv_freqs = get_inv_freqs(
222+
layers[0].attention.attention_head_size,
223+
config.rope_theta,
224+
vb.device(),
225+
config.rope_scaling.as_ref(),
226+
)?;
227+
228+
let (cos_cache, sin_cache) = get_cos_sin(
229+
config.max_position_embeddings,
230+
&inv_freqs,
231+
vb.dtype(),
232+
false,
233+
)?;
343234

344235
Ok(Self {
345236
word_embeddings,

backends/candle/src/models/flash_mistral.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use crate::flash_attn::flash_attn_varlen;
2-
use crate::layers::{HiddenAct, Linear, RMSNorm};
2+
use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm};
33
use crate::models::{MistralConfig, Model};
44
use candle::{DType, Device, IndexOp, Result, Tensor};
55
use candle_nn::{Embedding, Module, VarBuilder};
6+
use candle_rotary::apply_rotary_inplace;
67
use text_embeddings_backend_core::{Batch, ModelType, Pool};
78

89
struct MistralAttention {
@@ -90,7 +91,7 @@ impl MistralAttention {
9091
self.num_key_value_heads,
9192
)?;
9293

93-
candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
94+
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
9495

9596
let attention = flash_attn_varlen(
9697
&q,
@@ -267,13 +268,18 @@ impl FlashMistralModel {
267268

268269
let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;
269270

270-
let inv_freqs = candle_rotary::inv_freqs(
271+
let inv_freqs = get_inv_freqs(
271272
layers[0].attention.attention_head_size,
272273
config.rope_theta,
273274
vb.device(),
275+
None,
276+
)?;
277+
let (cos_cache, sin_cache) = get_cos_sin(
278+
config.max_position_embeddings,
279+
&inv_freqs,
280+
vb.dtype(),
281+
false,
274282
)?;
275-
let (cos_cache, sin_cache) =
276-
candle_rotary::cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?;
277283

278284
Ok(Self {
279285
embeddings,

backends/candle/src/models/flash_nomic.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use crate::flash_attn::flash_attn_varlen;
2-
use crate::layers::{LayerNorm, Linear};
2+
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear};
33
use crate::models::nomic::{NomicBertEmbeddings, NomicBertGatedMLP};
44
use crate::models::{Model, NomicConfig};
55
use candle::{DType, Device, IndexOp, Result, Tensor, D};
66
use candle_nn::VarBuilder;
7+
use candle_rotary::apply_rotary_inplace;
78
use text_embeddings_backend_core::{Batch, ModelType, Pool};
89

910
struct NomicAttention {
@@ -68,7 +69,7 @@ impl NomicAttention {
6869
let qkv = qkv.reshape(new_qkv_shape.as_slice())?;
6970
let qkv = qkv.chunk(3, 1)?;
7071

71-
candle_rotary::apply_rotary_inplace(&qkv[0], &qkv[1], &cos, &sin, true)?;
72+
apply_rotary_inplace(&qkv[0], &qkv[1], &cos, &sin, true)?;
7273

7374
let attention = flash_attn_varlen(
7475
&qkv[0],
@@ -221,20 +222,21 @@ impl FlashNomicBertModel {
221222
let encoder = NomicBertEncoder::load(vb.pp("encoder"), config)?;
222223

223224
let rotary_dim = encoder.layers[0].attention.attention_head_size;
224-
let inv_freqs = candle_rotary::inv_freqs(rotary_dim, config.rotary_emb_base, vb.device())?;
225-
let rotary_cache = candle_rotary::cos_sin(config.n_positions, &inv_freqs, vb.dtype())?;
225+
let inv_freqs = get_inv_freqs(rotary_dim, config.rotary_emb_base, vb.device(), None)?;
226+
let rotary_cache = get_cos_sin(config.n_positions, &inv_freqs, vb.dtype(), false)?;
226227

227228
let scaled_rotary_cache = if let Some(scaling_factor) = config.rotary_scaling_factor {
228229
let new_base = (config.rotary_emb_base
229230
* ((scaling_factor * config.n_positions as f32
230231
/ config.max_trained_positions as f32)
231232
- (scaling_factor - 1.0)))
232233
.powi((rotary_dim as f32 / (rotary_dim as f32 - 2.0)) as i32);
233-
let inv_freqs = candle_rotary::inv_freqs(rotary_dim, new_base, vb.device())?;
234-
Some(candle_rotary::cos_sin(
234+
let inv_freqs = get_inv_freqs(rotary_dim, new_base, vb.device(), None)?;
235+
Some(get_cos_sin(
235236
config.n_positions,
236237
&inv_freqs,
237238
vb.dtype(),
239+
false,
238240
)?)
239241
} else {
240242
None

0 commit comments

Comments
 (0)