Skip to content

Commit 7eea2fc

Browse files
authored
♻️ rework inference parameters (#338)
1 parent d16e96e commit 7eea2fc

File tree

11 files changed

+90
-60
lines changed

11 files changed

+90
-60
lines changed

docs/extras/code_samples/default_v2.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from mindee import ClientV2, InferencePredictOptions
1+
from mindee import ClientV2, InferenceParameters
22

33
input_path = "/path/to/the/file.ext"
44
api_key = "MY_API_KEY"
@@ -7,20 +7,20 @@ model_id = "MY_MODEL_ID"
77
# Init a new client
88
mindee_client = ClientV2(api_key)
99

10-
# Set inference options
11-
options = InferencePredictOptions(
10+
# Set inference paramters
11+
params = InferenceParameters(
1212
# ID of the model, required.
1313
model_id=model_id,
1414
# If set to `True`, will enable Retrieval-Augmented Generation.
1515
rag=False,
1616
)
1717

1818
# Load a file from disk
19-
input_doc = mindee_client.source_from_path(input_path)
19+
input_source = mindee_client.source_from_path(input_path)
2020

2121
# Upload the file
2222
response = mindee_client.enqueue_and_parse(
23-
input_doc, options
23+
input_source, params
2424
)
2525

2626
# Print a brief summary of the parsed data

mindee/__init__.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,32 @@
11
from mindee import product
22
from mindee.client import Client
33
from mindee.client_v2 import ClientV2
4-
from mindee.input.inference_predict_options import InferencePredictOptions
4+
from mindee.input.inference_parameters import InferenceParameters
55
from mindee.input.local_response import LocalResponse
66
from mindee.input.page_options import PageOptions
7+
from mindee.input.polling_options import PollingOptions
78
from mindee.parsing.common.api_response import ApiResponse
89
from mindee.parsing.common.async_predict_response import AsyncPredictResponse
910
from mindee.parsing.common.feedback_response import FeedbackResponse
1011
from mindee.parsing.common.job import Job
1112
from mindee.parsing.common.predict_response import PredictResponse
1213
from mindee.parsing.common.workflow_response import WorkflowResponse
14+
from mindee.parsing.v2.inference_response import InferenceResponse
15+
from mindee.parsing.v2.job_response import JobResponse
16+
17+
__all__ = [
18+
"Client",
19+
"ClientV2",
20+
"InferenceParameters",
21+
"LocalResponse",
22+
"PageOptions",
23+
"PollingOptions",
24+
"ApiResponse",
25+
"AsyncPredictResponse",
26+
"FeedbackResponse",
27+
"PredictResponse",
28+
"WorkflowResponse",
29+
"JobResponse",
30+
"InferenceResponse",
31+
"product",
32+
]

mindee/client_v2.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mindee.client_mixin import ClientMixin
55
from mindee.error.mindee_error import MindeeError
66
from mindee.error.mindee_http_error_v2 import handle_error_v2
7-
from mindee.input.inference_predict_options import InferencePredictOptions
7+
from mindee.input.inference_parameters import InferenceParameters
88
from mindee.input.local_response import LocalResponse
99
from mindee.input.polling_options import PollingOptions
1010
from mindee.input.sources.local_input_source import LocalInputSource
@@ -38,28 +38,21 @@ def __init__(self, api_key: Optional[str] = None) -> None:
3838
self.mindee_api = MindeeApiV2(api_key)
3939

4040
def enqueue(
41-
self, input_source: LocalInputSource, options: InferencePredictOptions
41+
self, input_source: LocalInputSource, params: InferenceParameters
4242
) -> JobResponse:
4343
"""
4444
Enqueues a document to a given model.
4545
4646
:param input_source: The document/source file to use.
4747
Has to be created beforehand.
4848
49-
:param options: Options for the prediction.
49+
:param params: Parameters to set when sending a file.
5050
:return: A valid inference response.
5151
"""
52-
logger.debug("Enqueuing document to '%s'", options.model_id)
53-
54-
if options.page_options and input_source.is_pdf():
55-
input_source.process_pdf(
56-
options.page_options.operation,
57-
options.page_options.on_min_pages,
58-
options.page_options.page_indexes,
59-
)
52+
logger.debug("Enqueuing document to '%s'", params.model_id)
6053

6154
response = self.mindee_api.predict_async_req_post(
62-
input_source=input_source, options=options
55+
input_source=input_source, options=params
6356
)
6457
dict_response = response.json()
6558

@@ -89,35 +82,35 @@ def parse_queued(
8982
return InferenceResponse(dict_response)
9083

9184
def enqueue_and_parse(
92-
self, input_source: LocalInputSource, options: InferencePredictOptions
85+
self, input_source: LocalInputSource, params: InferenceParameters
9386
) -> InferenceResponse:
9487
"""
9588
Enqueues to an asynchronous endpoint and automatically polls for a response.
9689
9790
:param input_source: The document/source file to use.
9891
Has to be created beforehand.
9992
100-
:param options: Options for the prediction.
93+
:param params: Parameters to set when sending a file.
10194
10295
:return: A valid inference response.
10396
"""
104-
if not options.polling_options:
105-
options.polling_options = PollingOptions()
97+
if not params.polling_options:
98+
params.polling_options = PollingOptions()
10699
self._validate_async_params(
107-
options.polling_options.initial_delay_sec,
108-
options.polling_options.delay_sec,
109-
options.polling_options.max_retries,
100+
params.polling_options.initial_delay_sec,
101+
params.polling_options.delay_sec,
102+
params.polling_options.max_retries,
110103
)
111-
queue_result = self.enqueue(input_source, options)
104+
queue_result = self.enqueue(input_source, params)
112105
logger.debug(
113106
"Successfully enqueued document with job id: %s", queue_result.job.id
114107
)
115-
sleep(options.polling_options.initial_delay_sec)
108+
sleep(params.polling_options.initial_delay_sec)
116109
retry_counter = 1
117110
poll_results = self.parse_queued(
118111
queue_result.job.id,
119112
)
120-
while retry_counter < options.polling_options.max_retries:
113+
while retry_counter < params.polling_options.max_retries:
121114
if not isinstance(poll_results, JobResponse):
122115
break
123116
if poll_results.job.status == "Failed":
@@ -133,7 +126,7 @@ def enqueue_and_parse(
133126
queue_result.job.id,
134127
)
135128
retry_counter += 1
136-
sleep(options.polling_options.delay_sec)
129+
sleep(params.polling_options.delay_sec)
137130
poll_results = self.parse_queued(queue_result.job.id)
138131

139132
if not isinstance(poll_results, InferenceResponse):

mindee/input/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from mindee.input.inference_predict_options import InferencePredictOptions
21
from mindee.input.local_response import LocalResponse
32
from mindee.input.page_options import PageOptions
43
from mindee.input.polling_options import PollingOptions
Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from dataclasses import dataclass
22
from typing import List, Optional
33

4-
from mindee.input.page_options import PageOptions
54
from mindee.input.polling_options import PollingOptions
65

76

87
@dataclass
9-
class InferencePredictOptions:
10-
"""Inference prediction options."""
8+
class InferenceParameters:
9+
"""Inference parameters to set when sending a file."""
1110

1211
model_id: str
1312
"""ID of the model, required."""
@@ -17,9 +16,7 @@ class InferencePredictOptions:
1716
"""Optional alias for the file."""
1817
webhook_ids: Optional[List[str]] = None
1918
"""IDs of webhooks to propagate the API response to."""
20-
page_options: Optional[PageOptions] = None
21-
"""Options for page-level inference."""
2219
polling_options: Optional[PollingOptions] = None
23-
"""Options for polling."""
20+
"""Options for polling. Set only if having timeout issues."""
2421
close_file: bool = True
2522
"""Whether to close the file after parsing."""

mindee/input/polling_options.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ class PollingOptions:
44
initial_delay_sec: float
55
"""Initial delay before the first polling attempt."""
66
delay_sec: float
7-
"""Delay between each polling attempts."""
7+
"""Delay between each polling attempt."""
88
max_retries: int
9-
"""Total amount of polling attempts."""
9+
"""Total number of polling attempts."""
1010

1111
def __init__(
1212
self,

mindee/input/sources/local_input_source.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mindee.error.mimetype_error import MimeTypeError
99
from mindee.error.mindee_error import MindeeError, MindeeSourceError
1010
from mindee.image_operations.image_compressor import compress_image
11-
from mindee.input.page_options import KEEP_ONLY, REMOVE
11+
from mindee.input.page_options import KEEP_ONLY, REMOVE, PageOptions
1212
from mindee.input.sources.input_type import InputType
1313
from mindee.logger import logger
1414
from mindee.pdf.pdf_compressor import compress_pdf
@@ -112,6 +112,16 @@ def count_doc_pages(self) -> int:
112112
return len(pdf)
113113
return 1
114114

115+
def apply_page_options(self, page_options: PageOptions) -> None:
116+
"""Apply cut and merge options on multipage documents."""
117+
if not self.is_pdf():
118+
raise MindeeSourceError(f"File is not a PDF: {self.filename}")
119+
self.process_pdf(
120+
page_options.operation,
121+
page_options.on_min_pages,
122+
page_options.page_indexes,
123+
)
124+
115125
def process_pdf(
116126
self,
117127
behavior: str,

mindee/mindee_http/mindee_api_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from mindee.error.mindee_error import MindeeApiV2Error
77
from mindee.input import LocalInputSource
8-
from mindee.input.inference_predict_options import InferencePredictOptions
8+
from mindee.input.inference_parameters import InferenceParameters
99
from mindee.logger import logger
1010
from mindee.mindee_http.base_settings import USER_AGENT
1111
from mindee.mindee_http.settings_mixin import SettingsMixin
@@ -68,7 +68,7 @@ def set_from_env(self) -> None:
6868
logger.debug("Value was set from env: %s", name)
6969

7070
def predict_async_req_post(
71-
self, input_source: LocalInputSource, options: InferencePredictOptions
71+
self, input_source: LocalInputSource, options: InferenceParameters
7272
) -> requests.Response:
7373
"""
7474
Make an asynchronous request to POST a document for prediction on the V2 API.

tests/test_client_v2.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from mindee import ClientV2, InferencePredictOptions, LocalResponse
5+
from mindee import ClientV2, InferenceParameters, LocalResponse
66
from mindee.error.mindee_error import MindeeApiV2Error
77
from mindee.error.mindee_http_error_v2 import MindeeHTTPErrorV2
88
from mindee.input import LocalInputSource, PathInput
@@ -96,9 +96,7 @@ def test_enqueue_path_with_env_token(custom_base_url_client):
9696
f"{FILE_TYPES_DIR}/receipt.jpg"
9797
)
9898
with pytest.raises(MindeeHTTPErrorV2):
99-
custom_base_url_client.enqueue(
100-
input_doc, InferencePredictOptions("dummy-model")
101-
)
99+
custom_base_url_client.enqueue(input_doc, InferenceParameters("dummy-model"))
102100

103101

104102
@pytest.mark.v2
@@ -108,7 +106,7 @@ def test_enqueue_and_parse_path_with_env_token(custom_base_url_client):
108106
)
109107
with pytest.raises(MindeeHTTPErrorV2):
110108
custom_base_url_client.enqueue_and_parse(
111-
input_doc, InferencePredictOptions("dummy-model")
109+
input_doc, InferenceParameters("dummy-model")
112110
)
113111

114112

@@ -128,7 +126,7 @@ def test_error_handling(custom_base_url_client):
128126
PathInput(
129127
V2_DATA_DIR / "products" / "financial_document" / "default_sample.jpg"
130128
),
131-
InferencePredictOptions("dummy-model"),
129+
InferenceParameters("dummy-model"),
132130
)
133131
assert e.status_code == -1
134132
assert e.detail == "forced failure from test"

tests/test_client_v2_integration.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77

8-
from mindee import ClientV2, InferencePredictOptions
8+
from mindee import ClientV2, InferenceParameters
99
from mindee.error.mindee_http_error_v2 import MindeeHTTPErrorV2
1010
from mindee.parsing.v2.inference_response import InferenceResponse
1111
from tests.test_inputs import FILE_TYPES_DIR, PRODUCT_DATA_DIR
@@ -40,7 +40,7 @@ def test_parse_file_empty_multiple_pages_must_succeed(
4040
assert input_path.exists(), f"sample file missing: {input_path}"
4141

4242
input_doc = v2_client.source_from_path(input_path)
43-
options = InferencePredictOptions(findoc_model_id)
43+
options = InferenceParameters(findoc_model_id)
4444

4545
response: InferenceResponse = v2_client.enqueue_and_parse(input_doc, options)
4646

@@ -66,7 +66,7 @@ def test_parse_file_filled_single_page_must_succeed(
6666
assert input_path.exists(), f"sample file missing: {input_path}"
6767

6868
input_doc = v2_client.source_from_path(input_path)
69-
options = InferencePredictOptions(findoc_model_id)
69+
options = InferenceParameters(findoc_model_id)
7070

7171
response: InferenceResponse = v2_client.enqueue_and_parse(input_doc, options)
7272

@@ -95,7 +95,7 @@ def test_invalid_uuid_must_throw_error_422(v2_client: ClientV2) -> None:
9595
assert input_path.exists()
9696

9797
input_doc = v2_client.source_from_path(input_path)
98-
options = InferencePredictOptions("INVALID MODEL ID")
98+
options = InferenceParameters("INVALID MODEL ID")
9999

100100
with pytest.raises(MindeeHTTPErrorV2) as exc_info:
101101
v2_client.enqueue(input_doc, options)

0 commit comments

Comments
 (0)