Skip to content

Commit ae3da10

Browse files
committed
add reference tensors
1 parent 309d255 commit ae3da10

10 files changed

+64
-4
lines changed

tests/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,27 @@ pip install -r requirements.txt
88
and mounting a volume for the tests, they can be run from within the container with
99
```
1010
pytest tests/ -s -vvvvv
11+
```
12+
13+
## Reference outputs
14+
15+
For example, collecting the reference on an RTX 4090 on Candle backend:
16+
```
17+
docker run --rm -it --gpus all --net host --entrypoint "/bin/bash" -v $(pwd):/tei ghcr.io/huggingface/text-embeddings-inference:89-1.2.3
18+
```
19+
and
20+
```
21+
text-embeddings-router --model-id sentence-transformers/all-MiniLM-L6-v2
22+
```
23+
24+
and then
25+
```
26+
python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 1 --flash
27+
python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 3 --flash
28+
```
29+
30+
Restart server with `USE_FLASH_ATTENTION=0`, and
31+
```
32+
python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 1
33+
python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 3
1134
```

tests/assets/default_bert.pt

Whitespace-only changes.

tests/assets/flash_bert.pt

Whitespace-only changes.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

tests/collect.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
2+
import requests
3+
import torch
4+
import argparse
5+
import json
6+
import os
7+
8+
parser = argparse.ArgumentParser(description='Assets collection')
9+
parser.add_argument('--model-id', help='Model id', required=True)
10+
parser.add_argument('--n_inp', help='Number of inputs', required=True, type=int)
11+
parser.add_argument('--flash', action='store_true')
12+
13+
args = parser.parse_args()
14+
15+
url = f"http://0.0.0.0:80/embed"
16+
17+
INPUTS = [
18+
"What is Deep Learning?",
19+
"Today I am in Paris and I would like to",
20+
"Paris weather is",
21+
"Great job"
22+
]
23+
24+
data = {"inputs": INPUTS[:args.n_inp]}
25+
headers = {"Content-Type": "application/json"}
26+
27+
response = requests.post(url, json=data, headers=headers)
28+
29+
embedding = torch.Tensor(json.loads(response.text))
30+
31+
postfix = ""
32+
if not args.flash:
33+
postfix = "_no_flash"
34+
35+
save_path = f"./assets/{args.model_id.replace('/', '-')}_inp{args.n_inp}{postfix}.pt"
36+
print(f"Saving embedding of shape {embedding.shape} to {save_path}")
37+
torch.save(embedding, save_path)

tests/test_default_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ async def test_single_query(default_model):
2323
response = requests.post(url, json=data, headers=headers)
2424

2525
embedding = torch.Tensor(json.loads(response.text))
26-
# reference_embedding = torch.load("assets/default_model.pt")
26+
reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt")
2727

28-
# assert torch.allclose(embedding, reference_embedding)
28+
assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3)

tests/test_flash_bert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ async def test_single_query(default_model):
2323
response = requests.post(url, json=data, headers=headers)
2424

2525
embedding = torch.Tensor(json.loads(response.text))
26-
# reference_embedding = torch.load("assets/default_model.pt")
26+
reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt")
2727

28-
# assert torch.allclose(embedding, reference_embedding)
28+
assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3)

0 commit comments

Comments
 (0)