Skip to content

Commit 20bda42

Browse files
committed
add file back
1 parent 9941fcc commit 20bda42

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed

backends/python/src/lib.rs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,126 @@
1+
mod logging;
2+
mod management;
13

4+
use backend_grpc_client::Client;
5+
use nohash_hasher::BuildNoHashHasher;
6+
use std::collections::HashMap;
7+
use text_embeddings_backend_core::{
8+
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
9+
};
10+
use tokio::runtime::Runtime;
11+
12+
pub struct PythonBackend {
13+
_backend_process: management::BackendProcess,
14+
tokio_runtime: Runtime,
15+
backend_client: Client,
16+
}
17+
18+
impl PythonBackend {
19+
pub fn new(
20+
model_path: String,
21+
dtype: String,
22+
model_type: ModelType,
23+
uds_path: String,
24+
otlp_endpoint: Option<String>,
25+
otlp_service_name: String,
26+
) -> Result<Self, BackendError> {
27+
let model_type_clone = model_type.clone();
28+
29+
match model_type {
30+
ModelType::Classifier => {
31+
return Err(BackendError::Start(
32+
"`classifier` model type is not supported".to_string(),
33+
))
34+
}
35+
ModelType::Embedding(pool) => {
36+
if pool != Pool::Cls && pool != Pool::Mean {
37+
return Err(BackendError::Start(format!("{pool:?} is not supported in the TEI Python backend. Please open an issue.")));
38+
}
39+
pool
40+
}
41+
};
42+
43+
let pool_string = match &model_type_clone {
44+
ModelType::Classifier => &Pool::Cls,
45+
ModelType::Embedding(pool) => pool,
46+
}
47+
.to_string();
48+
49+
let backend_process = management::BackendProcess::new(
50+
model_path,
51+
dtype,
52+
&uds_path,
53+
otlp_endpoint,
54+
otlp_service_name,
55+
pool_string,
56+
)?;
57+
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
58+
.enable_all()
59+
.build()
60+
.map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?;
61+
62+
let backend_client = tokio_runtime
63+
.block_on(Client::connect_uds(uds_path))
64+
.map_err(|err| {
65+
BackendError::Start(format!("Could not connect to backend process: {err}"))
66+
})?;
67+
68+
Ok(Self {
69+
_backend_process: backend_process,
70+
tokio_runtime,
71+
backend_client,
72+
})
73+
}
74+
}
75+
76+
impl Backend for PythonBackend {
77+
fn health(&self) -> Result<(), BackendError> {
78+
if self
79+
.tokio_runtime
80+
.block_on(self.backend_client.clone().health())
81+
.is_err()
82+
{
83+
return Err(BackendError::Unhealthy);
84+
}
85+
Ok(())
86+
}
87+
88+
fn is_padded(&self) -> bool {
89+
false
90+
}
91+
92+
fn embed(&self, batch: Batch) -> Result<Embeddings, BackendError> {
93+
if !batch.raw_indices.is_empty() {
94+
return Err(BackendError::Inference(
95+
"raw embeddings are not supported for the Python backend.".to_string(),
96+
));
97+
}
98+
let batch_size = batch.len();
99+
100+
let results = self
101+
.tokio_runtime
102+
.block_on(self.backend_client.clone().embed(
103+
batch.input_ids,
104+
batch.token_type_ids,
105+
batch.position_ids,
106+
batch.cumulative_seq_lengths,
107+
batch.max_length,
108+
))
109+
.map_err(|err| BackendError::Inference(err.to_string()))?;
110+
let pooled_embeddings: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();
111+
112+
let mut embeddings =
113+
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
114+
for (i, e) in pooled_embeddings.into_iter().enumerate() {
115+
embeddings.insert(i, Embedding::Pooled(e));
116+
}
117+
118+
Ok(embeddings)
119+
}
120+
121+
fn predict(&self, _batch: Batch) -> Result<Predictions, BackendError> {
122+
Err(BackendError::Inference(
123+
"`predict` is not implemented".to_string(),
124+
))
125+
}
126+
}

0 commit comments

Comments
 (0)