19
19
import json
20
20
import logging
21
21
import os
22
+ from typing import IO
22
23
23
24
import requests
24
25
@@ -57,7 +58,9 @@ class LLMWhispererClient:
57
58
client's activities and errors.
58
59
"""
59
60
60
- formatter = logging .Formatter ("%(asctime)s - %(name)s - %(levelname)s - %(message)s" )
61
+ formatter = logging .Formatter (
62
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
63
+ )
61
64
logger = logging .getLogger (__name__ )
62
65
log_stream_handler = logging .StreamHandler ()
63
66
log_stream_handler .setFormatter (formatter )
@@ -114,7 +117,9 @@ def __init__(
114
117
self .api_key = os .getenv ("LLMWHISPERER_API_KEY" , "" )
115
118
else :
116
119
self .api_key = api_key
117
- self .logger .debug ("api_key set to %s" , LLMWhispererUtils .redact_key (self .api_key ))
120
+ self .logger .debug (
121
+ "api_key set to %s" , LLMWhispererUtils .redact_key (self .api_key )
122
+ )
118
123
119
124
self .api_timeout = api_timeout
120
125
@@ -150,6 +155,7 @@ def get_usage_info(self) -> dict:
150
155
def whisper (
151
156
self ,
152
157
file_path : str = "" ,
158
+ stream : IO [bytes ] = None ,
153
159
url : str = "" ,
154
160
processing_mode : str = "ocr" ,
155
161
output_mode : str = "line-printer" ,
@@ -170,6 +176,7 @@ def whisper(
170
176
171
177
Args:
172
178
file_path (str, optional): The path to the file to be processed. Defaults to "".
179
+ stream (IO[bytes], optional): A stream of bytes to be processed. Defaults to None.
173
180
url (str, optional): The URL of the file to be processed. Defaults to "".
174
181
processing_mode (str, optional): The processing mode. Can be "ocr" or "text". Defaults to "ocr".
175
182
output_mode (str, optional): The output mode. Can be "line-printer" or "text". Defaults to "line-printer".
@@ -212,11 +219,11 @@ def whisper(
212
219
self .logger .debug ("api_url: %s" , api_url )
213
220
self .logger .debug ("params: %s" , params )
214
221
215
- if url == "" and file_path == "" :
222
+ if url == "" and file_path == "" and stream is None :
216
223
raise LLMWhispererClientException (
217
224
{
218
225
"status_code" : - 1 ,
219
- "message" : "Either url or file_path must be provided" ,
226
+ "message" : "Either url, stream or file_path must be provided" ,
220
227
}
221
228
)
222
229
@@ -228,21 +235,39 @@ def whisper(
228
235
}
229
236
)
230
237
238
+ should_stream = False
231
239
if url == "" :
232
- with open (file_path , "rb" ) as f :
233
- data = f .read ()
234
- req = requests .Request (
235
- "POST" ,
236
- api_url ,
237
- params = params ,
238
- headers = self .headers ,
239
- data = data ,
240
- )
240
+ if stream is not None :
241
+
242
+ should_stream = True
243
+
244
+ def generate ():
245
+ for chunk in stream :
246
+ yield chunk
247
+
248
+ req = requests .Request (
249
+ "POST" ,
250
+ api_url ,
251
+ params = params ,
252
+ headers = self .headers ,
253
+ data = generate (),
254
+ )
255
+
256
+ else :
257
+ with open (file_path , "rb" ) as f :
258
+ data = f .read ()
259
+ req = requests .Request (
260
+ "POST" ,
261
+ api_url ,
262
+ params = params ,
263
+ headers = self .headers ,
264
+ data = data ,
265
+ )
241
266
else :
242
267
req = requests .Request ("POST" , api_url , params = params , headers = self .headers )
243
268
prepared = req .prepare ()
244
269
s = requests .Session ()
245
- response = s .send (prepared , timeout = self .api_timeout )
270
+ response = s .send (prepared , timeout = self .api_timeout , stream = should_stream )
246
271
if response .status_code != 200 and response .status_code != 202 :
247
272
message = json .loads (response .text )
248
273
message ["status_code" ] = response .status_code
0 commit comments