Skip to content

Commit 2a2993a

Browse files
committed
support mean pooling in python backend
1 parent 6cdd454 commit 2a2993a

File tree

8 files changed

+64
-17
lines changed

8 files changed

+64
-17
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
from flash_attn.bert_padding import pad_input
3+
4+
from loguru import logger
5+
6+
def mean_pooling(embedding, cu_seqlens, max_s):
7+
# Ideally, rust would pass `indices` to the FlashBatch.
8+
seqlens = cu_seqlens[1:].clone()
9+
seqlens[0] = cu_seqlens[1]
10+
seqlens[1:] -= cu_seqlens[1:-1]
11+
batch_size = len(seqlens)
12+
13+
# Example: indices = [0, 1, 2, 3, 7, 8, 9, 10, 11, 12, 13]
14+
mask = torch.zeros(batch_size, max_s, dtype=torch.int32, device=cu_seqlens.device)
15+
mask[torch.arange(max_s) < seqlens[:, None].cpu()] = 1
16+
indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
17+
18+
embedding_padded = pad_input(embedding, indices, batch_size, max_s)
19+
20+
sum_embeddings = torch.sum(embedding_padded, 1)
21+
22+
return sum_embeddings / seqlens[:, None]

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
__all__.append(FlashBert)
2626

2727

28-
class
29-
3028
def get_model(model_path: Path, dtype: Optional[str], pooling_mode: str):
3129
if dtype == "float32":
3230
dtype = torch.float32

backends/python/server/text_embeddings_server/models/default_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88

99
from text_embeddings_server.models import Model
1010
from text_embeddings_server.models.types import PaddedBatch, Embedding
11+
from typing import Optional
1112

1213
tracer = trace.get_tracer(__name__)
1314

1415

1516
class DefaultModel(Model):
16-
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
17+
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype, pooling_mode: Optional[str]):
1718
model = AutoModel.from_pretrained(model_path).to(dtype).to(device)
1819
self.hidden_size = model.config.hidden_size
20+
self.pooling_mode = pooling_mode
1921

2022
self.has_position_ids = (
2123
inspect.signature(model.forward).parameters.get("position_ids", None)

backends/python/server/text_embeddings_server/models/flash_bert.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from text_embeddings_server.models.types import FlashBatch, Embedding
1313
from text_embeddings_server.layers.attention import attention
1414
from text_embeddings_server.layers.layernorm import FastLayerNorm
15-
from loguru import logger
15+
from text_embeddings_server.layers.pooling import mean_pooling
16+
from typing import Optional
1617

1718
tracer = trace.get_tracer(__name__)
1819

@@ -190,12 +191,13 @@ def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s):
190191

191192

192193
class FlashBert(Model):
193-
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
194+
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype, pooling_mode: Optional[str]):
194195
config = BertConfig.from_pretrained(model_path)
195196
with safe_open(model_path / "model.safetensors", framework="pt") as f:
196197
model = FlashBertModel(f, device, dtype, config)
197198

198199
self.hidden_size = config.hidden_size
200+
self.pooling_mode = pooling_mode
199201

200202
super(FlashBert, self).__init__(model=model, dtype=dtype, device=device)
201203

@@ -205,7 +207,6 @@ def batch_type(self) -> Type[FlashBatch]:
205207

206208
@tracer.start_as_current_span("embed")
207209
def embed(self, batch: FlashBatch) -> List[Embedding]:
208-
logger.info(f"batch.input_ids {batch.input_ids}")
209210
embedding = self.model.forward(
210211
input_ids=batch.input_ids,
211212
token_type_ids=batch.token_type_ids,
@@ -214,9 +215,8 @@ def embed(self, batch: FlashBatch) -> List[Embedding]:
214215
max_s=batch.max_s,
215216
)
216217

217-
if True:
218+
if self.pooling_mode == "cls":
218219
embedding = embedding[batch.cu_seqlens[:-1]]
219-
logger.info(f"embedding {embedding.shape}")
220220
cpu_results = embedding.view(-1).tolist()
221221

222222
return [
@@ -225,4 +225,14 @@ def embed(self, batch: FlashBatch) -> List[Embedding]:
225225
)
226226
for i in range(len(batch))
227227
]
228-
elif
228+
elif self.pooling_mode == "mean":
229+
res = mean_pooling(embedding, batch.cu_seqlens, batch.max_s)
230+
return [
231+
Embedding(
232+
values=res[i]
233+
)
234+
for i in range(len(batch))
235+
]
236+
237+
else:
238+
raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend")

