Skip to content

Commit 7c4f67e

Browse files
feat: support multiple backends at the same time (#440)
1 parent 76b29f1 commit 7c4f67e

File tree

9 files changed

+1085
-739
lines changed

9 files changed

+1085
-739
lines changed

Cargo.lock

Lines changed: 921 additions & 639 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Dockerfile

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,22 @@ ARG ACTIONS_CACHE_URL
2828
ARG ACTIONS_RUNTIME_TOKEN
2929
ARG SCCACHE_GHA_ENABLED
3030

31+
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
32+
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
33+
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | \
34+
tee /etc/apt/sources.list.d/oneAPI.list
35+
36+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
37+
intel-oneapi-mkl-devel=2024.0.0-49656 \
38+
build-essential \
39+
&& rm -rf /var/lib/apt/lists/*
40+
41+
RUN echo "int mkl_serv_intel_cpu_true() {return 1;}" > fakeintel.c && \
42+
gcc -shared -fPIC -o libfakeintel.so fakeintel.c
43+
3144
COPY --from=planner /usr/src/recipe.json recipe.json
3245

33-
RUN cargo chef cook --release --features ort --no-default-features --recipe-path recipe.json && sccache -s
46+
RUN cargo chef cook --release --features ort --features candle --features mkl-dynamic --no-default-features --recipe-path recipe.json && sccache -s
3447

3548
COPY backends backends
3649
COPY core core
@@ -40,7 +53,7 @@ COPY Cargo.lock ./
4053

4154
FROM builder AS http-builder
4255

43-
RUN cargo build --release --bin text-embeddings-router -F ort -F http --no-default-features && sccache -s
56+
RUN cargo build --release --bin text-embeddings-router -F ort -F candle -F mkl-dynamic -F http --no-default-features && sccache -s
4457

4558
FROM builder AS grpc-builder
4659

@@ -52,19 +65,35 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
5265

5366
COPY proto proto
5467

55-
RUN cargo build --release --bin text-embeddings-router -F grpc -F ort --no-default-features && sccache -s
68+
RUN cargo build --release --bin text-embeddings-router -F grpc -F ort -F candle -F mkl-dynamic --no-default-features && sccache -s
5669

5770
FROM debian:bookworm-slim AS base
5871

5972
ENV HUGGINGFACE_HUB_CACHE=/data \
60-
PORT=80
73+
PORT=80 \
74+
MKL_ENABLE_INSTRUCTIONS=AVX512_E4 \
75+
RAYON_NUM_THREADS=8 \
76+
LD_PRELOAD=/usr/local/libfakeintel.so \
77+
LD_LIBRARY_PATH=/usr/local/lib
6178

6279
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
80+
libomp-dev \
6381
ca-certificates \
6482
libssl-dev \
6583
curl \
6684
&& rm -rf /var/lib/apt/lists/*
6785

86+
# Copy a lot of the Intel shared objects because of the mkl_serv_intel_cpu_true patch...
87+
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_lp64.so.2 /usr/local/lib/libmkl_intel_lp64.so.2
88+
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_thread.so.2 /usr/local/lib/libmkl_intel_thread.so.2
89+
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so.2 /usr/local/lib/libmkl_core.so.2
90+
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_def.so.2 /usr/local/lib/libmkl_vml_def.so.2
91+
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_def.so.2 /usr/local/lib/libmkl_def.so.2
92+
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx2.so.2 /usr/local/lib/libmkl_vml_avx2.so.2
93+
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx512.so.2 /usr/local/lib/libmkl_vml_avx512.so.2
94+
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /usr/local/lib/libmkl_avx2.so.2
95+
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2
96+
COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so
6897

6998
FROM base AS grpc
7099

backends/candle/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use candle_nn::VarBuilder;
2525
use nohash_hasher::BuildNoHashHasher;
2626
use serde::Deserialize;
2727
use std::collections::HashMap;
28-
use std::path::PathBuf;
28+
use std::path::Path;
2929
use text_embeddings_backend_core::{
3030
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
3131
};
@@ -69,7 +69,7 @@ pub struct CandleBackend {
6969

7070
impl CandleBackend {
7171
pub fn new(
72-
model_path: PathBuf,
72+
model_path: &Path,
7373
dtype: String,
7474
model_type: ModelType,
7575
) -> Result<Self, BackendError> {

backends/core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,6 @@ pub enum BackendError {
8888
Inference(String),
8989
#[error("Backend is unhealthy")]
9090
Unhealthy,
91+
#[error("Weights not found: {0}")]
92+
WeightsNotFound(String),
9193
}

backends/ort/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ homepage.workspace = true
88
[dependencies]
99
anyhow = { workspace = true }
1010
nohash-hasher = { workspace = true }
11-
ndarray = "0.15.6"
11+
ndarray = "0.16.1"
1212
num_cpus = { workspace = true }
1313
ort = { version = "2.0.0-rc.4", default-features = false, features = ["download-binaries", "half", "onednn", "ndarray"] }
1414
text-embeddings-backend-core = { path = "../core" }

backends/ort/src/lib.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use ndarray::{s, Axis};
22
use nohash_hasher::BuildNoHashHasher;
3-
use ort::{GraphOptimizationLevel, Session};
3+
use ort::session::{builder::GraphOptimizationLevel, Session};
44
use std::collections::HashMap;
55
use std::ops::{Div, Mul};
6-
use std::path::PathBuf;
6+
use std::path::Path;
77
use text_embeddings_backend_core::{
88
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
99
};
@@ -16,12 +16,12 @@ pub struct OrtBackend {
1616

1717
impl OrtBackend {
1818
pub fn new(
19-
model_path: PathBuf,
19+
model_path: &Path,
2020
dtype: String,
2121
model_type: ModelType,
2222
) -> Result<Self, BackendError> {
2323
// Check dtype
24-
if &dtype == "float32" {
24+
if dtype == "float32" {
2525
} else {
2626
return Err(BackendError::Start(format!(
2727
"DType {dtype} is not supported"
@@ -246,6 +246,7 @@ impl Backend for OrtBackend {
246246
if has_raw_requests {
247247
// Reshape outputs
248248
let s = outputs.shape().to_vec();
249+
#[allow(deprecated)]
249250
let outputs = outputs.into_shape((s[0] * s[1], s[2])).e()?;
250251

251252
// We need to remove the padding tokens only if batch_size > 1 and there are some

backends/src/lib.rs

Lines changed: 89 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ pub struct Backend {
3737
}
3838

3939
impl Backend {
40-
pub fn new(
40+
pub async fn new(
4141
model_path: PathBuf,
42+
api_repo: Option<ApiRepo>,
4243
dtype: DType,
4344
model_type: ModelType,
4445
uds_path: String,
@@ -49,12 +50,14 @@ impl Backend {
4950

5051
let backend = init_backend(
5152
model_path,
53+
api_repo,
5254
dtype,
5355
model_type.clone(),
5456
uds_path,
5557
otlp_endpoint,
5658
otlp_service_name,
57-
)?;
59+
)
60+
.await?;
5861
let padded_model = backend.is_padded();
5962
let max_batch_size = backend.max_batch_size();
6063

@@ -193,48 +196,102 @@ impl Backend {
193196
}
194197

195198
#[allow(unused)]
196-
fn init_backend(
199+
async fn init_backend(
197200
model_path: PathBuf,
201+
api_repo: Option<ApiRepo>,
198202
dtype: DType,
199203
model_type: ModelType,
200204
uds_path: String,
201205
otlp_endpoint: Option<String>,
202206
otlp_service_name: String,
203207
) -> Result<Box<dyn CoreBackend + Send>, BackendError> {
208+
let mut backend_start_failed = false;
209+
210+
if cfg!(feature = "ort") {
211+
#[cfg(feature = "ort")]
212+
{
213+
if let Some(api_repo) = api_repo.as_ref() {
214+
let start = std::time::Instant::now();
215+
download_onnx(api_repo)
216+
.await
217+
.map_err(|err| BackendError::WeightsNotFound(err.to_string()));
218+
tracing::info!("Model ONNX weights downloaded in {:?}", start.elapsed());
219+
}
220+
221+
let backend = OrtBackend::new(&model_path, dtype.to_string(), model_type.clone());
222+
match backend {
223+
Ok(b) => return Ok(Box::new(b)),
224+
Err(err) => {
225+
tracing::error!("Could not start ORT backend: {err}");
226+
backend_start_failed = true;
227+
}
228+
}
229+
}
230+
}
231+
232+
if let Some(api_repo) = api_repo.as_ref() {
233+
if cfg!(feature = "python") || cfg!(feature = "candle") {
234+
let start = std::time::Instant::now();
235+
if download_safetensors(api_repo).await.is_err() {
236+
tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower.");
237+
tracing::info!("Downloading `pytorch_model.bin`");
238+
api_repo
239+
.get("pytorch_model.bin")
240+
.await
241+
.map_err(|err| BackendError::WeightsNotFound(err.to_string()))?;
242+
}
243+
244+
tracing::info!("Model weights downloaded in {:?}", start.elapsed());
245+
}
246+
}
247+
204248
if cfg!(feature = "candle") {
205249
#[cfg(feature = "candle")]
206-
return Ok(Box::new(CandleBackend::new(
207-
model_path,
208-
dtype.to_string(),
209-
model_type,
210-
)?));
211-
} else if cfg!(feature = "python") {
250+
{
251+
let backend = CandleBackend::new(&model_path, dtype.to_string(), model_type.clone());
252+
match backend {
253+
Ok(b) => return Ok(Box::new(b)),
254+
Err(err) => {
255+
tracing::error!("Could not start Candle backend: {err}");
256+
backend_start_failed = true;
257+
}
258+
}
259+
}
260+
}
261+
262+
if cfg!(feature = "python") {
212263
#[cfg(feature = "python")]
213264
{
214-
return Ok(Box::new(
215-
std::thread::spawn(move || {
216-
PythonBackend::new(
217-
model_path.to_str().unwrap().to_string(),
218-
dtype.to_string(),
219-
model_type,
220-
uds_path,
221-
otlp_endpoint,
222-
otlp_service_name,
223-
)
224-
})
225-
.join()
226-
.expect("Python Backend management thread failed")?,
227-
));
265+
let backend = std::thread::spawn(move || {
266+
PythonBackend::new(
267+
model_path.to_str().unwrap().to_string(),
268+
dtype.to_string(),
269+
model_type,
270+
uds_path,
271+
otlp_endpoint,
272+
otlp_service_name,
273+
)
274+
})
275+
.join()
276+
.expect("Python Backend management thread failed");
277+
278+
match backend {
279+
Ok(b) => return Ok(Box::new(b)),
280+
Err(err) => {
281+
tracing::error!("Could not start Python backend: {err}");
282+
backend_start_failed = true;
283+
}
284+
}
228285
}
229-
} else if cfg!(feature = "ort") {
230-
#[cfg(feature = "ort")]
231-
return Ok(Box::new(OrtBackend::new(
232-
model_path,
233-
dtype.to_string(),
234-
model_type,
235-
)?));
236286
}
237-
Err(BackendError::NoBackend)
287+
288+
if backend_start_failed {
289+
Err(BackendError::Start(
290+
"Could not start a suitable backend".to_string(),
291+
))
292+
} else {
293+
Err(BackendError::NoBackend)
294+
}
238295
}
239296

240297
#[derive(Debug)]
@@ -298,31 +355,6 @@ enum BackendCommand {
298355
),
299356
}
300357

301-
pub async fn download_weights(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
302-
let model_files = if cfg!(feature = "python") || cfg!(feature = "candle") {
303-
match download_safetensors(api).await {
304-
Ok(p) => p,
305-
Err(_) => {
306-
tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower.");
307-
tracing::info!("Downloading `pytorch_model.bin`");
308-
let p = api.get("pytorch_model.bin").await?;
309-
vec![p]
310-
}
311-
}
312-
} else if cfg!(feature = "ort") {
313-
match download_onnx(api).await {
314-
Ok(p) => p,
315-
Err(err) => {
316-
panic!("failed to download `model.onnx` or `model.onnx_data`. Check the onnx file exists in the repository. {err}");
317-
}
318-
}
319-
} else {
320-
unreachable!()
321-
};
322-
323-
Ok(model_files)
324-
}
325-
326358
async fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
327359
// Single file
328360
tracing::info!("Downloading `model.safetensors`");
@@ -362,6 +394,7 @@ async fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
362394
Ok(safetensors_files)
363395
}
364396

397+
#[cfg(feature = "ort")]
365398
async fn download_onnx(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
366399
let mut model_files: Vec<PathBuf> = Vec::new();
367400

core/src/download.rs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use hf_hub::api::tokio::{ApiError, ApiRepo};
22
use std::path::PathBuf;
3-
use text_embeddings_backend::download_weights;
43
use tracing::instrument;
54

65
// Old classes used other config names than 'sentence_bert_config.json'
@@ -15,20 +14,36 @@ pub const ST_CONFIG_NAMES: [&str; 7] = [
1514
];
1615

1716
#[instrument(skip_all)]
18-
pub async fn download_artifacts(api: &ApiRepo) -> Result<PathBuf, ApiError> {
17+
pub async fn download_artifacts(api: &ApiRepo, pool_config: bool) -> Result<PathBuf, ApiError> {
1918
let start = std::time::Instant::now();
2019

2120
tracing::info!("Starting download");
2221

22+
// Optionally download the pooling config.
23+
if pool_config {
24+
// If a pooling config exist, download it
25+
let _ = download_pool_config(api).await.map_err(|err| {
26+
tracing::warn!("Download failed: {err}");
27+
err
28+
});
29+
}
30+
31+
// Download legacy sentence transformers config
32+
// We don't warn on failure as it is a legacy file
33+
let _ = download_st_config(api).await;
34+
// Download new sentence transformers config
35+
let _ = download_new_st_config(api).await.map_err(|err| {
36+
tracing::warn!("Download failed: {err}");
37+
err
38+
});
39+
2340
tracing::info!("Downloading `config.json`");
2441
api.get("config.json").await?;
2542

2643
tracing::info!("Downloading `tokenizer.json`");
27-
api.get("tokenizer.json").await?;
28-
29-
let model_files = download_weights(api).await?;
30-
let model_root = model_files[0].parent().unwrap().to_path_buf();
44+
let tokenizer_path = api.get("tokenizer.json").await?;
3145

46+
let model_root = tokenizer_path.parent().unwrap().to_path_buf();
3247
tracing::info!("Model artifacts downloaded in {:?}", start.elapsed());
3348
Ok(model_root)
3449
}

0 commit comments

Comments
 (0)