Skip to content

Commit 9941fcc

Browse files
committed
style
1 parent cfbaf65 commit 9941fcc

File tree

21 files changed

+103
-197
lines changed

21 files changed

+103
-197
lines changed

.github/workflows/build_rocm.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
type=semver,pattern=rocm-{{major}}.{{minor}}
8080
type=raw,value=rocm-latest
8181
type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }}
82-
82+
8383
- name: Build and push Docker image
8484
id: build-and-push-rocm
8585
uses: docker/build-push-action@v4
@@ -98,7 +98,7 @@
9898
labels: ${{ steps.meta-rocm.outputs.labels }}
9999
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max
100100
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max
101-
101+
102102
- name: Extract metadata (tags, labels) for Docker
103103
id: meta-rocm-grpc
104104
uses: docker/metadata-action@v4.3.0
@@ -113,7 +113,7 @@
113113
type=semver,pattern=rocm-{{major}}.{{minor}}-grpc
114114
type=raw,value=rocm-latest-grpc
115115
type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }}-grpc
116-
116+
117117
- name: Build and push Docker image
118118
id: build-and-push-rocm-grpc
119119
uses: docker/build-push-action@v4

backends/candle/src/lib.rs

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@ use crate::compute_cap::{
1111
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
1212
};
1313
use crate::models::{
14-
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaConfig, JinaBertModel, JinaCodeConfig, JinaCodeBertModel,
15-
Model, NomicBertModel, NomicConfig,
14+
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, JinaCodeBertModel,
15+
JinaCodeConfig, JinaConfig, Model, NomicBertModel, NomicConfig,
1616
};
1717
#[cfg(feature = "cuda")]
1818
use crate::models::{
19-
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel, FlashNomicBertModel,
19+
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel,
20+
FlashNomicBertModel,
2021
};
2122
use anyhow::Context;
2223
use candle::{DType, Device};
2324
use candle_nn::VarBuilder;
24-
use models::BertConfig;
2525
use nohash_hasher::BuildNoHashHasher;
2626
use serde::Deserialize;
2727
use std::collections::HashMap;
@@ -133,7 +133,9 @@ impl CandleBackend {
133133
}
134134
(Config::JinaCodeBert(config), Device::Cpu | Device::Metal(_)) => {
135135
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
136-
Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?))
136+
Ok(Box::new(
137+
JinaCodeBertModel::load(vb, &config, model_type).s()?,
138+
))
137139
}
138140
(
139141
Config::XlmRoberta(config) | Config::Camembert(config) | Config::Roberta(config),
@@ -171,8 +173,9 @@ impl CandleBackend {
171173
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
172174
}
173175
}
174-
#[cfg(feature = "cuda")]
175-
(Config::JinaBert(config), Device::Cuda(_)) => {
176+
}
177+
#[cfg(feature = "cuda")]
178+
(Config::JinaBert(config), Device::Cuda(_)) => {
176179
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
177180
&& dtype == DType::F16
178181
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
@@ -181,25 +184,32 @@ impl CandleBackend {
181184
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
182185
{
183186
tracing::info!("Starting FlashJinaBertModel model on {:?}", device);
184-
Ok(Box::new(FlashJinaBertModel::load(vb, &config, model_type).s()?,))
187+
Ok(Box::new(
188+
FlashJinaBertModel::load(vb, &config, model_type).s()?,
189+
))
185190
} else {
186191
tracing::info!("Starting JinaBertModel model on {:?}", device);
187192
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
188193
}
189-
#[cfg(feature = "cuda")]
190-
(Config::JinaCodeBert(config), Device::Cuda(_)) => {
191-
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
192-
&& dtype == DType::F16
193-
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
194-
// Allow disabling because of flash attention v1 precision problems
195-
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
196-
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
197-
{
198-
tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device);
199-
Ok(Box::new(FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,))
200-
} else {
201-
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
202-
Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?))
194+
}
195+
#[cfg(feature = "cuda")]
196+
(Config::JinaCodeBert(config), Device::Cuda(_)) => {
197+
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
198+
&& dtype == DType::F16
199+
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
200+
// Allow disabling because of flash attention v1 precision problems
201+
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
202+
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
203+
{
204+
tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device);
205+
Ok(Box::new(
206+
FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,
207+
))
208+
} else {
209+
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
210+
Ok(Box::new(
211+
JinaCodeBertModel::load(vb, &config, model_type).s()?,
212+
))
203213
}
204214
}
205215
#[cfg(feature = "cuda")]

