Skip to content

Commit 2be18ab

Browse files
authored
Enable splade embeddings for Python backend (#493)
Signed-off-by: Daniel Huang <daniel1.huang@intel.com>
1 parent 4ac772e commit 2be18ab

File tree

4 files changed

+138
-3
lines changed

4 files changed

+138
-3
lines changed

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from transformers.models.bert import BertConfig
99

1010
from text_embeddings_server.models.model import Model
11+
from text_embeddings_server.models.masked_model import MaskedLanguageModel
1112
from text_embeddings_server.models.default_model import DefaultModel
1213
from text_embeddings_server.models.classification_model import ClassificationModel
1314
from text_embeddings_server.utils.device import get_device, use_ipex
@@ -53,6 +54,14 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
5354
and FLASH_ATTENTION
5455
):
5556
if pool != "cls":
57+
if config.architectures[0].endswith("ForMaskedLM"):
58+
return MaskedLanguageModel(
59+
model_path,
60+
device,
61+
datatype,
62+
pool,
63+
trust_remote=TRUST_REMOTE_CODE,
64+
)
5665
return DefaultModel(
5766
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
5867
)
@@ -61,6 +70,10 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
6170
return ClassificationModel(
6271
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
6372
)
73+
elif config.architectures[0].endswith("ForMaskedLM"):
74+
return MaskedLanguageModel(
75+
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
76+
)
6477
else:
6578
return DefaultModel(
6679
model_path,
@@ -84,6 +97,10 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
8497
datatype,
8598
trust_remote=TRUST_REMOTE_CODE,
8699
)
100+
elif config.architectures[0].endswith("ForMaskedLM"):
101+
return MaskedLanguageModel(
102+
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
103+
)
87104
else:
88105
model_handle = DefaultModel(
89106
model_path,
@@ -102,6 +119,10 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
102119
datatype,
103120
trust_remote=TRUST_REMOTE_CODE,
104121
)
122+
elif config.architectures[0].endswith("ForMaskedLM"):
123+
return MaskedLanguageModel(
124+
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
125+
)
105126
else:
106127
return DefaultModel(
107128
model_path,
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import inspect
2+
import torch
3+
4+
from pathlib import Path
5+
from typing import Type, List
6+
from transformers import AutoModelForMaskedLM
7+
from opentelemetry import trace
8+
9+
from text_embeddings_server.models import Model
10+
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score
11+
from text_embeddings_server.models.pooling import DefaultPooling, SpladePooling
12+
13+
tracer = trace.get_tracer(__name__)
14+
15+
16+
class MaskedLanguageModel(Model):
17+
def __init__(
18+
self,
19+
model_path: Path,
20+
device: torch.device,
21+
dtype: torch.dtype,
22+
pool: str,
23+
trust_remote: bool = False,
24+
):
25+
model = (
26+
AutoModelForMaskedLM.from_pretrained(
27+
model_path, trust_remote_code=trust_remote
28+
)
29+
.to(dtype)
30+
.to(device)
31+
)
32+
self.hidden_size = model.config.hidden_size
33+
self.vocab_size = model.config.vocab_size
34+
self.pooling_mode = pool
35+
if pool == "splade":
36+
self.pooling = SpladePooling()
37+
else:
38+
self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool)
39+
40+
self.has_position_ids = (
41+
inspect.signature(model.forward).parameters.get("position_ids", None)
42+
is not None
43+
)
44+
self.has_token_type_ids = (
45+
inspect.signature(model.forward).parameters.get("token_type_ids", None)
46+
is not None
47+
)
48+
49+
super(MaskedLanguageModel, self).__init__(
50+
model=model, dtype=dtype, device=device
51+
)
52+
53+
@property
54+
def batch_type(self) -> Type[PaddedBatch]:
55+
return PaddedBatch
56+
57+
@tracer.start_as_current_span("embed")
58+
def embed(self, batch: PaddedBatch) -> List[Embedding]:
59+
kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
60+
if self.has_token_type_ids:
61+
kwargs["token_type_ids"] = batch.token_type_ids
62+
if self.has_position_ids:
63+
kwargs["position_ids"] = batch.position_ids
64+
output = self.model(**kwargs)
65+
embedding = self.pooling.forward(output, batch.attention_mask)
66+
cpu_results = embedding.view(-1).tolist()
67+
68+
step_size = embedding.shape[-1]
69+
return [
70+
Embedding(values=cpu_results[i * step_size : (i + 1) * step_size])
71+
for i in range(len(batch))
72+
]
73+
74+
@tracer.start_as_current_span("predict")
75+
def predict(self, batch: PaddedBatch) -> List[Score]:
76+
pass
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from abc import ABC, abstractmethod
2+
3+
import torch
4+
from opentelemetry import trace
5+
from sentence_transformers.models import Pooling
6+
from torch import Tensor
7+
8+
tracer = trace.get_tracer(__name__)
9+
10+
11+
class _Pooling(ABC):
12+
@abstractmethod
13+
def forward(self, model_output, attention_mask) -> Tensor:
14+
pass
15+
16+
17+
class DefaultPooling(_Pooling):
18+
def __init__(self, hidden_size, pooling_mode) -> None:
19+
assert (
20+
pooling_mode != "splade"
21+
), "Splade pooling is not supported for DefaultPooling"
22+
self.pooling = Pooling(hidden_size, pooling_mode=pooling_mode)
23+
24+
@tracer.start_as_current_span("pooling")
25+
def forward(self, model_output, attention_mask) -> Tensor:
26+
pooling_features = {
27+
"token_embeddings": model_output[0],
28+
"attention_mask": attention_mask,
29+
}
30+
return self.pooling.forward(pooling_features)["sentence_embedding"]
31+
32+
33+
class SpladePooling(_Pooling):
34+
@tracer.start_as_current_span("pooling")
35+
def forward(self, model_output, attention_mask) -> Tensor:
36+
# Implement Splade pooling
37+
hidden_states = torch.relu(model_output[0])
38+
hidden_states = (1 + hidden_states).log()
39+
hidden_states = torch.mul(hidden_states, attention_mask.unsqueeze(-1))
40+
return hidden_states.max(dim=1).values

backends/python/src/management.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ impl BackendProcess {
3636
Pool::Cls => "cls",
3737
Pool::Mean => "mean",
3838
Pool::LastToken => "lasttoken",
39-
Pool::Splade => {
40-
return Err(BackendError::Start(format!("{pool:?} is not supported")));
41-
}
39+
Pool::Splade => "splade",
4240
};
4341

4442
// Process args

0 commit comments

Comments
 (0)