Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 44c6083

Browse files
My Namemahithsuresh
authored andcommitted
Add support for JSONLines as input
cr https://code.amazon.com/reviews/CR-31529785
1 parent 2ca74cc commit 44c6083

File tree

6 files changed

+488
-45
lines changed

6 files changed

+488
-45
lines changed

src/main/java/com/amazonaws/sagemaker/controller/ServingController.java

Lines changed: 87 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import com.amazonaws.sagemaker.dto.BatchExecutionParameter;
2323
import com.amazonaws.sagemaker.dto.DataSchema;
24-
import com.amazonaws.sagemaker.dto.SageMakerDataListObject;
24+
import com.amazonaws.sagemaker.dto.SageMakerRequestListObject;
2525
import com.amazonaws.sagemaker.dto.SageMakerRequestObject;
2626
import com.amazonaws.sagemaker.helper.DataConversionHelper;
2727
import com.amazonaws.sagemaker.helper.ResponseHelper;
@@ -30,11 +30,13 @@
3030
import com.amazonaws.sagemaker.utils.ScalaUtils;
3131
import com.amazonaws.sagemaker.utils.SystemUtils;
3232
import com.fasterxml.jackson.core.JsonProcessingException;
33+
import com.fasterxml.jackson.databind.JsonMappingException;
3334
import com.fasterxml.jackson.databind.ObjectMapper;
3435
import com.google.common.annotations.VisibleForTesting;
3536
import com.google.common.base.Preconditions;
3637
import com.google.common.collect.Lists;
3738
import java.io.IOException;
39+
import java.util.Arrays;
3840
import java.util.List;
3941
import ml.combust.mleap.runtime.frame.ArrayRow;
4042
import ml.combust.mleap.runtime.frame.DefaultLeapFrame;
@@ -44,7 +46,6 @@
4446
import org.slf4j.LoggerFactory;
4547
import org.springframework.beans.factory.annotation.Autowired;
4648
import org.springframework.http.HttpHeaders;
47-
import org.springframework.http.HttpStatus;
4849
import org.springframework.http.MediaType;
4950
import org.springframework.http.ResponseEntity;
5051
import org.springframework.web.bind.annotation.RequestBody;
@@ -104,7 +105,7 @@ public ResponseEntity returnBatchExecutionParameter() throws JsonProcessingExcep
104105
* Implements the invocations POST API for application/json input
105106
*
106107
* @param sro, the request object
107-
* @param accept, accept parameter from request
108+
* @param accept, indicates the content types that the http method is able to understand
108109
* @return ResponseEntity with body as the expected payload JSON & proper statuscode based on the input
109110
*/
110111
@RequestMapping(path = "/invocations", method = POST, consumes = MediaType.APPLICATION_JSON_VALUE)
@@ -125,10 +126,10 @@ public ResponseEntity<String> transformRequestJson(@RequestBody final SageMakerR
125126
}
126127

