Skip to content

Commit 103176c

Browse files
Added streaming upload
1 parent e3b7470 commit 103176c

File tree

2 files changed

+68
-19
lines changed

2 files changed

+68
-19
lines changed

src/unstract/llmwhisperer/client.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import json
2020
import logging
2121
import os
22+
from typing import IO
2223

2324
import requests
2425

@@ -57,7 +58,9 @@ class LLMWhispererClient:
5758
client's activities and errors.
5859
"""
5960

60-
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
61+
formatter = logging.Formatter(
62+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
63+
)
6164
logger = logging.getLogger(__name__)
6265
log_stream_handler = logging.StreamHandler()
6366
log_stream_handler.setFormatter(formatter)
@@ -114,7 +117,9 @@ def __init__(
114117
self.api_key = os.getenv("LLMWHISPERER_API_KEY", "")
115118
else:
116119
self.api_key = api_key
117-
self.logger.debug("api_key set to %s", LLMWhispererUtils.redact_key(self.api_key))
120+
self.logger.debug(
121+
"api_key set to %s", LLMWhispererUtils.redact_key(self.api_key)
122+
)
118123

119124
self.api_timeout = api_timeout
120125

@@ -150,6 +155,7 @@ def get_usage_info(self) -> dict:
150155
def whisper(
151156
self,
152157
file_path: str = "",
158+
stream: IO[bytes] = None,
153159
url: str = "",
154160
processing_mode: str = "ocr",
155161
output_mode: str = "line-printer",
@@ -170,6 +176,7 @@ def whisper(
170176
171177
Args:
172178
file_path (str, optional): The path to the file to be processed. Defaults to "".
179+
stream (IO[bytes], optional): A stream of bytes to be processed. Defaults to None.
173180
url (str, optional): The URL of the file to be processed. Defaults to "".
174181
processing_mode (str, optional): The processing mode. Can be "ocr" or "text". Defaults to "ocr".
175182
output_mode (str, optional): The output mode. Can be "line-printer" or "text". Defaults to "line-printer".
@@ -212,11 +219,11 @@ def whisper(
212219
self.logger.debug("api_url: %s", api_url)
213220
self.logger.debug("params: %s", params)
214221

215-
if url == "" and file_path == "":
222+
if url == "" and file_path == "" and stream is None:
216223
raise LLMWhispererClientException(
217224
{
218225
"status_code": -1,
219-
"message": "Either url or file_path must be provided",
226+
"message": "Either url, stream or file_path must be provided",
220227
}
221228
)
222229

@@ -228,21 +235,39 @@ def whisper(
228235
}
229236
)
230237

238+
should_stream = False
231239
if url == "":
232-
with open(file_path, "rb") as f:
233-
data = f.read()
234-
req = requests.Request(
235-
"POST",
236-
api_url,
237-
params=params,
238-
headers=self.headers,
239-
data=data,
240-
)
240+
if stream is not None:
241+
242+
should_stream = True
243+
244+
def generate():
245+
for chunk in stream:
246+
yield chunk
247+
248+
req = requests.Request(
249+
"POST",
250+
api_url,
251+
params=params,
252+
headers=self.headers,
253+
data=generate(),
254+
)
255+
256+
else:
257+
with open(file_path, "rb") as f:
258+
data = f.read()
259+
req = requests.Request(
260+
"POST",
261+
api_url,
262+
params=params,
263+
headers=self.headers,
264+
data=data,
265+
)
241266
else:
242267
req = requests.Request("POST", api_url, params=params, headers=self.headers)
243268
prepared = req.prepare()
244269
s = requests.Session()
245-
response = s.send(prepared, timeout=self.api_timeout)
270+
response = s.send(prepared, timeout=self.api_timeout, stream=should_stream)
246271
if response.status_code != 200 and response.status_code != 202:
247272
message = json.loads(response.text)
248273
message["status_code"] = response.status_code

tests/client_test.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55

66
import pytest
7+
import requests
78

89
from unstract.llmwhisperer import LLMWhispererClient
910

@@ -22,7 +23,9 @@ def test_get_usage_info(client):
2223
"subscription_plan",
2324
"today_page_count",
2425
]
25-
assert set(usage_info.keys()) == set(expected_keys), f"usage_info {usage_info} does not contain the expected keys"
26+
assert set(usage_info.keys()) == set(
27+
expected_keys
28+
), f"usage_info {usage_info} does not contain the expected keys"
2629

2730

2831
@pytest.mark.parametrize(
@@ -65,24 +68,45 @@ def test_whisper(self):
6568
# url="https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
6669
# )
6770
response = client.whisper(
68-
file_path="test_files/restaurant_invoice_photo.pdf",
71+
file_path="test_data/restaurant_invoice_photo.pdf",
6972
timeout=200,
7073
store_metadata_for_highlighting=True,
7174
)
72-
logger.info(response)
75+
print(response)
76+
# self.assertIsInstance(response, dict)
77+
78+
# @unittest.skip("Skipping test_whisper")
79+
def test_whisper_stream(self):
80+
client = LLMWhispererClient()
81+
download_url = (
82+
"https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
83+
)
84+
# Create a stream of download_url and pass it to whisper
85+
response_download = requests.get(download_url, stream=True)
86+
response_download.raise_for_status()
87+
response = client.whisper(
88+
stream=response_download.iter_content(chunk_size=1024),
89+
timeout=200,
90+
store_metadata_for_highlighting=True,
91+
)
92+
print(response)
7393
# self.assertIsInstance(response, dict)
7494

7595
@unittest.skip("Skipping test_whisper_status")
7696
def test_whisper_status(self):
7797
client = LLMWhispererClient()
78-
response = client.whisper_status(whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a")
98+
response = client.whisper_status(
99+
whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
100+
)
79101
logger.info(response)
80102
self.assertIsInstance(response, dict)
81103

82104
@unittest.skip("Skipping test_whisper_retrieve")
83105
def test_whisper_retrieve(self):
84106
client = LLMWhispererClient()
85-
response = client.whisper_retrieve(whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a")
107+
response = client.whisper_retrieve(
108+
whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
109+
)
86110
logger.info(response)
87111
self.assertIsInstance(response, dict)
88112

0 commit comments

Comments
 (0)