Skip to content

Commit a8c02db

Browse files
committed
add tests
1 parent 36b3a72 commit a8c02db

File tree

11 files changed

+183
-9
lines changed

11 files changed

+183
-9
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
.idea
22
target
3+
__pycache__/

backends/python/server/text_embeddings_server/layers/attention/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import os
33

44
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
5-
raise ImportError("`USE_FLASH_ATTENTION` is false.")
5+
class Attention:
6+
def __getattr__(self, name):
7+
raise RuntimeError(f"TEI is used with USE_FLASH_ATTENTION=false, accessing `attention` is prohibited")
8+
attention = Attention()
69
if SYSTEM == "cuda":
710
from .cuda import attention
811
elif SYSTEM == "rocm":

backends/python/server/text_embeddings_server/models/flash_bert.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,5 @@ def embed(self, batch: FlashBatch) -> List[Embedding]:
233233
)
234234
for i in range(len(batch))
235235
]
236-
237236
else:
238237
raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend")

router/src/lib.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ pub async fn run(
105105
serde_json::from_str(&config).context("Failed to parse `config.json`")?;
106106

107107
// Set model type from config
108-
let backend_model_type = get_backend_model_type(&config, &model_root, &pooling)?;
108+
let (backend_model_type, inferred_pooling) = get_backend_model_type(&config, &model_root, &pooling)?;
109109

110110
// Info model type
111111
let model_type = match &backend_model_type {
@@ -191,7 +191,7 @@ pub async fn run(
191191
}
192192
});
193193

194-
let pooling_str = match pooling {
194+
let pooling_str = match inferred_pooling {
195195
Some(pool) => pool.to_string(),
196196
None => "none".to_string(),
197197
};
@@ -313,19 +313,19 @@ fn get_backend_model_type(
313313
config: &ModelConfig,
314314
model_root: &Path,
315315
pooling: &Option<text_embeddings_backend::Pool>,
316-
) -> Result<text_embeddings_backend::ModelType> {
316+
) -> Result<(text_embeddings_backend::ModelType, Option<text_embeddings_backend::Pool>)> {
317317
for arch in &config.architectures {
318318
if Some(text_embeddings_backend::Pool::Splade) == *pooling && arch.ends_with("MaskedLM") {
319-
return Ok(text_embeddings_backend::ModelType::Embedding(
319+
return Ok((text_embeddings_backend::ModelType::Embedding(
320320
text_embeddings_backend::Pool::Splade,
321-
));
321+
), Some(text_embeddings_backend::Pool::Splade)));
322322
} else if arch.ends_with("Classification") {
323323
if pooling.is_some() {
324324
tracing::warn!(
325325
"`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg."
326326
);
327327
}
328-
return Ok(text_embeddings_backend::ModelType::Classifier);
328+
return Ok((text_embeddings_backend::ModelType::Classifier, None));
329329
}
330330
}
331331

@@ -353,7 +353,7 @@ fn get_backend_model_type(
353353
}
354354
}
355355
};
356-
Ok(text_embeddings_backend::ModelType::Embedding(pool))
356+
Ok((text_embeddings_backend::ModelType::Embedding(pool.clone()), Some(pool)))
357357
}
358358

359359
#[derive(Debug, Deserialize)]

tests/__init__.py

Whitespace-only changes.

tests/assets/default_bert.pt

Whitespace-only changes.

tests/assets/flash_bert.pt

Whitespace-only changes.