backends/candle/src/models.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ extern crate accelerate_src;
77
mod bert;
88
mod distilbert;
99
mod jina;
10+
mod jina_code;
1011
mod nomic;
1112

1213
#[cfg(feature = "cuda")]
@@ -27,8 +28,8 @@ mod flash_distilbert;
2728
pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
2829
use candle::{Result, Tensor};
2930
pub use distilbert::{DistilBertConfig, DistilBertModel};
30-
pub use jina::{JinaConfig, JinaBertModel};
31-
pub use jina_code::{JinaCodeConfig, JinaCodeBertModel};
31+
pub use jina::{JinaBertModel, JinaConfig};
32+
pub use jina_code::{JinaCodeBertModel, JinaCodeConfig};
3233
pub use nomic::{NomicBertModel, NomicConfig};
3334
use text_embeddings_backend_core::Batch;
3435

@@ -41,7 +42,6 @@ pub use flash_jina::FlashJinaBertModel;
4142
#[cfg(feature = "cuda")]
4243
pub use flash_jina_code::FlashJinaCodeBertModel;
4344

44-
4545
#[cfg(feature = "cuda")]
4646
pub use flash_nomic::FlashNomicBertModel;
4747

backends/candle/src/models/flash_jina.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use crate::alibi::alibi_head_slopes;
22
use crate::flash_attn::flash_attn_varlen;
33
use crate::layers::{HiddenAct, LayerNorm, Linear};
44
use crate::models::bert::PositionEmbeddingType;
5-
use crate::models::jina::{JinaConfig, BertEmbeddings};
65
use crate::models::jina::BertEmbeddings;
6+
use crate::models::jina::{BertEmbeddings, JinaConfig};
77
use crate::models::Model;
88
use candle::{DType, Device, IndexOp, Result, Tensor};
99
use candle_nn::VarBuilder;

backends/candle/src/models/flash_jina_code.rs

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::alibi::alibi_head_slopes;
22
use crate::flash_attn::flash_attn_varlen;
33
use crate::layers::{HiddenAct, LayerNorm, Linear};
44
use crate::models::bert::PositionEmbeddingType;
5-
use crate::models::jina::{JinaCodeConfig, BertEmbeddings};
5+
use crate::models::jina::{BertEmbeddings, JinaCodeConfig};
66
use crate::models::Model;
77
use candle::{DType, Device, IndexOp, Result, Tensor};
88
use candle_nn::VarBuilder;
@@ -28,7 +28,11 @@ struct AlibiBertAttention {
2828
}
2929

