Skip to content

Commit 8eb20dd

Browse files
authored
✨ add support for workflow polling (#238)
1 parent dda75f5 commit 8eb20dd

File tree

8 files changed

+199
-27
lines changed

8 files changed

+199
-27
lines changed

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ import com.mindee.product.us.bankcheck.BankCheckV1;
7474
### Custom Documents (docTI & Custom APIs)
7575
```java
7676
import com.mindee.MindeeClient;
77+
import com.mindee.PredictOptions;
7778
import com.mindee.input.LocalInputSource;
7879
import com.mindee.parsing.common.PredictResponse;
7980
import com.mindee.product.generated.GeneratedV1;
@@ -98,6 +99,7 @@ public class SimpleMindeeClient {
9899
Document<GeneratedV1> customDocument = mindeeClient.enqueueAndParse(
99100
localInputSource,
100101
endpoint
102+
// PredictOptions.builder().build(),
101103
);
102104
}
103105
}
@@ -116,6 +118,7 @@ This is the easiest way to get started.
116118

117119
```java
118120
import com.mindee.MindeeClient;
121+
import com.mindee.PredictOptions;
119122
import com.mindee.input.LocalInputSource;
120123
import com.mindee.parsing.common.AsyncPredictResponse;
121124
import com.mindee.product.internationalid.InternationalIdV2;
@@ -138,6 +141,7 @@ public class SimpleMindeeClient {
138141
AsyncPredictResponse<InternationalIdV2> response = mindeeClient.enqueueAndParse(
139142
InternationalIdV2.class,
140143
inputSource
144+
// PredictOptions.builder().build(),
141145
);
142146

143147
// Print a summary of the response

docs/code_samples/workflow_execution.txt

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ public class SimpleMindeeClient {
2424
inputSource
2525
);
2626

27-
2827
// Alternatively: give an alias to the document
2928
// WorkflowResponse response = mindeeClient.executeWorkflow(
3029
// workflowId,

src/main/java/com/mindee/MindeeClient.java

+78
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,30 @@ public <T extends Inference> AsyncPredictResponse<T> enqueue(
146146
);
147147
}
148148

149+
/**
150+
* Send a local file to an async queue.
151+
* @param <T> Type of inference.
152+
* @param type Type of inference.
153+
* @param localInputSource A local input source file.
154+
* @param predictOptions Prediction options for the enqueuing.
155+
* @return an instance of {@link AsyncPredictResponse}.
156+
* @throws IOException Throws if the file can't be accessed.
157+
*/
158+
public <T extends Inference> AsyncPredictResponse<T> enqueue(
159+
Class<T> type,
160+
LocalInputSource localInputSource,
161+
PredictOptions predictOptions
162+
) throws IOException {
163+
return this.enqueue(
164+
type,
165+
new Endpoint(type),
166+
localInputSource.getFile(),
167+
localInputSource.getFilename(),
168+
predictOptions,
169+
null
170+
);
171+
}
172+
149173
/**
150174
* Send a remote file to an async queue.
151175
* @param <T> Type of inference.
@@ -290,6 +314,60 @@ public <T extends Inference> AsyncPredictResponse<T> enqueueAndParse(
290314
);
291315
}
292316

317+
/**
318+
* Send a local file to an async queue, poll, and parse when complete.
319+
* @param <T> Type of inference.
320+
* @param type Type of inference.
321+
* @param localInputSource A local input source file.
322+
* @param predictOptions Prediction options for the enqueuing.
323+
* @param pollingOptions Options for async call parameters.
324+
* @return an instance of {@link AsyncPredictResponse}.
325+
* @throws IOException Throws if the file can't be accessed.
326+
* @throws InterruptedException Throws in the event of a timeout.
327+
*/
328+
public <T extends Inference> AsyncPredictResponse<T> enqueueAndParse(
329+
Class<T> type,
330+
LocalInputSource localInputSource,
331+
PredictOptions predictOptions,
332+
AsyncPollingOptions pollingOptions
333+
) throws IOException, InterruptedException {
334+
return this.enqueueAndParse(
335+
type,
336+
new Endpoint(type),
337+
pollingOptions,
338+
localInputSource.getFile(),
339+
localInputSource.getFilename(),
340+
predictOptions,
341+
null
342+
);
343+
}
344+
345+
/**
346+
* Send a local file to an async queue, poll, and parse when complete.
347+
* @param <T> Type of inference.
348+
* @param type Type of inference.
349+
* @param localInputSource A local input source file.
350+
* @param predictOptions Prediction options for the enqueuing.
351+
* @return an instance of {@link AsyncPredictResponse}.
352+
* @throws IOException Throws if the file can't be accessed.
353+
* @throws InterruptedException Throws in the event of a timeout.
354+
*/
355+
public <T extends Inference> AsyncPredictResponse<T> enqueueAndParse(
356+
Class<T> type,
357+
LocalInputSource localInputSource,
358+
PredictOptions predictOptions
359+
) throws IOException, InterruptedException {
360+
return this.enqueueAndParse(
361+
type,
362+
new Endpoint(type),
363+
null,
364+
localInputSource.getFile(),
365+
localInputSource.getFilename(),
366+
predictOptions,
367+
null
368+
);
369+
}
370+
293371
/**
294372
* Send a remote file to an async queue, poll, and parse when complete.
295373
* @param <T> Type of inference.

src/main/java/com/mindee/PredictOptions.java

+14-1
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,28 @@ public class PredictOptions {
2424
* size.
2525
*/
2626
Boolean fullText;
27+
/**
28+
* If set, will enqueue to a workflow queue instead of a product's endpoint.
29+
*/
30+
String workflowId;
31+
/**
32+
* If set, will enable Retrieval-Augmented Generation.
33+
* Only works if a valid workflowId is set.
34+
*/
35+
Boolean rag;
2736

2837
@Builder
2938
private PredictOptions(
3039
Boolean allWords,
3140
Boolean fullText,
32-
Boolean cropper
41+
Boolean cropper,
42+
String workflowId,
43+
Boolean rag
3344
) {
3445
this.allWords = allWords == null ? Boolean.FALSE : allWords;
3546
this.fullText = fullText == null ? Boolean.FALSE : fullText;
3647
this.cropper = cropper == null ? Boolean.FALSE : cropper;
48+
this.workflowId = workflowId;
49+
this.rag = rag == null ? Boolean.FALSE : rag;
3750
}
3851
}

