Skip to content

Commit 4adfe71

Browse files
feat: support jinaAI variant (#48)
1 parent e07f68a commit 4adfe71

File tree

5 files changed

+806
-92
lines changed

5 files changed

+806
-92
lines changed

backends/candle/src/alibi.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// coding=utf-8
2+
// Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3+
// Copyright (c) 2023 Jina AI GmbH. All rights reserved.
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
use candle::{DType, Device, Result, Tensor};
17+
18+
fn get_slopes_power_of_2(n: usize) -> Vec<f64> {
19+
let start: f64 = 2_f64.powf(-2_f64.powf(-((n as f64).log2() - 3_f64)));
20+
21+
(0..n).map(|i| start * start.powi(i as i32)).collect()
22+
}
23+
24+
fn alibi_head_slopes(num_attention_heads: usize) -> Vec<f64> {
25+
if (num_attention_heads as f64).log2().fract() == 0.0 {
26+
// `num_attention_heads` is a power of 2
27+
get_slopes_power_of_2(num_attention_heads)
28+
} else {
29+
let closest_power_of_2 =
30+
2_f64.powi((num_attention_heads as f64).log2().floor() as i32) as usize;
31+
32+
let mut slopes = get_slopes_power_of_2(closest_power_of_2);
33+
let additional_slopes: Vec<f64> = get_slopes_power_of_2(2 * closest_power_of_2)
34+
.into_iter()
35+
.enumerate()
36+
// Filter odd indices
37+
.filter(|(i, _)| i % 2 == 0)
38+
// Remove i
39+
.map(|(_, v)| v)
40+
.collect();
41+
42+
// Extend slopes
43+
slopes.extend_from_slice(&additional_slopes[0..(num_attention_heads - closest_power_of_2)]);
44+
45+
slopes
46+
}
47+
}
48+
49+
pub fn build_alibi_tensor(
50+
num_positions: usize,
51+
num_heads: usize,
52+
device: &Device,
53+
dtype: DType,
54+
) -> Result<Tensor> {
55+
let context_positions = Tensor::arange(0.0, num_positions as f64, device)?.unsqueeze(1)?;
56+
let memory_positions = Tensor::arange(0.0, num_positions as f64, device)?.unsqueeze(0)?;
57+
58+
let relative_positions = memory_positions.broadcast_sub(&context_positions)?.abs()?;
59+
// [num_heads, num_positions, num_positions]
60+
let relative_positions =
61+
relative_positions
62+
.unsqueeze(0)?
63+
.expand((num_heads, num_positions, num_positions))?;
64+
65+
// [num_heads, 1, 1]
66+
let slopes =
67+
(Tensor::from_vec(alibi_head_slopes(num_heads), (num_heads, 1, 1), device)? * -1_f64)?;
68+
69+
// [num_heads, num_positions, num_positions]
70+
let alibi = relative_positions.broadcast_mul(&slopes)?;
71+
72+
alibi
73+
.reshape((1, num_heads, num_positions, num_positions))?
74+
.to_dtype(dtype)
75+
}

backends/candle/src/lib.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod alibi;
12
#[cfg(feature = "cuda")]
23
mod compute_cap;
34
#[cfg(feature = "cuda")]
@@ -9,7 +10,9 @@ mod models;
910
use crate::compute_cap::{incompatible_compute_cap, COMPILE_COMPUTE_CAP, RUNTIME_COMPUTE_CAP};
1011
#[cfg(feature = "cuda")]
1112
use crate::models::FlashBertModel;
12-
use crate::models::{BertModel, EmbeddingModel, PositionEmbeddingType, QuantBertModel};
13+
use crate::models::{
14+
BertModel, EmbeddingModel, JinaBertModel, PositionEmbeddingType, QuantBertModel,
15+
};
1316
use candle::{DType, Device};
1417
use candle_nn::VarBuilder;
1518
use models::Config;
@@ -47,8 +50,6 @@ impl CandleBackend {
4750

4851
let model: Box<dyn EmbeddingModel + Send> = match device {
4952
Device::Cpu => {
50-
tracing::info!("Starting Bert model on CPU");
51-
5253
if &dtype == "float32" || &dtype == "float16" {
5354
let dtype = if &dtype == "float32" {
5455
DType::F32
@@ -70,14 +71,21 @@ impl CandleBackend {
7071
}
7172
.s()?;
7273

73-
Box::new(BertModel::load(vb, &config, pool).s()?)
74+
if config.position_embedding_type == PositionEmbeddingType::Alibi {
75+
tracing::info!("Starting JinaBert model on CPU");
76+
Box::new(JinaBertModel::load(vb, &config, pool).s()?)
77+
} else {
78+
tracing::info!("Starting Bert model on CPU");
79+
Box::new(BertModel::load(vb, &config, pool).s()?)
80+
}
7481
} else if &dtype == "q6k" {
7582
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
7683
model_path.join("ggml-model-q6k.bin"),
7784
)
7885
.map_err(|err| BackendError::Start(err.to_string()))?;
7986
tracing::info!("vb");
8087

88+
tracing::info!("Starting QuantBert model on CPU");
8189
Box::new(QuantBertModel::load(vb, &config, pool).s()?)
8290
} else {
8391
return Err(BackendError::Start(format!(
@@ -130,6 +138,9 @@ impl CandleBackend {
130138
{
131139
tracing::info!("Starting FlashBert model on Cuda");
132140
Box::new(FlashBertModel::load(vb, &config, pool).s()?)
141+
} else if config.position_embedding_type == PositionEmbeddingType::Alibi {
142+
tracing::info!("Starting JinaBert model on Cuda");
143+
Box::new(JinaBertModel::load(vb, &config, pool).s()?)
133144
} else {
134145
tracing::info!("Starting Bert model on Cuda");
135146
Box::new(BertModel::load(vb, &config, pool).s()?)

backends/candle/src/models.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@ mod bert_quant;
1010
pub use bert::{BertModel, Config, PositionEmbeddingType};
1111
pub use bert_quant::QuantBertModel;
1212
use candle::{Result, Tensor};
13+
pub use jina::JinaBertModel;
1314
use text_embeddings_backend_core::Batch;
1415

1516
#[cfg(feature = "cuda")]
1617
mod flash_bert;
18+
mod jina;
19+
1720
#[cfg(feature = "cuda")]
1821
pub use flash_bert::FlashBertModel;
1922

0 commit comments

Comments
 (0)