backends/python/src/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ impl PythonBackend {
2323
uds_path: String,
2424
otlp_endpoint: Option<String>,
2525
otlp_service_name: String,
26+
pooling_mode: String,
2627
) -> Result<Self, BackendError> {
2728
match model_type {
2829
ModelType::Classifier => {
@@ -31,8 +32,8 @@ impl PythonBackend {
3132
))
3233
}
3334
ModelType::Embedding(pool) => {
34-
if pool != Pool::Cls {
35-
return Err(BackendError::Start(format!("{pool:?} is not supported")));
35+
if pool != Pool::Cls && pool != Pool::Mean {
36+
return Err(BackendError::Start(format!("{pool:?} is not supported in the TEI Python backend. Please open an issue.")));
3637
}
3738
pool
3839
}
@@ -44,6 +45,7 @@ impl PythonBackend {
4445
&uds_path,
4546
otlp_endpoint,
4647
otlp_service_name,
48+
pooling_mode,
4749
)?;
4850
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
4951
.enable_all()

backends/python/src/management.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ impl BackendProcess {
2222
uds_path: &str,
2323
otlp_endpoint: Option<String>,
2424
otlp_service_name: String,
25+
pooling_mode: String,
2526
) -> Result<Self, BackendError> {
2627
// Get UDS path
2728
let uds = Path::new(uds_path);
@@ -52,6 +53,9 @@ impl BackendProcess {
5253
python_server_args.push("--otlp-service-name".to_owned());
5354
python_server_args.push(otlp_service_name);
5455

56+
python_server_args.push("--pooling-mode".to_owned());
57+
python_server_args.push(pooling_mode);
58+
5559
// Copy current process env
5660
let envs: Vec<(OsString, OsString)> = env::vars_os().collect();
5761

backends/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ impl Backend {
3939
uds_path: String,
4040
otlp_endpoint: Option<String>,
4141
otlp_service_name: String,
42+
pooling_mode: String,
4243
) -> Result<Self, BackendError> {
4344
let (backend_sender, backend_receiver) = mpsc::unbounded_channel();
4445

@@ -49,6 +50,7 @@ impl Backend {
4950
uds_path,
5051
otlp_endpoint,
5152
otlp_service_name,
53+
pooling_mode,
5254
)?;
5355
let padded_model = backend.is_padded();
5456
let max_batch_size = backend.max_batch_size();
@@ -138,6 +140,7 @@ fn init_backend(
138140
uds_path: String,
139141
otlp_endpoint: Option<String>,
140142
otlp_service_name: String,
143+
pooling_mode: String,
141144
) -> Result<Box<dyn CoreBackend + Send>, BackendError> {
142145
if cfg!(feature = "candle") {
143146
#[cfg(feature = "candle")]
@@ -158,6 +161,7 @@ fn init_backend(
158161
uds_path,
159162
otlp_endpoint,
160163
otlp_service_name,
164+
pooling_mode,
161165
)
162166
})
163167
.join()

router/src/lib.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ pub async fn run(
105105
serde_json::from_str(&config).context("Failed to parse `config.json`")?;
106106

107107
// Set model type from config
108-
let backend_model_type = get_backend_model_type(&config, &model_root, pooling)?;
108+
let backend_model_type = get_backend_model_type(&config, &model_root, &pooling)?;
109109

110110
// Info model type
111111
let model_type = match &backend_model_type {
@@ -191,6 +191,11 @@ pub async fn run(
191191
}
192192
});
193193

194+
let pooling_str = match pooling {
195+
Some(pool) => pool.to_string(),
196+
None => "none".to_string(),
197+
};
198+
194199
// Create backend
195200
tracing::info!("Starting model backend");
196201
let backend = text_embeddings_backend::Backend::new(
@@ -200,7 +205,7 @@ pub async fn run(
200205
uds_path.unwrap_or("/tmp/text-embeddings-inference-server".to_string()),
201206
otlp_endpoint.clone(),
202207
otlp_service_name.clone(),
203-
pooling.to_string(),
208+
pooling_str,
204209
)
205210
.context("Could not create backend")?;
206211
backend
@@ -307,10 +312,10 @@ pub async fn run(
307312
fn get_backend_model_type(
308313
config: &ModelConfig,
309314
model_root: &Path,
310-
pooling: Option<text_embeddings_backend::Pool>,
315+
pooling: &Option<text_embeddings_backend::Pool>,
311316
) -> Result<text_embeddings_backend::ModelType> {
312317
for arch in &config.architectures {
313-
if Some(text_embeddings_backend::Pool::Splade) == pooling && arch.ends_with("MaskedLM") {
318+
if Some(text_embeddings_backend::Pool::Splade) == *pooling && arch.ends_with("MaskedLM") {
314319
return Ok(text_embeddings_backend::ModelType::Embedding(
315320
text_embeddings_backend::Pool::Splade,
316321
));
@@ -324,15 +329,15 @@ fn get_backend_model_type(
324329
}
325330
}
326331

327-
if Some(text_embeddings_backend::Pool::Splade) == pooling {
332+
if Some(text_embeddings_backend::Pool::Splade) == *pooling {
328333
return Err(anyhow!(
329334
"Splade pooling is not supported: model is not a ForMaskedLM model"
330335
));
331336
}
332337

333338
// Set pooling
334339
let pool = match pooling {
335-
Some(pool) => pool,
340+
Some(pool) => pool.clone(),
336341
None => {
337342
// Load pooling config
338343
let config_path = model_root.join("1_Pooling/config.json");

0 commit comments

Comments
 (0)