3030
impl AlibiBertAttention {
31-
pub fn load(vb: VarBuilder, config: &JinaCodeConfig, alibi_slopes: Option<Tensor>) -> Result<Self> {
31+
pub fn load(
32+
vb: VarBuilder,
33+
config: &JinaCodeConfig,
34+
alibi_slopes: Option<Tensor>,
35+
) -> Result<Self> {
3236
let attention_head_size = config.hidden_size / config.num_attention_heads;
3337
let all_head_size = config.num_attention_heads * attention_head_size;
3438
let hidden_size = config.hidden_size;
@@ -116,9 +120,15 @@ impl AlibiBertAttention {
116120
new_qkv_shape.push(self.num_attention_heads);
117121
new_qkv_shape.push(self.attention_head_size);
118122

119-
let query_layer = query_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
120-
let key_layer = key_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
121-
let value_layer = value_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
123+
let query_layer = query_layer
124+
.reshape(new_qkv_shape.as_slice())?
125+
.transpose(1, 2)?;
126+
let key_layer = key_layer
127+
.reshape(new_qkv_shape.as_slice())?
128+
.transpose(1, 2)?;
129+
let value_layer = value_layer
130+
.reshape(new_qkv_shape.as_slice())?
131+
.transpose(1, 2)?;
122132

123133
let attention = flash_attn_varlen(
124134
query_layer,
@@ -135,7 +145,9 @@ impl AlibiBertAttention {
135145
let attention = attention.flatten_from(candle::D::Minus2)?;
136146

137147
let hidden_states = self.dense.forward(&attention)?;
138-
let hidden_states = self.layer_norm_out.forward(&hidden_states, Some(&residual))?;
148+
let hidden_states = self
149+
.layer_norm_out
150+
.forward(&hidden_states, Some(&residual))?;
139151

140152
Ok(hidden_states)
141153
}
@@ -168,7 +180,10 @@ impl JinaBertLayer {
168180
.pp("mlp")
169181
.pp("down_layer")
170182
.get((config.hidden_size, config.intermediate_size), "weight")?;
171-
let down_bias = vb.pp("mlp").pp("down_layer").get(config.hidden_size, "bias")?;
183+
let down_bias = vb
184+
.pp("mlp")
185+
.pp("down_layer")
186+
.get(config.hidden_size, "bias")?;
172187
let down_layer = Linear::new(down_weight, Some(down_bias), None);
173188

174189
let layer_norm_1 = LayerNorm::load(
@@ -455,4 +470,4 @@ impl Model for FlashJinaCodeBertModel {
455470
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
456471
self.forward(batch)
457472
}
458-
}
473+
}

backends/candle/src/models/jina.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ pub struct JinaConfig {
3030
pub id2label: Option<HashMap<String, String>>,
3131
}
3232

33-
3433
#[derive(Debug)]
3534
pub struct BertEmbeddings {
3635
word_embeddings: Embedding,

backends/candle/src/models/jina_code.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ pub struct JinaCodeConfig {
3030
pub id2label: Option<HashMap<String, String>>,
3131
}
3232

33-
3433
#[derive(Debug)]
3534
pub struct BertEmbeddings {
3635
word_embeddings: Embedding,
@@ -201,9 +200,15 @@ impl BertAttention {
201200
new_qkv_shape.push(self.num_attention_heads);
202201
new_qkv_shape.push(self.attention_head_size);
203202

204-
let query_layer = query_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
205-
let key_layer = key_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
206-
let value_layer = value_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
203+
let query_layer = query_layer
204+
.reshape(new_qkv_shape.as_slice())?
205+
.transpose(1, 2)?;
206+
let key_layer = key_layer
207+
.reshape(new_qkv_shape.as_slice())?
208+
.transpose(1, 2)?;
209+
let value_layer = value_layer
210+
.reshape(new_qkv_shape.as_slice())?
211+
.transpose(1, 2)?;
207212

208213
#[allow(unused_variables)]
209214
let context_layer = if let (Device::Cuda(_), Some(cublaslt)) =
@@ -276,7 +281,9 @@ impl BertAttention {
276281
let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?;
277282

278283
let hidden_states = self.dense.forward(&context_layer)?;
279-
let hidden_states = self.layer_norm_out.forward(&hidden_states, Some(&residual))?;
284+
let hidden_states = self
285+
.layer_norm_out
286+
.forward(&hidden_states, Some(&residual))?;
280287

281288
Ok(hidden_states)
282289
}
@@ -309,7 +316,10 @@ impl JinaCodeBertLayer {
309316
.pp("mlp")
310317
.pp("down_layer")
311318
.get((config.hidden_size, config.intermediate_size), "weight")?;
312-
let down_bias = vb.pp("mlp").pp("down_layer").get(config.hidden_size, "bias")?;
319+
let down_bias = vb
320+
.pp("mlp")
321+
.pp("down_layer")
322+
.get(config.hidden_size, "bias")?;
313323
let down_layer = Linear::new(down_weight, Some(down_bias), None);
314324

315325
let layer_norm_1 = LayerNorm::load(

backends/python/server/text_embeddings_server/layers/attention/rocm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False):
4242
is_causal,
4343
False,
4444
None,
45-
)
45+
)

backends/python/server/text_embeddings_server/layers/layernorm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig):
4141
self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device)
4242
self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device)
4343
self.variance_epsilon = config.layer_norm_eps
44-
44+
4545
def forward(self, hidden_states, residual=None):
4646
if residual is not None:
4747
hidden_states += residual
@@ -51,4 +51,4 @@ def forward(self, hidden_states, residual=None):
5151

5252
return hidden_states, residual
5353
else:
54-
raise ValueError("System not recognized")
54+
raise ValueError("System not recognized")

backends/python/server/text_embeddings_server/layers/pooling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def mean_pooling(embedding, cu_seqlens, max_s):
1616
indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
1717

1818
embedding_padded = pad_input(embedding, indices, batch_size, max_s)
19-
19+
2020
sum_embeddings = torch.sum(embedding_padded, 1)
2121

22-
return sum_embeddings / seqlens[:, None]
22+
return sum_embeddings / seqlens[:, None]

0 commit comments

Comments
 (0)