|
| 1 | +mod logging; |
| 2 | +mod management; |
1 | 3 |
|
| 4 | +use backend_grpc_client::Client; |
| 5 | +use nohash_hasher::BuildNoHashHasher; |
| 6 | +use std::collections::HashMap; |
| 7 | +use text_embeddings_backend_core::{ |
| 8 | + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, |
| 9 | +}; |
| 10 | +use tokio::runtime::Runtime; |
| 11 | + |
| 12 | +pub struct PythonBackend { |
| 13 | + _backend_process: management::BackendProcess, |
| 14 | + tokio_runtime: Runtime, |
| 15 | + backend_client: Client, |
| 16 | +} |
| 17 | + |
| 18 | +impl PythonBackend { |
| 19 | + pub fn new( |
| 20 | + model_path: String, |
| 21 | + dtype: String, |
| 22 | + model_type: ModelType, |
| 23 | + uds_path: String, |
| 24 | + otlp_endpoint: Option<String>, |
| 25 | + otlp_service_name: String, |
| 26 | + ) -> Result<Self, BackendError> { |
| 27 | + let model_type_clone = model_type.clone(); |
| 28 | + |
| 29 | + match model_type { |
| 30 | + ModelType::Classifier => { |
| 31 | + return Err(BackendError::Start( |
| 32 | + "`classifier` model type is not supported".to_string(), |
| 33 | + )) |
| 34 | + } |
| 35 | + ModelType::Embedding(pool) => { |
| 36 | + if pool != Pool::Cls && pool != Pool::Mean { |
| 37 | + return Err(BackendError::Start(format!("{pool:?} is not supported in the TEI Python backend. Please open an issue."))); |
| 38 | + } |
| 39 | + pool |
| 40 | + } |
| 41 | + }; |
| 42 | + |
| 43 | + let pool_string = match &model_type_clone { |
| 44 | + ModelType::Classifier => &Pool::Cls, |
| 45 | + ModelType::Embedding(pool) => pool, |
| 46 | + } |
| 47 | + .to_string(); |
| 48 | + |
| 49 | + let backend_process = management::BackendProcess::new( |
| 50 | + model_path, |
| 51 | + dtype, |
| 52 | + &uds_path, |
| 53 | + otlp_endpoint, |
| 54 | + otlp_service_name, |
| 55 | + pool_string, |
| 56 | + )?; |
| 57 | + let tokio_runtime = tokio::runtime::Builder::new_current_thread() |
| 58 | + .enable_all() |
| 59 | + .build() |
| 60 | + .map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?; |
| 61 | + |
| 62 | + let backend_client = tokio_runtime |
| 63 | + .block_on(Client::connect_uds(uds_path)) |
| 64 | + .map_err(|err| { |
| 65 | + BackendError::Start(format!("Could not connect to backend process: {err}")) |
| 66 | + })?; |
| 67 | + |
| 68 | + Ok(Self { |
| 69 | + _backend_process: backend_process, |
| 70 | + tokio_runtime, |
| 71 | + backend_client, |
| 72 | + }) |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +impl Backend for PythonBackend { |
| 77 | + fn health(&self) -> Result<(), BackendError> { |
| 78 | + if self |
| 79 | + .tokio_runtime |
| 80 | + .block_on(self.backend_client.clone().health()) |
| 81 | + .is_err() |
| 82 | + { |
| 83 | + return Err(BackendError::Unhealthy); |
| 84 | + } |
| 85 | + Ok(()) |
| 86 | + } |
| 87 | + |
| 88 | + fn is_padded(&self) -> bool { |
| 89 | + false |
| 90 | + } |
| 91 | + |
| 92 | + fn embed(&self, batch: Batch) -> Result<Embeddings, BackendError> { |
| 93 | + if !batch.raw_indices.is_empty() { |
| 94 | + return Err(BackendError::Inference( |
| 95 | + "raw embeddings are not supported for the Python backend.".to_string(), |
| 96 | + )); |
| 97 | + } |
| 98 | + let batch_size = batch.len(); |
| 99 | + |
| 100 | + let results = self |
| 101 | + .tokio_runtime |
| 102 | + .block_on(self.backend_client.clone().embed( |
| 103 | + batch.input_ids, |
| 104 | + batch.token_type_ids, |
| 105 | + batch.position_ids, |
| 106 | + batch.cumulative_seq_lengths, |
| 107 | + batch.max_length, |
| 108 | + )) |
| 109 | + .map_err(|err| BackendError::Inference(err.to_string()))?; |
| 110 | + let pooled_embeddings: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect(); |
| 111 | + |
| 112 | + let mut embeddings = |
| 113 | + HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); |
| 114 | + for (i, e) in pooled_embeddings.into_iter().enumerate() { |
| 115 | + embeddings.insert(i, Embedding::Pooled(e)); |
| 116 | + } |
| 117 | + |
| 118 | + Ok(embeddings) |
| 119 | + } |
| 120 | + |
| 121 | + fn predict(&self, _batch: Batch) -> Result<Predictions, BackendError> { |
| 122 | + Err(BackendError::Inference( |
| 123 | + "`predict` is not implemented".to_string(), |
| 124 | + )) |
| 125 | + } |
| 126 | +} |
0 commit comments