tests/conftest.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import pytest
2+
import asyncio
3+
import contextlib
4+
import random
5+
import os
6+
import tempfile
7+
import subprocess
8+
import shutil
9+
import sys
10+
from typing import Optional
11+
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
12+
import requests
13+
import time
14+
from requests.exceptions import ConnectionError as RequestsConnectionError
15+
16+
@pytest.fixture(scope="module")
17+
def event_loop():
18+
loop = asyncio.get_event_loop()
19+
yield loop
20+
loop.close()
21+
22+
class ProcessLauncherHandle:
23+
def __init__(self, process, port: int):
24+
self.port = port
25+
self.process = process
26+
27+
def _inner_health(self) -> bool:
28+
return self.process.poll() is None
29+
30+
def health(self, timeout: int = 60):
31+
assert timeout > 0
32+
for _ in range(timeout):
33+
if not self._inner_health():
34+
raise RuntimeError("Launcher crashed")
35+
36+
try:
37+
url = f"http://0.0.0.0:{self.port}/health"
38+
headers = {"Content-Type": "application/json"}
39+
40+
response = requests.post(url, headers=headers)
41+
return
42+
except (ClientConnectorError, ClientOSError, ServerDisconnectedError, RequestsConnectionError) as e:
43+
print("Connecting")
44+
time.sleep(1)
45+
raise RuntimeError("Health check failed")
46+
47+
@pytest.fixture(scope="module")
48+
def launcher(event_loop):
49+
@contextlib.contextmanager
50+
def local_launcher(
51+
model_id: str,
52+
trust_remote_code: bool = False,
53+
use_flash_attention: bool = True,
54+
dtype: Optional[str] = None,
55+
revision: Optional[str] = None,
56+
pooling: Optional[str] = None,
57+
):
58+
port = random.randint(8000, 10_000)
59+
shard_uds_path = (
60+
f"/tmp/tei-tests-{model_id.split('/')[-1]}-server"
61+
)
62+
63+
args = [
64+
"text-embeddings-router",
65+
"--model-id",
66+
model_id,
67+
"--port",
68+
str(port),
69+
"--uds-path",
70+
shard_uds_path,
71+
]
72+
73+
env = os.environ
74+
75+
if dtype is not None:
76+
args.append("--dtype")
77+
args.append(dtype)
78+
if revision is not None:
79+
args.append("--revision")
80+
args.append(revision)
81+
if trust_remote_code:
82+
args.append("--trust-remote-code")
83+
if pooling:
84+
args.append("--pooling")
85+
args.append(str(max_input_length))
86+
87+
env["LOG_LEVEL"] = "debug"
88+
89+
if not use_flash_attention:
90+
env["USE_FLASH_ATTENTION"] = "false"
91+
92+
with tempfile.TemporaryFile("w+") as tmp:
93+
# We'll output stdout/stderr to a temporary file. Using a pipe
94+
# cause the process to block until stdout is read.
95+
print("call subprocess.Popen, with args", args)
96+
with subprocess.Popen(
97+
args,
98+
stdout=tmp,
99+
stderr=subprocess.STDOUT,
100+
env=env,
101+
) as process:
102+
yield ProcessLauncherHandle(process, port)
103+
104+
process.terminate()
105+
process.wait(60)
106+
107+
tmp.seek(0)
108+
shutil.copyfileobj(tmp, sys.stderr)
109+
110+
if not use_flash_attention:
111+
del env["USE_FLASH_ATTENTION"]
112+
113+
return local_launcher

tests/pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[pytest]
2+
asyncio_mode = auto

tests/test_default_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pytest
2+
import requests
3+
import json
4+
import torch
5+
6+
@pytest.fixture(scope="module")
7+
def default_model_handle(launcher):
8+
with launcher("sentence-transformers/all-MiniLM-L6-v2", use_flash_attention=False) as handle:
9+
yield handle
10+
11+
@pytest.fixture(scope="module")
12+
async def default_model(default_model_handle):
13+
default_model_handle.health(300)
14+
return default_model_handle
15+
16+
@pytest.mark.asyncio
17+
@pytest.mark.private
18+
async def test_single_query(default_model):
19+
url = f"http://0.0.0.0:{default_model.port}/embed"
20+
data = {"inputs": "What is Deep Learning?"}
21+
headers = {"Content-Type": "application/json"}
22+
23+
response = requests.post(url, json=data, headers=headers)
24+
25+
embedding = torch.Tensor(json.loads(response.text))
26+
# reference_embedding = torch.load("assets/default_model.pt")
27+
28+
# assert torch.allclose(embedding, reference_embedding)

0 commit comments

Comments
 (0)