Skip to content

Commit e27a4fb

Browse files
feat: Python pooling (#442)
Co-authored-by: Oskar Liew <43787643+OskarLiew@users.noreply.github.com>
1 parent 0bfeb7e commit e27a4fb

File tree

10 files changed

+1749
-640
lines changed

10 files changed

+1749
-640
lines changed

backends/python/server/Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ unit-tests:
66

77
gen-server:
88
# Compile protos
9-
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir
9+
pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir
1010
mkdir text_embeddings_server/pb || true
1111
python -m grpc_tools.protoc -I../../proto --python_out=text_embeddings_server/pb \
1212
--grpc_python_out=text_embeddings_server/pb --mypy_out=text_embeddings_server/pb ../../proto/embed.proto
@@ -15,6 +15,7 @@ gen-server:
1515

1616
install: gen-server
1717
pip install pip --upgrade
18+
pip install torch==2.5.1
1819
pip install -r requirements.txt
1920
pip install -e .
2021

backends/python/server/poetry.lock

Lines changed: 1641 additions & 578 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backends/python/server/pyproject.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,19 @@ python-text-embeddings-server = 'text_embeddings_server.cli:app'
99

1010
[tool.poetry.dependencies]
1111
python = ">=3.9,<3.13"
12-
protobuf = "^4.21.7"
12+
protobuf = ">=4.25.3,<6"
1313
grpcio = "^1.51.1"
1414
grpcio-status = "^1.51.1"
1515
grpcio-reflection = "^1.51.1"
1616
grpc-interceptor = "^0.15.0"
1717
typer = "^0.6.1"
18-
safetensors = "^0.3.2"
18+
safetensors = "^0.4"
1919
loguru = "^0.6.0"
20-
opentelemetry-api = "^1.15.0"
21-
opentelemetry-exporter-otlp = "^1.15.0"
22-
opentelemetry-instrumentation-grpc = "^0.36b0"
23-
torch = { version = "^2.0.1" }
20+
opentelemetry-api = "^1.25.0"
21+
opentelemetry-exporter-otlp = "^1.25.0"
22+
opentelemetry-instrumentation-grpc = "^0.46b0"
23+
sentence-transformers = "^3.3.1"
24+
torch = "^2.5.1"
2425

2526
[tool.poetry.extras]
2627

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,68 @@
1-
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
2-
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
3-
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
1+
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
2+
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
43
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
54
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
6-
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
7-
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
8-
fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13"
9-
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
10-
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
11-
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
12-
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
13-
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
14-
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
15-
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
16-
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
5+
deprecated==1.2.15 ; python_version >= "3.9" and python_version < "3.13"
6+
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
7+
fsspec==2024.10.0 ; python_version >= "3.9" and python_version < "3.13"
8+
googleapis-common-protos==1.66.0 ; python_version >= "3.9" and python_version < "3.13"
9+
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
10+
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
11+
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
12+
grpcio==1.68.0 ; python_version >= "3.9" and python_version < "3.13"
13+
huggingface-hub==0.26.2 ; python_version >= "3.9" and python_version < "3.13"
14+
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
15+
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
16+
jinja2==3.1.4 ; python_version >= "3.9" and python_version < "3.13"
17+
joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
1718
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
18-
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
19+
markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13"
1920
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
20-
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
21-
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
22-
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
23-
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
24-
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
25-
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
26-
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
27-
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
28-
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
29-
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
30-
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
31-
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
32-
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
33-
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
34-
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
35-
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
36-
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
37-
torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
38-
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
21+
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
22+
numpy==2.0.2 ; python_version >= "3.9" and python_version < "3.13"
23+
nvidia-cublas-cu12==12.4.5.8 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
24+
nvidia-cuda-cupti-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
25+
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
26+
nvidia-cuda-runtime-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
27+
nvidia-cudnn-cu12==9.1.0.70 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
28+
nvidia-cufft-cu12==11.2.1.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
29+
nvidia-curand-cu12==10.3.5.147 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
30+
nvidia-cusolver-cu12==11.6.1.9 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
31+
nvidia-cusparse-cu12==12.3.1.170 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
32+
nvidia-nccl-cu12==2.21.5 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
33+
nvidia-nvjitlink-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
34+
nvidia-nvtx-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
35+
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
36+
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
37+
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
38+
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
39+
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
40+
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
41+
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
42+
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
43+
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
44+
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
45+
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
46+
pillow==11.0.0 ; python_version >= "3.9" and python_version < "3.13"
47+
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
48+
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
49+
regex==2024.11.6 ; python_version >= "3.9" and python_version < "3.13"
50+
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
51+
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
52+
scikit-learn==1.5.2 ; python_version >= "3.9" and python_version < "3.13"
53+
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
54+
sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13"
55+
setuptools==75.6.0 ; python_version >= "3.9" and python_version < "3.13"
56+
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
57+
threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
58+
tokenizers==0.20.3 ; python_version >= "3.9" and python_version < "3.13"
59+
torch==2.5.1 ; python_version >= "3.9" and python_version < "3.13"
60+
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
61+
transformers==4.46.3 ; python_version >= "3.9" and python_version < "3.13"
62+
triton==3.1.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9"
3963
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
40-
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
41-
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
64+
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
65+
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
4266
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
43-
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
67+
wrapt==1.17.0 ; python_version >= "3.9" and python_version < "3.13"
68+
zipp==3.21.0 ; python_version >= "3.9" and python_version < "3.13"

backends/python/server/text_embeddings_server/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def serve(
2424
json_output: bool = False,
2525
otlp_endpoint: Optional[str] = None,
2626
otlp_service_name: str = "text-embeddings-inference.server",
27+
pool: str = "cls",
2728
):
2829
# Remove default handler
2930
logger.remove()
@@ -48,7 +49,7 @@ def serve(
4849
# Downgrade enum into str for easier management later on
4950
dtype = None if dtype is None else dtype.value
5051

51-
server.serve(model_path, dtype, uds_path)
52+
server.serve(model_path, dtype, uds_path, pool)
5253

5354

5455
if __name__ == "__main__":

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
__all__.append(FlashBert)
2626

2727

28-
def get_model(model_path: Path, dtype: Optional[str]):
28+
def get_model(model_path: Path, dtype: Optional[str], pool: str):
2929
if dtype == "float32":
3030
dtype = torch.float32
3131
elif dtype == "float16":
@@ -38,8 +38,6 @@ def get_model(model_path: Path, dtype: Optional[str]):
3838
if torch.cuda.is_available():
3939
device = torch.device("cuda")
4040
else:
41-
if dtype != torch.float32:
42-
raise ValueError("CPU device only supports float32 dtype")
4341
device = torch.device("cpu")
4442

4543
config = AutoConfig.from_pretrained(model_path)
@@ -52,8 +50,10 @@ def get_model(model_path: Path, dtype: Optional[str]):
5250
and dtype in [torch.float16, torch.bfloat16]
5351
and FLASH_ATTENTION
5452
):
53+
if pool != "cls":
54+
raise ValueError("FlashBert only supports cls pooling")
5555
return FlashBert(model_path, device, dtype)
5656
else:
57-
return DefaultModel(model_path, device, dtype)
57+
return DefaultModel(model_path, device, dtype, pool)
5858

5959
raise NotImplementedError

backends/python/server/text_embeddings_server/models/default_model.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Type, List
66
from transformers import AutoModel
77
from opentelemetry import trace
8+
from sentence_transformers.models import Pooling
89

910
from text_embeddings_server.models import Model
1011
from text_embeddings_server.models.types import PaddedBatch, Embedding
@@ -13,9 +14,12 @@
1314

1415

1516
class DefaultModel(Model):
16-
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
17+
def __init__(
18+
self, model_path: Path, device: torch.device, dtype: torch.dtype, pool: str
19+
):
1720
model = AutoModel.from_pretrained(model_path).to(dtype).to(device)
1821
self.hidden_size = model.config.hidden_size
22+
self.pooling = Pooling(self.hidden_size, pooling_mode=pool)
1923

2024
self.has_position_ids = (
2125
inspect.signature(model.forward).parameters.get("position_ids", None)
@@ -41,7 +45,13 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
4145
kwargs["position_ids"] = batch.position_ids
4246

4347
output = self.model(**kwargs)
44-
embedding = output[0][:, 0]
48+
49+
pooling_features = {
50+
"token_embeddings": output[0],
51+
"attention_mask": batch.attention_mask,
52+
}
53+
embedding = self.pooling.forward(pooling_features)["sentence_embedding"]
54+
4555
cpu_results = embedding.view(-1).tolist()
4656

4757
return [

backends/python/server/text_embeddings_server/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import torch
3-
43
from grpc import aio
54
from loguru import logger
65

@@ -37,6 +36,7 @@ def serve(
3736
model_path: Path,
3837
dtype: Optional[str],
3938
uds_path: Path,
39+
pool: str,
4040
):
4141
async def serve_inner(
4242
model_path: Path,
@@ -45,7 +45,7 @@ async def serve_inner(
4545
unix_socket = f"unix://{uds_path}"
4646

4747
try:
48-
model = get_model(model_path, dtype)
48+
model = get_model(model_path, dtype, pool)
4949
except Exception:
5050
logger.exception("Error when initializing model")
5151
raise

backends/python/src/lib.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use backend_grpc_client::Client;
55
use nohash_hasher::BuildNoHashHasher;
66
use std::collections::HashMap;
77
use text_embeddings_backend_core::{
8-
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
8+
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
99
};
1010
use tokio::runtime::Runtime;
1111

@@ -24,18 +24,13 @@ impl PythonBackend {
2424
otlp_endpoint: Option<String>,
2525
otlp_service_name: String,
2626
) -> Result<Self, BackendError> {
27-
match model_type {
27+
let pool = match model_type {
2828
ModelType::Classifier => {
2929
return Err(BackendError::Start(
3030
"`classifier` model type is not supported".to_string(),
3131
))
3232
}
33-
ModelType::Embedding(pool) => {
34-
if pool != Pool::Cls {
35-
return Err(BackendError::Start(format!("{pool:?} is not supported")));
36-
}
37-
pool
38-
}
33+
ModelType::Embedding(pool) => pool,
3934
};
4035

4136
let backend_process = management::BackendProcess::new(
@@ -44,6 +39,7 @@ impl PythonBackend {
4439
&uds_path,
4540
otlp_endpoint,
4641
otlp_service_name,
42+
pool,
4743
)?;
4844
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
4945
.enable_all()

backends/python/src/management.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use std::sync::mpsc;
88
use std::thread::sleep;
99
use std::time::{Duration, Instant};
1010
use std::{env, fs, io, thread};
11-
use text_embeddings_backend_core::BackendError;
11+
use text_embeddings_backend_core::{BackendError, Pool};
1212

1313
#[derive(Debug)]
1414
pub(crate) struct BackendProcess {
@@ -22,6 +22,7 @@ impl BackendProcess {
2222
uds_path: &str,
2323
otlp_endpoint: Option<String>,
2424
otlp_service_name: String,
25+
pool: Pool,
2526
) -> Result<Self, BackendError> {
2627
// Get UDS path
2728
let uds = Path::new(uds_path);
@@ -31,6 +32,15 @@ impl BackendProcess {
3132
fs::remove_file(uds).expect("could not remove UDS file");
3233
}
3334

35+
let pool = match pool {
36+
Pool::Cls => "cls",
37+
Pool::Mean => "mean",
38+
Pool::LastToken => "lasttoken",
39+
Pool::Splade => {
40+
return Err(BackendError::Start(format!("{pool:?} is not supported")));
41+
}
42+
};
43+
3444
// Process args
3545
let mut python_server_args = vec![
3646
model_path,
@@ -41,6 +51,8 @@ impl BackendProcess {
4151
"--logger-level".to_owned(),
4252
"INFO".to_owned(),
4353
"--json-output".to_owned(),
54+
"--pool".to_owned(),
55+
pool.to_owned(),
4456
];
4557

4658
// OpenTelemetry

0 commit comments

Comments
 (0)