Skip to content

Commit 440fe3f

Browse files
fix polling & add tests
1 parent 9a77744 commit 440fe3f

File tree

8 files changed

+52
-19
lines changed

8 files changed

+52
-19
lines changed

mindee/client_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,14 @@ def parse_queued(
9191
9292
:param queue_id: queue_id received from the API.
9393
"""
94-
logger.debug("Fetching from queue ''%s", queue_id)
94+
logger.debug("Fetching from queue '%s'.", queue_id)
9595

9696
response = self.mindee_api.document_queue_req_get(queue_id)
9797
if not is_valid_get_response(response):
9898
handle_error_v2(response.json())
9999

100100
dict_response = response.json()
101-
if dict_response.get("job"):
101+
if "job" in dict_response:
102102
return PollingResponse(dict_response)
103103
return InferenceResponse(dict_response)
104104

mindee/mindee_http/mindee_api_v2.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,28 +82,23 @@ def predict_async_req_post(
8282
:return: requests response.
8383
"""
8484
data = {"model_id": options.model_id}
85-
params = {}
8685
url = f"{self.url_root}/inferences/enqueue"
8786

8887
if options.full_text:
89-
params["full_text_ocr"] = "true"
88+
data["full_text_ocr"] = "true"
9089
if options.rag:
91-
params["rag"] = "true"
90+
data["rag"] = "true"
9291
if options.webhook_ids and len(options.webhook_ids) > 0:
93-
params["webhook_ids"] = ",".join(options.webhook_ids)
92+
data["webhook_ids"] = ",".join(options.webhook_ids)
9493
if options.alias and len(options.alias):
9594
data["alias"] = options.alias
9695

97-
files = {
98-
"file": input_source.read_contents(close_file)
99-
+ (input_source.file_mimetype,)
100-
}
96+
files = {"file": input_source.read_contents(close_file)}
10197
response = requests.post(
10298
url=url,
10399
files=files,
104100
headers=self.base_headers,
105101
data=data,
106-
params=params,
107102
timeout=self.request_timeout,
108103
)
109104

@@ -116,7 +111,7 @@ def document_queue_req_get(self, queue_id: str) -> requests.Response:
116111
:param queue_id: queue_id received from the API
117112
"""
118113
return requests.get(
119-
f"{self.url_root}/inferences/{queue_id}",
114+
f"{self.url_root}/jobs/{queue_id}",
120115
headers=self.base_headers,
121116
timeout=self.request_timeout,
122117
)

mindee/mindee_http/response_validation_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ def is_valid_get_response(response: requests.Response) -> bool:
3636
if not is_valid_sync_response(response):
3737
return False
3838
response_json = json.loads(response.content)
39-
if not "inference" in response_json:
39+
if not "inference" in response_json and not "job" in response_json:
4040
return False
4141
return True

mindee/parsing/v2/inference.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
from mindee.parsing.common.string_dict import StringDict
24
from mindee.parsing.v2.inference_file import InferenceFile
35
from mindee.parsing.v2.inference_model import InferenceModel
@@ -13,11 +15,14 @@ class Inference:
1315
"""File info for the inference."""
1416
result: InferenceResult
1517
"""Result of the inference."""
18+
id: Optional[str]
19+
"""ID of the inference."""
1620

1721
def __init__(self, raw_response: StringDict):
1822
self.model = InferenceModel(raw_response["model"])
1923
self.file = InferenceFile(raw_response["file"])
2024
self.result = InferenceResult(raw_response["result"])
25+
self.id = raw_response["id"] if "id" in raw_response else None
2126

2227
def __str__(self) -> str:
2328
return (

mindee/parsing/v2/inference_response.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ class InferenceResponse(CommonResponse):
77
"""Represent an inference response from Mindee V2 API."""
88

99
inference: Inference
10+
"""Inference result."""
1011

1112
def __init__(self, raw_response: StringDict) -> None:
1213
super().__init__(raw_response)

mindee/parsing/v2/polling_response.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ class PollingResponse(CommonResponse):
77
"""Represent an inference response from Mindee V2 API."""
88

99
job: Job
10+
"""Job for the polling."""
1011

1112
def __init__(self, raw_response: StringDict) -> None:
1213
super().__init__(raw_response)

tests/test_client_v2_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_parse_file_empty_multiple_pages_must_succeed(
4242
input_doc = v2_client.source_from_path(input_path)
4343
options = InferencePredictOptions(findoc_model_id)
4444

45-
response: InferenceResponse = v2_client.enqueue_and_parse(input_doc, options, False)
45+
response: InferenceResponse = v2_client.enqueue_and_parse(input_doc, options)
4646

4747
assert response is not None
4848
assert response.inference is not None
@@ -82,7 +82,7 @@ def test_parse_file_filled_single_page_must_succeed(
8282
assert response.inference.result is not None
8383
supplier_name = response.inference.result.fields["supplier_name"]
8484
assert supplier_name is not None
85-
assert supplier_name.simple_field.value == "John Smith"
85+
assert supplier_name.value == "John Smith"
8686

8787

8888
@pytest.mark.integration

tests/v2/test_inference_response.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
1+
import json
2+
13
import pytest
24

5+
from mindee import ClientV2, LocalResponse
36
from mindee.parsing.common.string_dict import StringDict
47
from mindee.parsing.v2 import (
58
Inference,
9+
InferenceFile,
10+
InferenceModel,
611
InferenceResponse,
7-
InferenceResult,
812
ListField,
913
ObjectField,
1014
SimpleField,
1115
)
16+
from tests.test_inputs import V2_DATA_DIR
1217

1318

1419
@pytest.fixture
15-
def inference_json() -> StringDict:
20+
def inference_result_json() -> StringDict:
1621
return {
1722
"inference": {
1823
"model": {"id": "test-model-id"},
@@ -92,8 +97,8 @@ def inference_json() -> StringDict:
9297

9398

9499
@pytest.mark.v2
95-
def test_inference(inference_json):
96-
inference_result = InferenceResponse(inference_json)
100+
def test_inference_response(inference_result_json):
101+
inference_result = InferenceResponse(inference_result_json)
97102
assert isinstance(inference_result.inference, Inference)
98103
assert isinstance(
99104
inference_result.inference.result.fields.field_simple, SimpleField
@@ -158,3 +163,29 @@ def test_inference(inference_json):
158163
.value
159164
== "value_9"
160165
)
166+
167+
168+
@pytest.mark.v2
169+
def test_full_inference_response():
170+
client_v2 = ClientV2("dummy")
171+
load_response = client_v2.load_inference(
172+
LocalResponse(V2_DATA_DIR / "products" / "financial_document" / "complete.json")
173+
)
174+
175+
assert isinstance(load_response.inference, Inference)
176+
assert load_response.inference.id == "12345678-1234-1234-1234-123456789abc"
177+
assert isinstance(load_response.inference.result.fields.date, SimpleField)
178+
assert load_response.inference.result.fields.date.value == "2019-11-02"
179+
assert isinstance(load_response.inference.result.fields.taxes, ListField)
180+
assert isinstance(load_response.inference.result.fields.taxes.items[0], ObjectField)
181+
assert (
182+
load_response.inference.result.fields.taxes.items[0].fields["base"].value
183+
== 31.5
184+
)
185+
186+
assert isinstance(load_response.inference.model, InferenceModel)
187+
assert load_response.inference.model.id == "12345678-1234-1234-1234-123456789abc"
188+
189+
assert isinstance(load_response.inference.file, InferenceFile)
190+
assert load_response.inference.file.name == "complete.jpg"
191+
assert load_response.inference.file.alias == None

0 commit comments

Comments
 (0)