127128
/**
128-
* Implements the invocations POST API for application/json input
129+
* Implements the invocations POST API for text/csv input
129130
*
130131
* @param csvRow, data in row format in CSV
131-
* @param accept, accept parameter from request
132+
* @param accept, indicates the content types that the http method is able to understand
132133
* @return ResponseEntity with body as the expected payload JSON & proper statuscode based on the input
133134
*/
134135
@RequestMapping(path = "/invocations", method = POST, consumes = AdditionalMediaType.TEXT_CSV_VALUE)
@@ -154,45 +155,31 @@ public ResponseEntity<String> transformRequestCsv(@RequestBody final byte[] csvR
154155
* Implements the invocations POST API for application/jsonlines input
155156
*
156157
* @param jsonLines, lines of json values
157-
* @param accept, accept parameter from request
158+
* @param accept, indicates the content types that the http method is able to understand
158159
* @return ResponseEntity with body as the expected payload JSON & proper statuscode based on the input
159160
*/
160-
@RequestMapping(path = "/invocations", method = POST, consumes = AdditionalMediaType.APPLICATION_JSONLINES_VALUE_MULTIPLE)
161-
public ResponseEntity<String> transformRequestJsonLines(@RequestBody final byte[] jsonLines,
162-
@RequestHeader(value = HttpHeaders.ACCEPT, required = false) final String accept) {
161+
@RequestMapping(path = "/invocations", method = POST, consumes = AdditionalMediaType.APPLICATION_JSONLINES_VALUE)
162+
public ResponseEntity<String> transformRequestJsonLines(
163+
@RequestBody final byte[] jsonLines,
164+
@RequestHeader(value = HttpHeaders.ACCEPT, required = false)
165+
final String accept) {
166+
163167
if (jsonLines == null) {
168+
LOG.error("Input passed to the request is null");
169+
return ResponseEntity.badRequest().build();
170+
171+
} else if (jsonLines.length == 0) {
172+
164173
LOG.error("Input passed to the request is empty");
165174
return ResponseEntity.noContent().build();
166175
}
176+
167177
try {
168178
final String acceptVal = this.retrieveAndVerifyAccept(accept);
169-
final DataSchema schema = this.retrieveAndVerifySchema(null, mapper);
170-
final String jsonStringLine = new String(jsonLines);
171-
172-
// Map list of inputs to DataList object
173-
final SageMakerDataListObject sro = mapper.readValue(jsonStringLine, SageMakerDataListObject.class);
174-
List<List<Object>> inputDatas = sro.getData();
175-
List<ResponseEntity<String>> responseList = Lists.newArrayList();
176-
177-
// Process each input separately and add response to a list
178-
final int inputDatasSize = inputDatas.size();
179-
for (int idx = 0; idx < inputDatasSize; ++idx) {
180-
ResponseEntity<String> response = this.processInputData(inputDatas.get(idx), schema, acceptVal);
181-
responseList.add(response);
182-
}
179+
return this.processInputDataForJsonLines(new String(jsonLines), acceptVal);
183180

184-
// Merge response body to a new output response
185-
List<List<String>> bodyList = Lists.newArrayList();
186-
HttpHeaders headers = null;
187-
//combine body in responseList
188-
for (ResponseEntity<String> response:responseList) {
189-
HttpStatus statuscode = response.getStatusCode();
190-
headers = response.getHeaders();
191-
bodyList.add(Lists.newArrayList(response.getBody()));
192-
}
193-
194-
return ResponseEntity.ok().headers(headers).body(bodyList.toString());
195181
} catch (final Exception ex) {
182+
196183
LOG.error("Error in processing current request", ex);
197184
return ResponseEntity.badRequest().body(ex.getMessage());
198185
}
@@ -231,6 +218,72 @@ private ResponseEntity<String> processInputData(final List<Object> inputData, fi
231218

232219
}
233220

221+
/**
222+
* Helper method to interpret the JSONLines input and return the response in the expected output format.
223+
*
224+
* @param jsonLinesAsString
225+
* The JSON lines input.
226+
*
227+
* @param acceptVal
228+
* The output format in which the response is to be returned.
229+
*
230+
* @return
231+
* The transformed output for the JSONlines input.
232+
*
233+
* @throws IOException
234+
* If there is an exception during object mapping and validation.
235+
*
236+
*/
237+
ResponseEntity<String> processInputDataForJsonLines(
238+
final String jsonLinesAsString, final String acceptVal) throws IOException {
239+
240+
final String lines[] = jsonLinesAsString.split("\\r?\\n");
241+
final ObjectMapper mapper = new ObjectMapper();
242+
243+
// first line is special since it could contain the schema as well. Extract the schema.
244+
final SageMakerRequestObject firstLine = mapper.readValue(lines[0], SageMakerRequestObject.class);
245+
final DataSchema schema = this.retrieveAndVerifySchema(firstLine.getSchema(), mapper);
246+
247+
List<List<Object>> inputDatas = Lists.newArrayList();
248+
249+
for(String jsonStringLine : lines) {
250+
try {
251+
252+
final SageMakerRequestListObject sro = mapper.readValue(jsonStringLine, SageMakerRequestListObject.class);
253+
254+
for(int idx = 0; idx < sro.getData().size(); ++idx) {
255+
inputDatas.add(sro.getData().get(idx));
256+
}
257+
258+
} catch (final JsonMappingException ex) {
259+
260+
final SageMakerRequestObject sro = mapper.readValue(jsonStringLine, SageMakerRequestObject.class);
261+
inputDatas.add(sro.getData());
262+
}
263+
}
264+
265+
List<ResponseEntity<String>> responseList = Lists.newArrayList();
266+
267+
// Process each input separately and add response to a list
268+
for (int idx = 0; idx < inputDatas.size(); ++idx) {
269+
responseList.add(this.processInputData(inputDatas.get(idx), schema, acceptVal));
270+
}
271+
272+
// Merge response body to a new output response
273+
List<List<String>> bodyList = Lists.newArrayList();
274+
275+
// All response should be valid if no exception got catch
276+
// which all headers should be the same and extract the first one to construct responseEntity
277+
HttpHeaders headers = responseList.get(0).getHeaders();
278+
279+
//combine body in responseList
280+
for (ResponseEntity<String> response: responseList) {
281+
bodyList.add(Lists.newArrayList(response.getBody()));
282+
}
283+
284+
return ResponseEntity.ok().headers(headers).body(bodyList.toString());
285+
}
286+
234287
private boolean checkEmptyAccept(final String acceptFromRequest) {
235288
//Spring may send the Accept as "*\/*" (star/star) in case accept is not passed via request
236289
return (StringUtils.isBlank(acceptFromRequest) || StringUtils.equals(acceptFromRequest, MediaType.ALL_VALUE));

src/main/java/com/amazonaws/sagemaker/dto/SageMakerDataListObject.java renamed to src/main/java/com/amazonaws/sagemaker/dto/SageMakerRequestListObject.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,23 @@
2626
* Request object POJO to which data field of input request in JSONLINES format will be mapped to by Spring (using Jackson).
2727
* For sample input, please see test/resources/com/amazonaws/sagemaker/dto
2828
*/
29-
public class SageMakerDataListObject {
29+
public class SageMakerRequestListObject {
3030

31+
private DataSchema schema;
3132
private List<List<Object>> data;
3233

3334
@JsonCreator
34-
public SageMakerDataListObject(
35-
@JsonProperty("data") final List<List<Object>> data) {
35+
public SageMakerRequestListObject(@JsonProperty("schema") final DataSchema schema,
36+
@JsonProperty("data") final List<List<Object>> data) {
37+
// schema can be retrieved from environment variable as well, hence it is not enforced to be null
38+
this.schema = schema;
3639
this.data = Preconditions.checkNotNull(data);
3740
}
3841

42+
public DataSchema getSchema() {
43+
return schema;
44+
}
45+
3946
public List<List<Object>> getData() {
4047
return data;
4148
}

0 commit comments

Comments
 (0)