Skip to content

✨ add support for workflow polling #323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 60 additions & 37 deletions mindee/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mindee.input import WorkflowOptions
from mindee.input.local_response import LocalResponse
from mindee.input.page_options import PageOptions
from mindee.input.predict_options import AsyncPredictOptions, PredictOptions
from mindee.input.sources.base_64_input import Base64Input
from mindee.input.sources.bytes_input import BytesInput
from mindee.input.sources.file_input import FileInput
Expand Down Expand Up @@ -123,14 +124,13 @@ def parse(
page_options.on_min_pages,
page_options.page_indexes,
)
options = PredictOptions(cropper, full_text, include_words)
return self._make_request(
product_class,
input_source,
endpoint,
include_words,
options,
close_file,
cropper,
full_text,
)

def enqueue(
Expand All @@ -143,6 +143,8 @@ def enqueue(
cropper: bool = False,
endpoint: Optional[Endpoint] = None,
full_text: bool = False,
workflow_id: Optional[str] = None,
rag: bool = False,
) -> AsyncPredictResponse:
"""
Enqueues a document to an asynchronous endpoint.
Expand All @@ -169,6 +171,11 @@ def enqueue(
:param endpoint: For custom endpoints, an endpoint has to be given.

:param full_text: Whether to include the full OCR text response in compatible APIs.

:param workflow_id: Workflow ID.

:param rag: If set, will enable Retrieval-Augmented Generation.
Only works if a valid ``workflow_id`` is set.
"""
if input_source is None:
raise MindeeClientError("No input document provided.")
Expand All @@ -185,14 +192,15 @@ def enqueue(
page_options.on_min_pages,
page_options.page_indexes,
)
options = AsyncPredictOptions(
cropper, full_text, include_words, workflow_id, rag
)
return self._predict_async(
product_class,
input_source,
options,
endpoint,
include_words,
close_file,
cropper,
full_text,
)

def load_prediction(
Expand Down Expand Up @@ -246,8 +254,9 @@ def execute_workflow(
:param input_source: The document/source file to use.
Has to be created beforehand.
:param workflow_id: ID of the workflow.
:param page_options: If set, remove pages from the document as specified. This is done before sending the file\
to the server. It is useful to avoid page limitations.
:param page_options: If set, remove pages from the document as specified.
This is done before sending the file to the server.
It is useful to avoid page limitations.
:param options: Options for the workflow.
:return:
"""
Expand All @@ -259,13 +268,11 @@ def execute_workflow(
page_options.page_indexes,
)

logger.debug("Sending document to workflow: %s", workflow_id)

if not options:
options = WorkflowOptions(
alias=None, priority=None, full_text=False, public_url=None, rag=False
)

logger.debug("Sending document to workflow: %s", workflow_id)
return self._send_to_workflow(GeneratedV1, input_source, workflow_id, options)

def _validate_async_params(
Expand All @@ -285,7 +292,7 @@ def _validate_async_params(
if max_retries < min_retries:
raise MindeeClientError(f"Cannot set retries to less than {min_retries}.")

def enqueue_and_parse(
def enqueue_and_parse( # pylint: disable=too-many-locals
self,
product_class: Type[Inference],
input_source: Union[LocalInputSource, UrlInputSource],
Expand All @@ -298,40 +305,51 @@ def enqueue_and_parse(
delay_sec: float = 1.5,
max_retries: int = 80,
full_text: bool = False,
workflow_id: Optional[str] = None,
rag: bool = False,
) -> AsyncPredictResponse:
"""
Enqueues to an asynchronous endpoint and automatically polls for a response.

:param product_class: The document class to use. The response object will be instantiated based on this\
parameter.
:param product_class: The document class to use.
The response object will be instantiated based on this parameter.

:param input_source: The document/source file to use. Has to be created beforehand.
:param input_source: The document/source file to use.
Has to be created beforehand.

:param include_words: Whether to include the full text for each page. This performs a full OCR operation on\
the server and will increase response time.
:param include_words: Whether to include the full text for each page.
This performs a full OCR operation on the server and will increase response time.

:param close_file: Whether to ``close()`` the file after parsing it. Set to ``False`` if you need to access\
the file after this operation.
:param close_file: Whether to ``close()`` the file after parsing it.
Set to ``False`` if you need to access the file after this operation.

:param page_options: If set, remove pages from the document as specified. This is done before sending the file\
to the server. It is useful to avoid page limitations.
:param page_options: If set, remove pages from the document as specified.
This is done before sending the file to the server.
It is useful to avoid page limitations.

:param cropper: Whether to include cropper results for each page. This performs a cropping operation on the\
server and will increase response time.
:param cropper: Whether to include cropper results for each page.
This performs a cropping operation on the server and will increase response time.

:param endpoint: For custom endpoints, an endpoint has to be given.

:param initial_delay_sec: Delay between each polling attempts This should not be shorter than 1 second.
:param initial_delay_sec: Delay between each polling attempts.
This should not be shorter than 1 second.

:param delay_sec: Delay between each polling attempts This should not be shorter than 1 second.
:param delay_sec: Delay between each polling attempts.
This should not be shorter than 1 second.

:param max_retries: Total amount of polling attempts.

:param full_text: Whether to include the full OCR text response in compatible APIs.

:param workflow_id: Workflow ID.

:param rag: If set, will enable Retrieval-Augmented Generation.
Only works if a valid ``workflow_id`` is set.
"""
self._validate_async_params(initial_delay_sec, delay_sec, max_retries)
if not endpoint:
endpoint = self._initialize_ots_endpoint(product_class)
endpoint = self._initialize_ots_endpoint(product_class=product_class)
queue_result = self.enqueue(
product_class,
input_source,
Expand All @@ -341,6 +359,8 @@ def enqueue_and_parse(
cropper,
endpoint,
full_text,
workflow_id,
rag,
)
logger.debug(
"Successfully enqueued document with job id: %s", queue_result.job.id
Expand Down Expand Up @@ -406,15 +426,16 @@ def _make_request(
product_class: Type[Inference],
input_source: Union[LocalInputSource, UrlInputSource],
endpoint: Endpoint,
include_words: bool,
options: PredictOptions,
close_file: bool,
cropper: bool,
full_text: bool,
) -> PredictResponse:
response = endpoint.predict_req_post(
input_source, include_words, close_file, cropper, full_text
input_source,
options.include_words,
close_file,
options.cropper,
options.full_text,
)

dict_response = response.json()

if not is_valid_sync_response(response):
Expand All @@ -423,28 +444,30 @@ def _make_request(
str(product_class.endpoint_name),
clean_response,
)

return PredictResponse(product_class, dict_response)

def _predict_async(
self,
product_class: Type[Inference],
input_source: Union[LocalInputSource, UrlInputSource],
options: AsyncPredictOptions,
endpoint: Optional[Endpoint] = None,
include_words: bool = False,
close_file: bool = True,
cropper: bool = False,
full_text: bool = False,
) -> AsyncPredictResponse:
"""Sends a document to the queue, and sends back an asynchronous predict response."""
if input_source is None:
raise MindeeClientError("No input document provided")
if not endpoint:
endpoint = self._initialize_ots_endpoint(product_class)
response = endpoint.predict_async_req_post(
input_source, include_words, close_file, cropper, full_text
input_source=input_source,
include_words=options.include_words,
close_file=close_file,
cropper=options.cropper,
full_text=options.full_text,
workflow_id=options.workflow_id,
rag=options.rag,
)

dict_response = response.json()

if not is_valid_async_response(response):
Expand Down
31 changes: 31 additions & 0 deletions mindee/input/predict_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Optional


class PredictOptions:
"""Options to pass to a prediction."""

def __init__(
self,
cropper: bool = False,
full_text: bool = False,
include_words: bool = False,
):
self.cropper = cropper
self.full_text = full_text
self.include_words = include_words


class AsyncPredictOptions(PredictOptions):
"""Options to pass to an asynchronous prediction."""

def __init__(
self,
cropper: bool = False,
full_text: bool = False,
include_words: bool = False,
workflow_id: Optional[str] = None,
rag: bool = False,
):
super().__init__(cropper, full_text, include_words)
self.workflow_id = workflow_id
self.rag = rag
28 changes: 24 additions & 4 deletions mindee/mindee_http/endpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Union
from typing import Optional, Union

import requests
from requests import Response
Expand Down Expand Up @@ -60,6 +60,8 @@ def predict_async_req_post(
close_file: bool = True,
cropper: bool = False,
full_text: bool = False,
workflow_id: Optional[str] = None,
rag: bool = False,
) -> requests.Response:
"""
Make an asynchronous request to POST a document for prediction.
Expand All @@ -69,10 +71,19 @@ def predict_async_req_post(
:param close_file: Whether to `close()` the file after parsing it.
:param cropper: Including Mindee cropping results.
:param full_text: Whether to include the full OCR text response in compatible APIs.
:param workflow_id: Workflow ID.
:param rag: If set, will enable Retrieval-Augmented Generation.
:return: requests response
"""
return self._custom_request(
"predict_async", input_source, include_words, close_file, cropper, full_text
"predict_async",
input_source,
include_words,
close_file,
cropper,
full_text,
workflow_id,
rag,
)

def _custom_request(
Expand All @@ -83,6 +94,8 @@ def _custom_request(
close_file: bool = True,
cropper: bool = False,
full_text: bool = False,
workflow_id: Optional[str] = None,
rag: bool = False,
):
data = {}
if include_words:
Expand All @@ -93,11 +106,18 @@ def _custom_request(
params["full_text_ocr"] = "true"
if cropper:
params["cropper"] = "true"
if rag:
params["rag"] = "true"

if workflow_id:
url = f"{self.settings.base_url}/workflows/{workflow_id}/{route}"
else:
url = f"{self.settings.url_root}/{route}"

if isinstance(input_source, UrlInputSource):
data["document"] = input_source.url
response = requests.post(
f"{self.settings.url_root}/{route}",
url=url,
headers=self.settings.base_headers,
data=data,
params=params,
Expand All @@ -106,7 +126,7 @@ def _custom_request(
else:
files = {"document": input_source.read_contents(close_file)}
response = requests.post(
f"{self.settings.url_root}/{route}",
url=url,
files=files,
headers=self.settings.base_headers,
data=data,
Expand Down
1 change: 0 additions & 1 deletion mindee/parsing/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from mindee.parsing.common.page import Page
from mindee.parsing.common.predict_response import PredictResponse
from mindee.parsing.common.prediction import Prediction
from mindee.parsing.common.string_dict import StringDict
from mindee.parsing.common.summary_helper import (
clean_out_string,
format_for_display,
Expand Down
2 changes: 1 addition & 1 deletion mindee/parsing/common/async_predict_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AsyncPredictResponse(Generic[TypeInference], ApiResponse):

job: Job
"""Job object link to the prediction. As long as it isn't complete, the prediction doesn't exist."""
document: Optional[Document]
document: Optional[Document] = None

def __init__(
self, inference_type: Type[TypeInference], raw_response: StringDict
Expand Down
4 changes: 2 additions & 2 deletions mindee/parsing/common/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class Document(Generic[TypePrediction, TypePage]):
"""Result of the base inference"""
id: str
"""Id of the document as sent back by the server"""
extras: Optional[Extras]
extras: Optional[Extras] = None
"""Potential Extras fields sent back along the prediction"""
ocr: Optional[Ocr]
ocr: Optional[Ocr] = None
"""Potential raw text results read by the OCR (limited feature)"""
n_pages: int
"""Amount of pages in the document"""
Expand Down
Loading