From f9d261e22061a8e1b1f3eafb537d8f81809e4791 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ianar=C3=A9=20S=C3=A9vi?= Date: Mon, 28 Apr 2025 22:38:50 +0200 Subject: [PATCH] :sparkles: add support for workflow polling --- mindee/client.py | 97 ++++++++++++------- mindee/input/predict_options.py | 31 ++++++ mindee/mindee_http/endpoint.py | 28 +++++- mindee/parsing/common/__init__.py | 1 - .../parsing/common/async_predict_response.py | 2 +- mindee/parsing/common/document.py | 4 +- mindee/parsing/common/execution.py | 8 +- mindee/parsing/common/extras/extras.py | 10 +- mindee/parsing/common/extras/rag_extra.py | 13 +++ mindee/parsing/common/inference.py | 8 +- mindee/parsing/common/job.py | 4 +- mindee/parsing/common/page.py | 4 +- tests/workflows/test_workflow_integration.py | 73 +++++++++++++- 13 files changed, 225 insertions(+), 58 deletions(-) create mode 100644 mindee/input/predict_options.py create mode 100644 mindee/parsing/common/extras/rag_extra.py diff --git a/mindee/client.py b/mindee/client.py index 07a70219..5e508e3a 100644 --- a/mindee/client.py +++ b/mindee/client.py @@ -7,6 +7,7 @@ from mindee.input import WorkflowOptions from mindee.input.local_response import LocalResponse from mindee.input.page_options import PageOptions +from mindee.input.predict_options import AsyncPredictOptions, PredictOptions from mindee.input.sources.base_64_input import Base64Input from mindee.input.sources.bytes_input import BytesInput from mindee.input.sources.file_input import FileInput @@ -123,14 +124,13 @@ def parse( page_options.on_min_pages, page_options.page_indexes, ) + options = PredictOptions(cropper, full_text, include_words) return self._make_request( product_class, input_source, endpoint, - include_words, + options, close_file, - cropper, - full_text, ) def enqueue( @@ -143,6 +143,8 @@ def enqueue( cropper: bool = False, endpoint: Optional[Endpoint] = None, full_text: bool = False, + workflow_id: Optional[str] = None, + rag: bool = False, ) -> AsyncPredictResponse: """ Enqueues a document to an asynchronous endpoint. @@ -169,6 +171,11 @@ def enqueue( :param endpoint: For custom endpoints, an endpoint has to be given. :param full_text: Whether to include the full OCR text response in compatible APIs. + + :param workflow_id: Workflow ID. + + :param rag: If set, will enable Retrieval-Augmented Generation. + Only works if a valid ``workflow_id`` is set. """ if input_source is None: raise MindeeClientError("No input document provided.") @@ -185,14 +192,15 @@ def enqueue( page_options.on_min_pages, page_options.page_indexes, ) + options = AsyncPredictOptions( + cropper, full_text, include_words, workflow_id, rag + ) return self._predict_async( product_class, input_source, + options, endpoint, - include_words, close_file, - cropper, - full_text, ) def load_prediction( @@ -246,8 +254,9 @@ def execute_workflow( :param input_source: The document/source file to use. Has to be created beforehand. :param workflow_id: ID of the workflow. - :param page_options: If set, remove pages from the document as specified. This is done before sending the file\ - to the server. It is useful to avoid page limitations. + :param page_options: If set, remove pages from the document as specified. + This is done before sending the file to the server. + It is useful to avoid page limitations. :param options: Options for the workflow. :return: """ @@ -259,13 +268,11 @@ def execute_workflow( page_options.page_indexes, ) - logger.debug("Sending document to workflow: %s", workflow_id) - if not options: options = WorkflowOptions( alias=None, priority=None, full_text=False, public_url=None, rag=False ) - + logger.debug("Sending document to workflow: %s", workflow_id) return self._send_to_workflow(GeneratedV1, input_source, workflow_id, options) def _validate_async_params( @@ -285,7 +292,7 @@ def _validate_async_params( if max_retries < min_retries: raise MindeeClientError(f"Cannot set retries to less than {min_retries}.") - def enqueue_and_parse( + def enqueue_and_parse( # pylint: disable=too-many-locals self, product_class: Type[Inference], input_source: Union[LocalInputSource, UrlInputSource], @@ -298,40 +305,51 @@ def enqueue_and_parse( delay_sec: float = 1.5, max_retries: int = 80, full_text: bool = False, + workflow_id: Optional[str] = None, + rag: bool = False, ) -> AsyncPredictResponse: """ Enqueues to an asynchronous endpoint and automatically polls for a response. - :param product_class: The document class to use. The response object will be instantiated based on this\ -parameter. + :param product_class: The document class to use. + The response object will be instantiated based on this parameter. - :param input_source: The document/source file to use. Has to be created beforehand. + :param input_source: The document/source file to use. + Has to be created beforehand. - :param include_words: Whether to include the full text for each page. This performs a full OCR operation on\ - the server and will increase response time. + :param include_words: Whether to include the full text for each page. + This performs a full OCR operation on the server and will increase response time. - :param close_file: Whether to ``close()`` the file after parsing it. Set to ``False`` if you need to access\ - the file after this operation. + :param close_file: Whether to ``close()`` the file after parsing it. + Set to ``False`` if you need to access the file after this operation. - :param page_options: If set, remove pages from the document as specified. This is done before sending the file\ - to the server. It is useful to avoid page limitations. + :param page_options: If set, remove pages from the document as specified. + This is done before sending the file to the server. + It is useful to avoid page limitations. - :param cropper: Whether to include cropper results for each page. This performs a cropping operation on the\ - server and will increase response time. + :param cropper: Whether to include cropper results for each page. + This performs a cropping operation on the server and will increase response time. :param endpoint: For custom endpoints, an endpoint has to be given. - :param initial_delay_sec: Delay between each polling attempts This should not be shorter than 1 second. + :param initial_delay_sec: Delay between each polling attempts. + This should not be shorter than 1 second. - :param delay_sec: Delay between each polling attempts This should not be shorter than 1 second. + :param delay_sec: Delay between each polling attempts. + This should not be shorter than 1 second. :param max_retries: Total amount of polling attempts. :param full_text: Whether to include the full OCR text response in compatible APIs. + + :param workflow_id: Workflow ID. + + :param rag: If set, will enable Retrieval-Augmented Generation. + Only works if a valid ``workflow_id`` is set. """ self._validate_async_params(initial_delay_sec, delay_sec, max_retries) if not endpoint: - endpoint = self._initialize_ots_endpoint(product_class) + endpoint = self._initialize_ots_endpoint(product_class=product_class) queue_result = self.enqueue( product_class, input_source, @@ -341,6 +359,8 @@ def enqueue_and_parse( cropper, endpoint, full_text, + workflow_id, + rag, ) logger.debug( "Successfully enqueued document with job id: %s", queue_result.job.id @@ -406,15 +426,16 @@ def _make_request( product_class: Type[Inference], input_source: Union[LocalInputSource, UrlInputSource], endpoint: Endpoint, - include_words: bool, + options: PredictOptions, close_file: bool, - cropper: bool, - full_text: bool, ) -> PredictResponse: response = endpoint.predict_req_post( - input_source, include_words, close_file, cropper, full_text + input_source, + options.include_words, + close_file, + options.cropper, + options.full_text, ) - dict_response = response.json() if not is_valid_sync_response(response): @@ -423,18 +444,15 @@ def _make_request( str(product_class.endpoint_name), clean_response, ) - return PredictResponse(product_class, dict_response) def _predict_async( self, product_class: Type[Inference], input_source: Union[LocalInputSource, UrlInputSource], + options: AsyncPredictOptions, endpoint: Optional[Endpoint] = None, - include_words: bool = False, close_file: bool = True, - cropper: bool = False, - full_text: bool = False, ) -> AsyncPredictResponse: """Sends a document to the queue, and sends back an asynchronous predict response.""" if input_source is None: @@ -442,9 +460,14 @@ def _predict_async( if not endpoint: endpoint = self._initialize_ots_endpoint(product_class) response = endpoint.predict_async_req_post( - input_source, include_words, close_file, cropper, full_text + input_source=input_source, + include_words=options.include_words, + close_file=close_file, + cropper=options.cropper, + full_text=options.full_text, + workflow_id=options.workflow_id, + rag=options.rag, ) - dict_response = response.json() if not is_valid_async_response(response): diff --git a/mindee/input/predict_options.py b/mindee/input/predict_options.py new file mode 100644 index 00000000..b3aa17b2 --- /dev/null +++ b/mindee/input/predict_options.py @@ -0,0 +1,31 @@ +from typing import Optional + + +class PredictOptions: + """Options to pass to a prediction.""" + + def __init__( + self, + cropper: bool = False, + full_text: bool = False, + include_words: bool = False, + ): + self.cropper = cropper + self.full_text = full_text + self.include_words = include_words + + +class AsyncPredictOptions(PredictOptions): + """Options to pass to an asynchronous prediction.""" + + def __init__( + self, + cropper: bool = False, + full_text: bool = False, + include_words: bool = False, + workflow_id: Optional[str] = None, + rag: bool = False, + ): + super().__init__(cropper, full_text, include_words) + self.workflow_id = workflow_id + self.rag = rag diff --git a/mindee/mindee_http/endpoint.py b/mindee/mindee_http/endpoint.py index 227c1e2f..0275328d 100644 --- a/mindee/mindee_http/endpoint.py +++ b/mindee/mindee_http/endpoint.py @@ -1,5 +1,5 @@ import json -from typing import Union +from typing import Optional, Union import requests from requests import Response @@ -60,6 +60,8 @@ def predict_async_req_post( close_file: bool = True, cropper: bool = False, full_text: bool = False, + workflow_id: Optional[str] = None, + rag: bool = False, ) -> requests.Response: """ Make an asynchronous request to POST a document for prediction. @@ -69,10 +71,19 @@ def predict_async_req_post( :param close_file: Whether to `close()` the file after parsing it. :param cropper: Including Mindee cropping results. :param full_text: Whether to include the full OCR text response in compatible APIs. + :param workflow_id: Workflow ID. + :param rag: If set, will enable Retrieval-Augmented Generation. :return: requests response """ return self._custom_request( - "predict_async", input_source, include_words, close_file, cropper, full_text + "predict_async", + input_source, + include_words, + close_file, + cropper, + full_text, + workflow_id, + rag, ) def _custom_request( @@ -83,6 +94,8 @@ def _custom_request( close_file: bool = True, cropper: bool = False, full_text: bool = False, + workflow_id: Optional[str] = None, + rag: bool = False, ): data = {} if include_words: @@ -93,11 +106,18 @@ def _custom_request( params["full_text_ocr"] = "true" if cropper: params["cropper"] = "true" + if rag: + params["rag"] = "true" + + if workflow_id: + url = f"{self.settings.base_url}/workflows/{workflow_id}/{route}" + else: + url = f"{self.settings.url_root}/{route}" if isinstance(input_source, UrlInputSource): data["document"] = input_source.url response = requests.post( - f"{self.settings.url_root}/{route}", + url=url, headers=self.settings.base_headers, data=data, params=params, @@ -106,7 +126,7 @@ def _custom_request( else: files = {"document": input_source.read_contents(close_file)} response = requests.post( - f"{self.settings.url_root}/{route}", + url=url, files=files, headers=self.settings.base_headers, data=data, diff --git a/mindee/parsing/common/__init__.py b/mindee/parsing/common/__init__.py index 4707e85d..8b43a6ff 100644 --- a/mindee/parsing/common/__init__.py +++ b/mindee/parsing/common/__init__.py @@ -15,7 +15,6 @@ from mindee.parsing.common.page import Page from mindee.parsing.common.predict_response import PredictResponse from mindee.parsing.common.prediction import Prediction -from mindee.parsing.common.string_dict import StringDict from mindee.parsing.common.summary_helper import ( clean_out_string, format_for_display, diff --git a/mindee/parsing/common/async_predict_response.py b/mindee/parsing/common/async_predict_response.py index 6897bd6f..e3101633 100644 --- a/mindee/parsing/common/async_predict_response.py +++ b/mindee/parsing/common/async_predict_response.py @@ -15,7 +15,7 @@ class AsyncPredictResponse(Generic[TypeInference], ApiResponse): job: Job """Job object link to the prediction. As long as it isn't complete, the prediction doesn't exist.""" - document: Optional[Document] + document: Optional[Document] = None def __init__( self, inference_type: Type[TypeInference], raw_response: StringDict diff --git a/mindee/parsing/common/document.py b/mindee/parsing/common/document.py index cb0af6a3..cf5f5e9c 100644 --- a/mindee/parsing/common/document.py +++ b/mindee/parsing/common/document.py @@ -29,9 +29,9 @@ class Document(Generic[TypePrediction, TypePage]): """Result of the base inference""" id: str """Id of the document as sent back by the server""" - extras: Optional[Extras] + extras: Optional[Extras] = None """Potential Extras fields sent back along the prediction""" - ocr: Optional[Ocr] + ocr: Optional[Ocr] = None """Potential raw text results read by the OCR (limited feature)""" n_pages: int """Amount of pages in the document""" diff --git a/mindee/parsing/common/execution.py b/mindee/parsing/common/execution.py index f218656e..3d0a725a 100644 --- a/mindee/parsing/common/execution.py +++ b/mindee/parsing/common/execution.py @@ -16,7 +16,7 @@ class Execution(Generic[TypePrediction]): batch_name: str """Identifier for the batch to which the execution belongs.""" - created_at: Optional[datetime] + created_at: Optional[datetime] = None """The time at which the execution started.""" file: ExecutionFile @@ -28,7 +28,7 @@ class Execution(Generic[TypePrediction]): inference: Optional[Inference[TypePrediction, Page[TypePrediction]]] """Deserialized inference object.""" - priority: Optional["ExecutionPriority"] + priority: Optional["ExecutionPriority"] = None """Priority of the execution.""" reviewed_at: Optional[datetime] @@ -37,7 +37,7 @@ class Execution(Generic[TypePrediction]): available_at: Optional[datetime] """The time at which the file was uploaded to a workflow.""" - reviewed_prediction: Optional["GeneratedV1Document"] + reviewed_prediction: Optional["GeneratedV1Document"] = None """Reviewed fields and values.""" status: str @@ -46,7 +46,7 @@ class Execution(Generic[TypePrediction]): type: Optional[str] """Execution type.""" - uploaded_at: Optional[datetime] + uploaded_at: Optional[datetime] = None """The time at which the file was uploaded to a workflow.""" workflow_id: str diff --git a/mindee/parsing/common/extras/extras.py b/mindee/parsing/common/extras/extras.py index 64ce8463..8d1f8f94 100644 --- a/mindee/parsing/common/extras/extras.py +++ b/mindee/parsing/common/extras/extras.py @@ -2,6 +2,7 @@ from mindee.parsing.common.extras.cropper_extra import CropperExtra from mindee.parsing.common.extras.full_text_ocr_extra import FullTextOcrExtra +from mindee.parsing.common.extras.rag_extra import RagExtra from mindee.parsing.common.string_dict import StringDict @@ -12,16 +13,19 @@ class Extras: Is roughly equivalent to a dict of Extras, with a bit more utility. """ - cropper: Optional[CropperExtra] - full_text_ocr: Optional[FullTextOcrExtra] + cropper: Optional[CropperExtra] = None + full_text_ocr: Optional[FullTextOcrExtra] = None + rag: Optional[RagExtra] = None def __init__(self, raw_prediction: StringDict) -> None: if "cropper" in raw_prediction and raw_prediction["cropper"]: self.cropper = CropperExtra(raw_prediction["cropper"]) if "full_text_ocr" in raw_prediction and raw_prediction["full_text_ocr"]: self.full_text_ocr = FullTextOcrExtra(raw_prediction["full_text_ocr"]) + if "rag" in raw_prediction and raw_prediction["rag"]: + self.rag = RagExtra(raw_prediction["rag"]) for key, extra in raw_prediction.items(): - if key not in ["cropper", "full_text_ocr"]: + if key not in ["cropper", "full_text_ocr", "rag"]: setattr(self, key, extra) def __str__(self) -> str: diff --git a/mindee/parsing/common/extras/rag_extra.py b/mindee/parsing/common/extras/rag_extra.py new file mode 100644 index 00000000..088675dd --- /dev/null +++ b/mindee/parsing/common/extras/rag_extra.py @@ -0,0 +1,13 @@ +from typing import Optional + +from mindee.parsing.common.string_dict import StringDict + + +class RagExtra: + """Contains information on the Retrieval-Augmented-Generation of a prediction.""" + + matching_document_id: Optional[str] = None + + def __init__(self, raw_prediction: StringDict) -> None: + if raw_prediction and "matching_document_id" in raw_prediction: + self.matching_document_id = raw_prediction["matching_document_id"] diff --git a/mindee/parsing/common/inference.py b/mindee/parsing/common/inference.py index f4c91b6f..5a61d353 100644 --- a/mindee/parsing/common/inference.py +++ b/mindee/parsing/common/inference.py @@ -1,6 +1,7 @@ from typing import Dict, Generic, List, Optional, Type, TypeVar from mindee.error.mindee_error import MindeeError +from mindee.parsing.common.extras import Extras from mindee.parsing.common.page import TypePage from mindee.parsing.common.prediction import TypePrediction from mindee.parsing.common.product import Product @@ -24,6 +25,8 @@ class Inference(Generic[TypePrediction, TypePage]): """Whether the document has had any rotation applied to it.""" page_id: Optional[int] """Optional page id for page-level predictions.""" + extras: Optional[Extras] = None + """Potential Extras fields sent back along with the prediction.""" def __init__(self, raw_prediction: StringDict, page_id: Optional[int] = None): self.is_rotation_applied = None @@ -33,6 +36,9 @@ def __init__(self, raw_prediction: StringDict, page_id: Optional[int] = None): if page_id: self.page_id = page_id + if "extras" in raw_prediction and raw_prediction["extras"]: + self.extras = Extras(raw_prediction["extras"]) + def __str__(self) -> str: rotation_applied_str = "Yes" if self.is_rotation_applied else "No" prediction_str = "" @@ -57,7 +63,7 @@ def __str__(self) -> str: @staticmethod def get_endpoint_info(klass: Type["Inference"]) -> Dict[str, str]: """ - Retrives the endpoint information for an Inference. + Retrieves the endpoint information for an Inference. Should never retrieve info for CustomV1, as a custom endpoint should be created to use CustomV1. diff --git a/mindee/parsing/common/job.py b/mindee/parsing/common/job.py index 5aceac9c..486015ee 100644 --- a/mindee/parsing/common/job.py +++ b/mindee/parsing/common/job.py @@ -14,11 +14,11 @@ class Job: id: str """ID of the job sent by the API in response to an enqueue request.""" - error: Optional[StringDict] + error: Optional[StringDict] = None """Information about an error that occurred during the job processing.""" issued_at: datetime """Timestamp of the request reception by the API.""" - available_at: Optional[datetime] + available_at: Optional[datetime] = None """Timestamp of the request after it has been completed.""" status: str """Status of the request, as seen by the API.""" diff --git a/mindee/parsing/common/page.py b/mindee/parsing/common/page.py index 8266c862..d9be8cab 100644 --- a/mindee/parsing/common/page.py +++ b/mindee/parsing/common/page.py @@ -11,11 +11,11 @@ class Page(Generic[TypePrediction]): id: int """Id of the current page.""" - orientation: Optional[OrientationField] + orientation: Optional[OrientationField] = None """Orientation of the page""" prediction: TypePrediction """Type of Page prediction.""" - extras: Optional[Extras] + extras: Optional[Extras] = None def __init__( self, diff --git a/tests/workflows/test_workflow_integration.py b/tests/workflows/test_workflow_integration.py index 9acbec39..4c4a2cdb 100644 --- a/tests/workflows/test_workflow_integration.py +++ b/tests/workflows/test_workflow_integration.py @@ -6,6 +6,7 @@ from mindee import Client from mindee.input import WorkflowOptions from mindee.parsing.common.execution_priority import ExecutionPriority +from mindee.product import FinancialDocumentV1, GeneratedV1 from tests.product import PRODUCT_DATA_DIR @@ -25,7 +26,7 @@ def input_path(): @pytest.mark.integration -def test_workflow(mindee_client: Client, workflow_id: str, input_path: str): +def test_workflow_execution(mindee_client: Client, workflow_id: str, input_path: str): input_source = mindee_client.source_from_path(str(input_path)) current_date_time = datetime.now().strftime("%Y-%m-%d-%H:%M:%S") alias = f"python-{current_date_time}" @@ -37,3 +38,73 @@ def test_workflow(mindee_client: Client, workflow_id: str, input_path: str): assert response.api_request.status_code == 202 assert response.execution.file.alias == f"python-{current_date_time}" assert response.execution.priority == "low" + + +@pytest.mark.integration +def test_workflow_predict_ots_rag( + mindee_client: Client, workflow_id: str, input_path: str +): + input_source = mindee_client.source_from_path(str(input_path)) + + response = mindee_client.enqueue_and_parse( + FinancialDocumentV1, + input_source, + workflow_id=workflow_id, + rag=True, + ) + assert len(response.document.inference.extras.rag.matching_document_id) > 5 + + +@pytest.mark.integration +def test_workflow_predict_ots_no_rag( + mindee_client: Client, workflow_id: str, input_path: str +): + input_source = mindee_client.source_from_path(str(input_path)) + + response = mindee_client.enqueue_and_parse( + FinancialDocumentV1, + input_source, + workflow_id=workflow_id, + ) + assert response.document.inference.extras is None + + +@pytest.mark.integration +def test_workflow_predict_custom_rag( + mindee_client: Client, workflow_id: str, input_path: str +): + my_endpoint = mindee_client.create_endpoint( + account_name="mindee", + endpoint_name="financial_document", + ) + + input_source = mindee_client.source_from_path(str(input_path)) + + response = mindee_client.enqueue_and_parse( + GeneratedV1, + input_source, + endpoint=my_endpoint, + workflow_id=workflow_id, + rag=True, + ) + assert len(response.document.inference.extras.rag.matching_document_id) > 5 + + +@pytest.mark.integration +def test_workflow_predict_custom_no_rag( + mindee_client: Client, workflow_id: str, input_path: str +): + my_endpoint = mindee_client.create_endpoint( + account_name="mindee", + endpoint_name="financial_document", + ) + + input_source = mindee_client.source_from_path(str(input_path)) + + response = mindee_client.enqueue_and_parse( + GeneratedV1, + input_source, + endpoint=my_endpoint, + workflow_id=workflow_id, + ) + assert response.document.inference.extras is None