src/main/java/com/mindee/http/MindeeHttpApi.java

+42-16
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@
4141
public final class MindeeHttpApi extends MindeeApi {
4242

4343
private static final ObjectMapper mapper = new ObjectMapper();
44-
private final Function<Endpoint, String> buildBaseUrl = this::buildProductUrl;
45-
private final Function<String, String> buildWorkflowBaseUrl = this::buildWorkflowUrl;
44+
private final Function<Endpoint, String> buildProductPredicBasetUrl = this::buildProductPredictBaseUrl;
45+
private final Function<String, String> buildWorkflowPredictBaseUrl = this::buildWorkflowPredictBaseUrl;
46+
private final Function<String, String> buildWorkflowExecutionBaseUrl = this::buildWorkflowExecutionUrl;
4647
/**
4748
* The MindeeSetting needed to make the api call.
4849
*/
@@ -53,24 +54,27 @@ public final class MindeeHttpApi extends MindeeApi {
5354
*/
5455
private final HttpClientBuilder httpClientBuilder;
5556
/**
56-
* The function used to generate the API endpoint URL.
57+
* The function used to generate the synchronous API endpoint URL.
5758
* Only needs to be set if the api calls need to be directed through internal URLs.
5859
*/
5960
private final Function<Endpoint, String> urlFromEndpoint;
60-
6161
/**
62-
* The function used to generate the API endpoint URL for workflow execution calls.
62+
* The function used to generate the asynchronous API endpoint URL for a product.
6363
* Only needs to be set if the api calls need to be directed through internal URLs.
6464
*/
6565
private final Function<Endpoint, String> asyncUrlFromEndpoint;
66+
/**
67+
* The function used to generate the asynchronous API endpoint URL for a workflow.
68+
* Only needs to be set if the api calls need to be directed through internal URLs.
69+
*/
70+
private final Function<String, String> asyncUrlFromWorkflow;
6671
/**
6772
* The function used to generate the Job status URL for Async calls.
6873
* Only needs to be set if the api calls need to be directed through internal URLs.
6974
*/
7075
private final Function<Endpoint, String> documentUrlFromEndpoint;
71-
7276
/**
73-
* The function used to generate the Job status URL for Async calls.
77+
* The function used to generate the Job status URL for workflow execution calls.
7478
* Only needs to be set if the api calls need to be directed through internal URLs.
7579
*/
7680
private final Function<String, String> workflowUrlFromId;
@@ -82,6 +86,7 @@ public MindeeHttpApi(MindeeSettings mindeeSettings) {
8286
null,
8387
null,
8488
null,
89+
null,
8590
null
8691
);
8792
}
@@ -93,7 +98,8 @@ private MindeeHttpApi(
9398
Function<Endpoint, String> urlFromEndpoint,
9499
Function<Endpoint, String> asyncUrlFromEndpoint,
95100
Function<Endpoint, String> documentUrlFromEndpoint,
96-
Function<String, String> workflowUrlFromEndpoint
101+
Function<String, String> workflowUrlFromEndpoint,
102+
Function<String, String> asyncUrlFromWorkflow
97103
) {
98104
this.mindeeSettings = mindeeSettings;
99105

@@ -106,26 +112,35 @@ private MindeeHttpApi(
106112
if (urlFromEndpoint != null) {
107113
this.urlFromEndpoint = urlFromEndpoint;
108114
} else {
109-
this.urlFromEndpoint = buildBaseUrl.andThen((url) -> url.concat("/predict"));
115+
this.urlFromEndpoint = buildProductPredicBasetUrl.andThen(
116+
(url) -> url.concat("/predict"));
117+
}
118+
119+
if (asyncUrlFromWorkflow != null) {
120+
this.asyncUrlFromWorkflow = asyncUrlFromWorkflow;
121+
} else {
122+
this.asyncUrlFromWorkflow = this.buildWorkflowPredictBaseUrl.andThen(
123+
(url) -> url.concat("/predict_async"));
110124
}
111125

112126
if (asyncUrlFromEndpoint != null) {
113127
this.asyncUrlFromEndpoint = asyncUrlFromEndpoint;
114128
} else {
115-
this.asyncUrlFromEndpoint = this.urlFromEndpoint.andThen((url) -> url.concat("_async"));
129+
this.asyncUrlFromEndpoint = this.buildProductPredicBasetUrl.andThen(
130+
(url) -> url.concat("/predict_async"));
116131
}
117132

118133
if (documentUrlFromEndpoint != null) {
119134
this.documentUrlFromEndpoint = documentUrlFromEndpoint;
120135
} else {
121-
this.documentUrlFromEndpoint = this.buildBaseUrl.andThen(
136+
this.documentUrlFromEndpoint = this.buildProductPredicBasetUrl.andThen(
122137
(url) -> url.concat("/documents/queue/"));
123138
}
124139

125140
if (workflowUrlFromEndpoint != null) {
126141
this.workflowUrlFromId = workflowUrlFromEndpoint;
127142
} else {
128-
this.workflowUrlFromId = this.buildWorkflowBaseUrl;
143+
this.workflowUrlFromId = this.buildWorkflowExecutionBaseUrl;
129144
}
130145
}
131146

@@ -233,7 +248,12 @@ public <DocT extends Inference> AsyncPredictResponse<DocT> predictAsyncPost(
233248
RequestParameters requestParameters
234249
) throws IOException {
235250

236-
String url = asyncUrlFromEndpoint.apply(endpoint);
251+
String url;
252+
if (requestParameters.getPredictOptions().getWorkflowId() != null) {
253+
url = asyncUrlFromWorkflow.apply(requestParameters.getPredictOptions().getWorkflowId());
254+
} else {
255+
url = asyncUrlFromEndpoint.apply(endpoint);
256+
}
237257
HttpPost post = buildHttpPost(url, requestParameters);
238258

239259
// required to register jackson date module format to deserialize
@@ -340,7 +360,7 @@ private <ResponseT extends ApiResponse> MindeeHttpException getHttpError(
340360
return new MindeeHttpException(statusCode, message, details, errorCode);
341361
}
342362

343-
private String buildProductUrl(Endpoint endpoint) {
363+
private String buildProductPredictBaseUrl(Endpoint endpoint) {
344364
return this.mindeeSettings.getBaseUrl()
345365
+ "/products/"
346366
+ endpoint.getAccountName()
@@ -350,7 +370,11 @@ private String buildProductUrl(Endpoint endpoint) {
350370
+ endpoint.getVersion();
351371
}
352372

353-
private String buildWorkflowUrl(String workflowId) {
373+
private String buildWorkflowPredictBaseUrl(String workflowId) {
374+
return this.mindeeSettings.getBaseUrl() + "/workflows/" + workflowId;
375+
}
376+
377+
private String buildWorkflowExecutionUrl(String workflowId) {
354378
return this.mindeeSettings.getBaseUrl() + "/workflows/" + workflowId + "/executions";
355379
}
356380

@@ -388,7 +412,9 @@ private List<NameValuePair> buildPostParams(
388412
if (Boolean.TRUE.equals(requestParameters.getPredictOptions().getFullText())) {
389413
params.add(new BasicNameValuePair("full_text_ocr", "true"));
390414
}
391-
if (Boolean.TRUE.equals(requestParameters.getWorkflowOptions().getRag())) {
415+
if (Boolean.TRUE.equals(requestParameters.getWorkflowOptions().getRag())
416+
|| Boolean.TRUE.equals(requestParameters.getPredictOptions().getRag())
417+
) {
392418
params.add(new BasicNameValuePair("rag", "true"));
393419
}
394420
return params;

src/main/java/com/mindee/parsing/common/InferenceExtras.java

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.mindee.parsing.common;
22

33
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
45
import lombok.EqualsAndHashCode;
56
import lombok.Getter;
67
import lombok.Setter;
@@ -17,4 +18,9 @@ public class InferenceExtras {
1718
* Full Text OCR result.
1819
*/
1920
private String fullTextOcr;
21+
/**
22+
* Retrieval-Augmented Generation results.
23+
*/
24+
@JsonProperty("rag")
25+
private Rag rag;
2026
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package com.mindee.parsing.common;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import lombok.EqualsAndHashCode;
6+
import lombok.Getter;
7+
import lombok.Setter;
8+
9+
/**
10+
* Retrieval-Augmented Generation info class.
11+
*/
12+
@Getter
13+
@EqualsAndHashCode
14+
@JsonIgnoreProperties(ignoreUnknown = true)
15+
public class Rag {
16+
/**
17+
* The document ID that was matched.
18+
*/
19+
@Setter
20+
@JsonProperty("matching_document_id")
21+
private String matchingDocumentId;
22+
}

0 commit comments

Comments
 (0)