Skip to content

Commit c0b7080

Browse files
authored
Fix bug for MaskedLanguageModel class` (#513)
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 98ceed9 commit c0b7080

File tree

4 files changed

+23
-25
lines changed

4 files changed

+23
-25
lines changed

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,11 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
5454
and FLASH_ATTENTION
5555
):
5656
if pool != "cls":
57-
if config.architectures[0].endswith("ForMaskedLM"):
57+
if config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
5858
return MaskedLanguageModel(
5959
model_path,
6060
device,
6161
datatype,
62-
pool,
6362
trust_remote=TRUST_REMOTE_CODE,
6463
)
6564
return DefaultModel(
@@ -70,9 +69,9 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
7069
return ClassificationModel(
7170
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
7271
)
73-
elif config.architectures[0].endswith("ForMaskedLM"):
72+
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
7473
return MaskedLanguageModel(
75-
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
74+
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
7675
)
7776
else:
7877
return DefaultModel(
@@ -97,9 +96,9 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
9796
datatype,
9897
trust_remote=TRUST_REMOTE_CODE,
9998
)
100-
elif config.architectures[0].endswith("ForMaskedLM"):
101-
return MaskedLanguageModel(
102-
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
99+
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
100+
model_handle = MaskedLanguageModel(
101+
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
103102
)
104103
else:
105104
model_handle = DefaultModel(
@@ -119,9 +118,9 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
119118
datatype,
120119
trust_remote=TRUST_REMOTE_CODE,
121120
)
122-
elif config.architectures[0].endswith("ForMaskedLM"):
121+
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
123122
return MaskedLanguageModel(
124-
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
123+
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
125124
)
126125
else:
127126
return DefaultModel(

backends/python/server/text_embeddings_server/models/default_model.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Type, List
66
from transformers import AutoModel
77
from opentelemetry import trace
8-
from sentence_transformers.models import Pooling
8+
from text_embeddings_server.models.pooling import DefaultPooling
99

1010
from text_embeddings_server.models import Model
1111
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score
@@ -28,7 +28,7 @@ def __init__(
2828
.to(device)
2929
)
3030
self.hidden_size = model.config.hidden_size
31-
self.pooling = Pooling(self.hidden_size, pooling_mode=pool)
31+
self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool)
3232

3333
position_offset = 0
3434
model_type = model.config.model_type
@@ -65,11 +65,7 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
6565
kwargs["position_ids"] = batch.position_ids
6666
output = self.model(**kwargs)
6767

68-
pooling_features = {
69-
"token_embeddings": output[0],
70-
"attention_mask": batch.attention_mask,
71-
}
72-
embedding = self.pooling.forward(pooling_features)["sentence_embedding"]
68+
embedding = self.pooling.forward(output, batch.attention_mask)
7369

7470
cpu_results = embedding.view(-1).tolist()
7571

backends/python/server/text_embeddings_server/models/masked_model.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from text_embeddings_server.models import Model
1010
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score
11-
from text_embeddings_server.models.pooling import DefaultPooling, SpladePooling
11+
from text_embeddings_server.models.pooling import SpladePooling
1212

1313
tracer = trace.get_tracer(__name__)
1414

@@ -19,7 +19,6 @@ def __init__(
1919
model_path: Path,
2020
device: torch.device,
2121
dtype: torch.dtype,
22-
pool: str,
2322
trust_remote: bool = False,
2423
):
2524
model = (
@@ -29,14 +28,17 @@ def __init__(
2928
.to(dtype)
3029
.to(device)
3130
)
32-
self.hidden_size = model.config.hidden_size
33-
self.vocab_size = model.config.vocab_size
34-
self.pooling_mode = pool
35-
if pool == "splade":
36-
self.pooling = SpladePooling()
31+
self.pooling = SpladePooling()
32+
position_offset = 0
33+
model_type = model.config.model_type
34+
if model_type in ["xlm-roberta", "camembert", "roberta"]:
35+
position_offset = model.config.pad_token_id + 1
36+
if hasattr(model.config, "max_seq_length"):
37+
self.max_input_length = model.config.max_seq_length
3738
else:
38-
self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool)
39-
39+
self.max_input_length = (
40+
model.config.max_position_embeddings - position_offset
41+
)
4042
self.has_position_ids = (
4143
inspect.signature(model.forward).parameters.get("position_ids", None)
4244
is not None

backends/python/server/text_embeddings_server/models/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import math
23
import torch
34

45
from abc import ABC, abstractmethod

0 commit comments

Comments
 (0)