Skip to content

Commit 6ebf313

Browse files
authored
Avoid direct comparison of floating point numbers (#21002)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
1 parent cfbcb9e commit 6ebf313

File tree

5 files changed

+44
-7
lines changed

5 files changed

+44
-7
lines changed

tests/entrypoints/openai/test_classification.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,4 +176,8 @@ async def test_invocations(server: RemoteOpenAIServer):
176176
invocation_output = invocation_response.json()
177177

178178
assert classification_output.keys() == invocation_output.keys()
179-
assert classification_output["data"] == invocation_output["data"]
179+
for classification_data, invocation_data in zip(
180+
classification_output["data"], invocation_output["data"]):
181+
assert classification_data.keys() == invocation_data.keys()
182+
assert classification_data["probs"] == pytest.approx(
183+
invocation_data["probs"], rel=0.01)

tests/entrypoints/openai/test_embedding.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from ...models.language.pooling.embed_utils import (
1616
run_embedding_correctness_test)
17+
from ...models.utils import check_embeddings_close
1718
from ...utils import RemoteOpenAIServer
1819

1920
MODEL_NAME = "intfloat/multilingual-e5-small"
@@ -321,7 +322,13 @@ async def test_invocations(server: RemoteOpenAIServer,
321322
invocation_output = invocation_response.json()
322323

323324
assert completion_output.keys() == invocation_output.keys()
324-
assert completion_output["data"] == invocation_output["data"]
325+
for completion_data, invocation_data in zip(completion_output["data"],
326+
invocation_output["data"]):
327+
assert completion_data.keys() == invocation_data.keys()
328+
check_embeddings_close(embeddings_0_lst=[completion_data["embedding"]],
329+
embeddings_1_lst=[invocation_data["embedding"]],
330+
name_0="completion",
331+
name_1="invocation")
325332

326333

327334
@pytest.mark.asyncio
@@ -355,4 +362,10 @@ async def test_invocations_conversation(server: RemoteOpenAIServer):
355362
invocation_output = invocation_response.json()
356363

357364
assert chat_output.keys() == invocation_output.keys()
358-
assert chat_output["data"] == invocation_output["data"]
365+
for chat_data, invocation_data in zip(chat_output["data"],
366+
invocation_output["data"]):
367+
assert chat_data.keys() == invocation_data.keys()
368+
check_embeddings_close(embeddings_0_lst=[chat_data["embedding"]],
369+
embeddings_1_lst=[invocation_data["embedding"]],
370+
name_0="chat",
371+
name_1="invocation")

tests/entrypoints/openai/test_pooling.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,13 @@ async def test_invocations(server: RemoteOpenAIServer):
281281
invocation_output = invocation_response.json()
282282

283283
assert completion_output.keys() == invocation_output.keys()
284-
assert completion_output["data"] == invocation_output["data"]
284+
for completion_data, invocation_data in zip(completion_output["data"],
285+
invocation_output["data"]):
286+
assert completion_data.keys() == invocation_data.keys()
287+
check_embeddings_close(embeddings_0_lst=completion_data["data"],
288+
embeddings_1_lst=invocation_data["data"],
289+
name_0="completion",
290+
name_1="invocation")
285291

286292

287293
@pytest.mark.asyncio
@@ -314,4 +320,10 @@ async def test_invocations_conversation(server: RemoteOpenAIServer):
314320
invocation_output = invocation_response.json()
315321

316322
assert chat_output.keys() == invocation_output.keys()
317-
assert chat_output["data"] == invocation_output["data"]
323+
for chat_data, invocation_data in zip(chat_output["data"],
324+
invocation_output["data"]):
325+
assert chat_data.keys() == invocation_data.keys()
326+
check_embeddings_close(embeddings_0_lst=chat_data["data"],
327+
embeddings_1_lst=invocation_data["data"],
328+
name_0="chat",
329+
name_1="invocation")

tests/entrypoints/openai/test_rerank.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,8 @@ def test_invocations(server: RemoteOpenAIServer):
120120
invocation_output = invocation_response.json()
121121

122122
assert rerank_output.keys() == invocation_output.keys()
123-
assert rerank_output["results"] == invocation_output["results"]
123+
for rerank_result, invocations_result in zip(rerank_output["results"],
124+
invocation_output["results"]):
125+
assert rerank_result.keys() == invocations_result.keys()
126+
assert rerank_result["relevance_score"] == pytest.approx(
127+
invocations_result["relevance_score"], rel=0.01)

tests/entrypoints/openai/test_score.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,4 +215,8 @@ def test_invocations(self, server: RemoteOpenAIServer, model: dict[str,
215215
invocation_output = invocation_response.json()
216216

217217
assert score_output.keys() == invocation_output.keys()
218-
assert score_output["data"] == invocation_output["data"]
218+
for score_data, invocation_data in zip(score_output["data"],
219+
invocation_output["data"]):
220+
assert score_data.keys() == invocation_data.keys()
221+
assert score_data["score"] == pytest.approx(
222+
invocation_data["score"], rel=0.01)

0 commit comments

Comments
 (0)