Skip to content

Commit 2871187

Browse files
committed
✨ add support for workflow polling
1 parent 9e4885a commit 2871187

File tree

2 files changed

+59
-38
lines changed

2 files changed

+59
-38
lines changed

mindee/client.py

Lines changed: 58 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mindee.input import WorkflowOptions
88
from mindee.input.local_response import LocalResponse
99
from mindee.input.page_options import PageOptions
10+
from mindee.input.predict_options import AsyncPredictOptions, PredictOptions
1011
from mindee.input.sources.base_64_input import Base64Input
1112
from mindee.input.sources.bytes_input import BytesInput
1213
from mindee.input.sources.file_input import FileInput
@@ -123,14 +124,13 @@ def parse(
123124
page_options.on_min_pages,
124125
page_options.page_indexes,
125126
)
127+
options = PredictOptions(cropper, full_text, include_words)
126128
return self._make_request(
127129
product_class,
128130
input_source,
129131
endpoint,
130-
include_words,
132+
options,
131133
close_file,
132-
cropper,
133-
full_text,
134134
)
135135

136136
def enqueue(
@@ -143,6 +143,8 @@ def enqueue(
143143
cropper: bool = False,
144144
endpoint: Optional[Endpoint] = None,
145145
full_text: bool = False,
146+
workflow_id: Optional[str] = None,
147+
rag: bool = False,
146148
) -> AsyncPredictResponse:
147149
"""
148150
Enqueues a document to an asynchronous endpoint.
@@ -169,6 +171,11 @@ def enqueue(
169171
:param endpoint: For custom endpoints, an endpoint has to be given.
170172
171173
:param full_text: Whether to include the full OCR text response in compatible APIs.
174+
175+
:param workflow_id: Workflow ID.
176+
177+
:param rag: If set, will enable Retrieval-Augmented Generation.
178+
Only works if a valid ``workflow_id`` is set.
172179
"""
173180
if input_source is None:
174181
raise MindeeClientError("No input document provided.")
@@ -185,14 +192,15 @@ def enqueue(
185192
page_options.on_min_pages,
186193
page_options.page_indexes,
187194
)
195+
options = AsyncPredictOptions(
196+
cropper, full_text, include_words, workflow_id, rag
197+
)
188198
return self._predict_async(
189199
product_class,
190200
input_source,
201+
options,
191202
endpoint,
192-
include_words,
193203
close_file,
194-
cropper,
195-
full_text,
196204
)
197205

198206
def load_prediction(
@@ -246,8 +254,9 @@ def execute_workflow(
246254
:param input_source: The document/source file to use.
247255
Has to be created beforehand.
248256
:param workflow_id: ID of the workflow.
249-
:param page_options: If set, remove pages from the document as specified. This is done before sending the file\
250-
to the server. It is useful to avoid page limitations.
257+
:param page_options: If set, remove pages from the document as specified.
258+
This is done before sending the file to the server.
259+
It is useful to avoid page limitations.
251260
:param options: Options for the workflow.
252261
:return:
253262
"""
@@ -259,13 +268,11 @@ def execute_workflow(
259268
page_options.page_indexes,
260269
)
261270

262-
logger.debug("Sending document to workflow: %s", workflow_id)
263-
264271
if not options:
265272
options = WorkflowOptions(
266273
alias=None, priority=None, full_text=False, public_url=None, rag=False
267274
)
268-
275+
logger.debug("Sending document to workflow: %s", workflow_id)
269276
return self._send_to_workflow(GeneratedV1, input_source, workflow_id, options)
270277

271278
def _validate_async_params(
@@ -285,7 +292,7 @@ def _validate_async_params(
285292
if max_retries < min_retries:
286293
raise MindeeClientError(f"Cannot set retries to less than {min_retries}.")
287294

288-
def enqueue_and_parse(
295+
def enqueue_and_parse( # pylint: disable=too-many-locals
289296
self,
290297
product_class: Type[Inference],
291298
input_source: Union[LocalInputSource, UrlInputSource],
@@ -298,40 +305,51 @@ def enqueue_and_parse(
298305
delay_sec: float = 1.5,
299306
max_retries: int = 80,
300307
full_text: bool = False,
308+
workflow_id: Optional[str] = None,
309+
rag: bool = False,
301310
) -> AsyncPredictResponse:
302311
"""
303312
Enqueues to an asynchronous endpoint and automatically polls for a response.
304313
305-
:param product_class: The document class to use. The response object will be instantiated based on this\
306-
parameter.
314+
:param product_class: The document class to use.
315+
The response object will be instantiated based on this parameter.
307316
308-
:param input_source: The document/source file to use. Has to be created beforehand.
317+
:param input_source: The document/source file to use.
318+
Has to be created beforehand.
309319
310-
:param include_words: Whether to include the full text for each page. This performs a full OCR operation on\
311-
the server and will increase response time.
320+
:param include_words: Whether to include the full text for each page.
321+
This performs a full OCR operation on the server and will increase response time.
312322
313-
:param close_file: Whether to ``close()`` the file after parsing it. Set to ``False`` if you need to access\
314-
the file after this operation.
323+
:param close_file: Whether to ``close()`` the file after parsing it.
324+
Set to ``False`` if you need to access the file after this operation.
315325
316-
:param page_options: If set, remove pages from the document as specified. This is done before sending the file\
317-
to the server. It is useful to avoid page limitations.
326+
:param page_options: If set, remove pages from the document as specified.
327+
This is done before sending the file to the server.
328+
It is useful to avoid page limitations.
318329
319-
:param cropper: Whether to include cropper results for each page. This performs a cropping operation on the\
320-
server and will increase response time.
330+
:param cropper: Whether to include cropper results for each page.
331+
This performs a cropping operation on the server and will increase response time.
321332
322333
:param endpoint: For custom endpoints, an endpoint has to be given.
323334
324-
:param initial_delay_sec: Delay between each polling attempts This should not be shorter than 1 second.
335+
:param initial_delay_sec: Delay between each polling attempts.
336+
This should not be shorter than 1 second.
325337
326-
:param delay_sec: Delay between each polling attempts This should not be shorter than 1 second.
338+
:param delay_sec: Delay between each polling attempts.
339+
This should not be shorter than 1 second.
327340
328341
:param max_retries: Total amount of polling attempts.
329342
330343
:param full_text: Whether to include the full OCR text response in compatible APIs.
344+
345+
:param workflow_id: Workflow ID.
346+
347+
:param rag: If set, will enable Retrieval-Augmented Generation.
348+
Only works if a valid ``workflow_id`` is set.
331349
"""
332350
self._validate_async_params(initial_delay_sec, delay_sec, max_retries)
333351
if not endpoint:
334-
endpoint = self._initialize_ots_endpoint(product_class)
352+
endpoint = self._initialize_ots_endpoint(product_class=product_class)
335353
queue_result = self.enqueue(
336354
product_class,
337355
input_source,
@@ -341,6 +359,8 @@ def enqueue_and_parse(
341359
cropper,
342360
endpoint,
343361
full_text,
362+
workflow_id,
363+
rag,
344364
)
345365
logger.debug(
346366
"Successfully enqueued document with job id: %s", queue_result.job.id
@@ -406,15 +426,16 @@ def _make_request(
406426
product_class: Type[Inference],
407427
input_source: Union[LocalInputSource, UrlInputSource],
408428
endpoint: Endpoint,
409-
include_words: bool,
429+
options: PredictOptions,
410430
close_file: bool,
411-
cropper: bool,
412-
full_text: bool,
413431
) -> PredictResponse:
414432
response = endpoint.predict_req_post(
415-
input_source, include_words, close_file, cropper, full_text
433+
input_source,
434+
options.include_words,
435+
close_file,
436+
options.cropper,
437+
options.full_text,
416438
)
417-
418439
dict_response = response.json()
419440

420441
if not is_valid_sync_response(response):
@@ -423,28 +444,28 @@ def _make_request(
423444
str(product_class.endpoint_name),
424445
clean_response,
425446
)
426-
427447
return PredictResponse(product_class, dict_response)
428448

429449
def _predict_async(
430450
self,
431451
product_class: Type[Inference],
432452
input_source: Union[LocalInputSource, UrlInputSource],
453+
options: AsyncPredictOptions,
433454
endpoint: Optional[Endpoint] = None,
434-
include_words: bool = False,
435455
close_file: bool = True,
436-
cropper: bool = False,
437-
full_text: bool = False,
438456
) -> AsyncPredictResponse:
439457
"""Sends a document to the queue, and sends back an asynchronous predict response."""
440458
if input_source is None:
441459
raise MindeeClientError("No input document provided")
442460
if not endpoint:
443461
endpoint = self._initialize_ots_endpoint(product_class)
444462
response = endpoint.predict_async_req_post(
445-
input_source, include_words, close_file, cropper, full_text
463+
input_source,
464+
options.include_words,
465+
close_file,
466+
options.cropper,
467+
options.full_text,
446468
)
447-
448469
dict_response = response.json()
449470

450471
if not is_valid_async_response(response):

mindee/parsing/common/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __str__(self) -> str:
5757
@staticmethod
5858
def get_endpoint_info(klass: Type["Inference"]) -> Dict[str, str]:
5959
"""
60-
Retrives the endpoint information for an Inference.
60+
Retrieves the endpoint information for an Inference.
6161
6262
Should never retrieve info for CustomV1, as a custom endpoint should be created to use CustomV1.
6363

0 commit comments

Comments
 (0)