1
1
import logging
2
2
import os
3
3
import unittest
4
+ from difflib import SequenceMatcher , unified_diff
4
5
from pathlib import Path
5
6
6
7
import pytest
@@ -23,9 +24,7 @@ def test_get_usage_info(client):
23
24
"subscription_plan" ,
24
25
"today_page_count" ,
25
26
]
26
- assert set (usage_info .keys ()) == set (
27
- expected_keys
28
- ), f"usage_info { usage_info } does not contain the expected keys"
27
+ assert set (usage_info .keys ()) == set (expected_keys ), f"usage_info { usage_info } does not contain the expected keys"
29
28
30
29
31
30
@pytest .mark .parametrize (
@@ -56,7 +55,21 @@ def test_whisper(client, data_dir, processing_mode, output_mode, input_file):
56
55
57
56
assert isinstance (response , dict )
58
57
assert response ["status_code" ] == 200
59
- assert response ["extracted_text" ] == exp
58
+
59
+ # For text based processing, perform a strict match
60
+ if processing_mode == "text" and output_mode == "text" :
61
+ assert response ["extracted_text" ] == exp
62
+ # For OCR based processing, perform a fuzzy match
63
+ else :
64
+ extracted_text = response ["extracted_text" ]
65
+ similarity = SequenceMatcher (None , extracted_text , exp ).ratio ()
66
+ threshold = 0.97
67
+
68
+ if similarity < threshold :
69
+ diff = "\n " .join (
70
+ unified_diff (exp .splitlines (), extracted_text .splitlines (), fromfile = "Expected" , tofile = "Extracted" )
71
+ )
72
+ pytest .fail (f"Texts are not similar enough: { similarity * 100 :.2f} % similarity. Diff:\n { diff } " )
60
73
61
74
62
75
# TODO: Review and port to pytest based tests
@@ -78,9 +91,7 @@ def test_whisper(self):
78
91
# @unittest.skip("Skipping test_whisper")
79
92
def test_whisper_stream (self ):
80
93
client = LLMWhispererClient ()
81
- download_url = (
82
- "https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
83
- )
94
+ download_url = "https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
84
95
# Create a stream of download_url and pass it to whisper
85
96
response_download = requests .get (download_url , stream = True )
86
97
response_download .raise_for_status ()
@@ -95,18 +106,14 @@ def test_whisper_stream(self):
95
106
@unittest .skip ("Skipping test_whisper_status" )
96
107
def test_whisper_status (self ):
97
108
client = LLMWhispererClient ()
98
- response = client .whisper_status (
99
- whisper_hash = "7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
100
- )
109
+ response = client .whisper_status (whisper_hash = "7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a" )
101
110
logger .info (response )
102
111
self .assertIsInstance (response , dict )
103
112
104
113
@unittest .skip ("Skipping test_whisper_retrieve" )
105
114
def test_whisper_retrieve (self ):
106
115
client = LLMWhispererClient ()
107
- response = client .whisper_retrieve (
108
- whisper_hash = "7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
109
- )
116
+ response = client .whisper_retrieve (whisper_hash = "7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a" )
110
117
logger .info (response )
111
118
self .assertIsInstance (response , dict )
112
119
0 commit comments