Skip to content

Commit 37d2931

Browse files
committed
working cls pooling
1 parent 2fc644c commit 37d2931

File tree

8 files changed

+38
-157
lines changed

8 files changed

+38
-157
lines changed

backends/python/server/pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ loguru = "^0.6.0"
2020
opentelemetry-api = "^1.15.0"
2121
opentelemetry-exporter-otlp = "^1.15.0"
2222
opentelemetry-instrumentation-grpc = "^0.36b0"
23-
torch = { version = "^2.0.1" }
23+
torch = { version = "==2.3.1" }
2424

2525
[tool.poetry.extras]
2626

@@ -33,6 +33,11 @@ name = "pytorch-gpu-src"
3333
url = "https://download.pytorch.org/whl/cu118"
3434
priority = "explicit"
3535

36+
[[tool.poetry.source]]
37+
name = "pytorch-gpu-src-rocm"
38+
url = "https://download.pytorch.org/whl/rocm6.0"
39+
priority = "explicit"
40+
3641
[tool.pytest.ini_options]
3742
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
3843

backends/python/server/requirements.txt

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,13 @@ charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
44
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
55
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
66
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
7-
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
8-
fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13"
97
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
108
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
119
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
1210
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
1311
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
14-
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
1512
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
16-
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
1713
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
18-
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
19-
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
20-
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
2114
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
2215
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
2316
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@@ -27,15 +20,10 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
2720
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
2821
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
2922
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
30-
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
3123
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
32-
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
3324
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
3425
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
3526
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
36-
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
37-
torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
38-
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
3927
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
4028
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
4129
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"

backends/python/server/text_embeddings_server/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def serve(
2323
logger_level: str = "INFO",
2424
json_output: bool = False,
2525
otlp_endpoint: Optional[str] = None,
26+
pooling_mode: Optional[str] = None,
2627
):
2728
# Remove default handler
2829
logger.remove()
@@ -47,7 +48,7 @@ def serve(
4748
# Downgrade enum into str for easier management later on
4849
dtype = None if dtype is None else dtype.value
4950

50-
server.serve(model_path, dtype, uds_path)
51+
server.serve(model_path, dtype, uds_path, pooling_mode)
5152

5253

5354
if __name__ == "__main__":

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,19 @@
1515
torch.set_grad_enabled(False)
1616

1717
FLASH_ATTENTION = True
18-
try:
19-
from text_embeddings_server.models.flash_bert import FlashBert
20-
except ImportError as e:
21-
logger.warning(f"Could not import Flash Attention enabled models: {e}")
22-
FLASH_ATTENTION = False
18+
# try:
19+
from text_embeddings_server.models.flash_bert import FlashBert
20+
# except ImportError as e:
21+
# logger.warning(f"Could not import Flash Attention enabled models: {e}")
22+
# FLASH_ATTENTION = False
2323

2424
if FLASH_ATTENTION:
2525
__all__.append(FlashBert)
2626

2727

28-
def get_model(model_path: Path, dtype: Optional[str]):
28+
class
29+
30+
def get_model(model_path: Path, dtype: Optional[str], pooling_mode: str):
2931
if dtype == "float32":
3032
dtype = torch.float32
3133
elif dtype == "float16":
@@ -52,8 +54,8 @@ def get_model(model_path: Path, dtype: Optional[str]):
5254
and dtype in [torch.float16, torch.bfloat16]
5355
and FLASH_ATTENTION
5456
):
55-
return FlashBert(model_path, device, dtype)
57+
return FlashBert(model_path, device, dtype, pooling_mode)
5658
else:
57-
return DefaultModel(model_path, device, dtype)
59+
return DefaultModel(model_path, device, dtype, pooling_mode)
5860

5961
raise NotImplementedError

backends/python/server/text_embeddings_server/models/flash_bert.py

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,46 +8,15 @@
88
from transformers.models.bert import BertConfig
99
from opentelemetry import trace
1010

11-
# Flash attention imports
12-
import dropout_layer_norm
13-
1411
from text_embeddings_server.models import Model
1512
from text_embeddings_server.models.types import FlashBatch, Embedding
16-
from text_embeddings_server.utils.flash_attn import attention
13+
from text_embeddings_server.layers.attention import attention
14+
from text_embeddings_server.layers.layernorm import FastLayerNorm
15+
from loguru import logger
1716

1817
tracer = trace.get_tracer(__name__)
1918

2019

21-
class FastLayerNorm:
22-
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
23-
self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device)
24-
self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device)
25-
self.variance_epsilon = config.layer_norm_eps
26-
27-
def forward(self, hidden_states, residual=None):
28-
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
29-
hidden_states,
30-
residual,
31-
self.weight,
32-
self.bias,
33-
None,
34-
None,
35-
None,
36-
None,
37-
0.0,
38-
self.variance_epsilon,
39-
1.0,
40-
0,
41-
None,
42-
False,
43-
False,
44-
)
45-
if res is None:
46-
res = hidden_states
47-
48-
return normed_hidden_states, res
49-
50-
5120
class BertEmbeddings:
5221
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
5322
self.word_embeddings_weight = (
@@ -217,7 +186,7 @@ def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s):
217186
embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids)
218187
encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s)
219188

220-
return encoder_outputs[cu_seqlens[:-1]]
189+
return encoder_outputs
221190

222191

223192
class FlashBert(Model):
@@ -236,18 +205,24 @@ def batch_type(self) -> Type[FlashBatch]:
236205

237206
@tracer.start_as_current_span("embed")
238207
def embed(self, batch: FlashBatch) -> List[Embedding]:
208+
logger.info(f"batch.input_ids {batch.input_ids}")
239209
embedding = self.model.forward(
240210
input_ids=batch.input_ids,
241211
token_type_ids=batch.token_type_ids,
242212
position_ids=batch.position_ids,
243213
cu_seqlens=batch.cu_seqlens,
244214
max_s=batch.max_s,
245215
)
246-
cpu_results = embedding.view(-1).tolist()
247216

248-
return [
249-
Embedding(
250-
values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size]
251-
)
252-
for i in range(len(batch))
253-
]
217+
if True:
218+
embedding = embedding[batch.cu_seqlens[:-1]]
219+
logger.info(f"embedding {embedding.shape}")
220+
cpu_results = embedding.view(-1).tolist()
221+
222+
return [
223+
Embedding(
224+
values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size]
225+
)
226+
for i in range(len(batch))
227+
]
228+
elif

backends/python/server/text_embeddings_server/server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def serve(
3737
model_path: Path,
3838
dtype: Optional[str],
3939
uds_path: Path,
40+
pooling_mode: Optional[str],
4041
):
4142
async def serve_inner(
4243
model_path: Path,
@@ -45,7 +46,7 @@ async def serve_inner(
4546
unix_socket = f"unix://{uds_path}"
4647

4748
try:
48-
model = get_model(model_path, dtype)
49+
model = get_model(model_path, dtype, pooling_mode)
4950
except Exception:
5051
logger.exception("Error when initializing model")
5152
raise

backends/python/server/text_embeddings_server/utils/flash_attn.py

Lines changed: 0 additions & 92 deletions
This file was deleted.

router/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ pub async fn run(
198198
backend_model_type,
199199
uds_path.unwrap_or("/tmp/text-embeddings-inference-server".to_string()),
200200
otlp_endpoint.clone(),
201+
pooling.to_string(),
201202
)
202203
.context("Could not create backend")?;
203204
backend

0 commit comments

Comments
 (0)