Skip to content

Commit b467ea0

Browse files
Perform fuzzy assertion for OCR based tests
1 parent db264a2 commit b467ea0

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

tests/client_test.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import os
33
import unittest
4+
from difflib import SequenceMatcher, unified_diff
45
from pathlib import Path
56

67
import pytest
@@ -23,9 +24,7 @@ def test_get_usage_info(client):
2324
"subscription_plan",
2425
"today_page_count",
2526
]
26-
assert set(usage_info.keys()) == set(
27-
expected_keys
28-
), f"usage_info {usage_info} does not contain the expected keys"
27+
assert set(usage_info.keys()) == set(expected_keys), f"usage_info {usage_info} does not contain the expected keys"
2928

3029

3130
@pytest.mark.parametrize(
@@ -56,7 +55,21 @@ def test_whisper(client, data_dir, processing_mode, output_mode, input_file):
5655

5756
assert isinstance(response, dict)
5857
assert response["status_code"] == 200
59-
assert response["extracted_text"] == exp
58+
59+
# For text based processing, perform a strict match
60+
if processing_mode == "text" and output_mode == "text":
61+
assert response["extracted_text"] == exp
62+
# For OCR based processing, perform a fuzzy match
63+
else:
64+
extracted_text = response["extracted_text"]
65+
similarity = SequenceMatcher(None, extracted_text, exp).ratio()
66+
threshold = 0.97
67+
68+
if similarity < threshold:
69+
diff = "\n".join(
70+
unified_diff(exp.splitlines(), extracted_text.splitlines(), fromfile="Expected", tofile="Extracted")
71+
)
72+
pytest.fail(f"Texts are not similar enough: {similarity * 100:.2f}% similarity. Diff:\n{diff}")
6073

6174

6275
# TODO: Review and port to pytest based tests
@@ -78,9 +91,7 @@ def test_whisper(self):
7891
# @unittest.skip("Skipping test_whisper")
7992
def test_whisper_stream(self):
8093
client = LLMWhispererClient()
81-
download_url = (
82-
"https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
83-
)
94+
download_url = "https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
8495
# Create a stream of download_url and pass it to whisper
8596
response_download = requests.get(download_url, stream=True)
8697
response_download.raise_for_status()
@@ -95,18 +106,14 @@ def test_whisper_stream(self):
95106
@unittest.skip("Skipping test_whisper_status")
96107
def test_whisper_status(self):
97108
client = LLMWhispererClient()
98-
response = client.whisper_status(
99-
whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
100-
)
109+
response = client.whisper_status(whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a")
101110
logger.info(response)
102111
self.assertIsInstance(response, dict)
103112

104113
@unittest.skip("Skipping test_whisper_retrieve")
105114
def test_whisper_retrieve(self):
106115
client = LLMWhispererClient()
107-
response = client.whisper_retrieve(
108-
whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
109-
)
116+
response = client.whisper_retrieve(whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a")
110117
logger.info(response)
111118
self.assertIsInstance(response, dict)
112119

0 commit comments

Comments
 (0)