21
21
22
22
import com .amazonaws .sagemaker .dto .BatchExecutionParameter ;
23
23
import com .amazonaws .sagemaker .dto .DataSchema ;
24
- import com .amazonaws .sagemaker .dto .SageMakerDataListObject ;
24
+ import com .amazonaws .sagemaker .dto .SageMakerRequestListObject ;
25
25
import com .amazonaws .sagemaker .dto .SageMakerRequestObject ;
26
26
import com .amazonaws .sagemaker .helper .DataConversionHelper ;
27
27
import com .amazonaws .sagemaker .helper .ResponseHelper ;
30
30
import com .amazonaws .sagemaker .utils .ScalaUtils ;
31
31
import com .amazonaws .sagemaker .utils .SystemUtils ;
32
32
import com .fasterxml .jackson .core .JsonProcessingException ;
33
+ import com .fasterxml .jackson .databind .JsonMappingException ;
33
34
import com .fasterxml .jackson .databind .ObjectMapper ;
34
35
import com .google .common .annotations .VisibleForTesting ;
35
36
import com .google .common .base .Preconditions ;
36
37
import com .google .common .collect .Lists ;
37
38
import java .io .IOException ;
39
+ import java .util .Arrays ;
38
40
import java .util .List ;
39
41
import ml .combust .mleap .runtime .frame .ArrayRow ;
40
42
import ml .combust .mleap .runtime .frame .DefaultLeapFrame ;
44
46
import org .slf4j .LoggerFactory ;
45
47
import org .springframework .beans .factory .annotation .Autowired ;
46
48
import org .springframework .http .HttpHeaders ;
47
- import org .springframework .http .HttpStatus ;
48
49
import org .springframework .http .MediaType ;
49
50
import org .springframework .http .ResponseEntity ;
50
51
import org .springframework .web .bind .annotation .RequestBody ;
@@ -104,7 +105,7 @@ public ResponseEntity returnBatchExecutionParameter() throws JsonProcessingExcep
104
105
* Implements the invocations POST API for application/json input
105
106
*
106
107
* @param sro, the request object
107
- * @param accept, accept parameter from request
108
+ * @param accept, indicates the content types that the http method is able to understand
108
109
* @return ResponseEntity with body as the expected payload JSON & proper statuscode based on the input
109
110
*/
110
111
@ RequestMapping (path = "/invocations" , method = POST , consumes = MediaType .APPLICATION_JSON_VALUE )
@@ -125,10 +126,10 @@ public ResponseEntity<String> transformRequestJson(@RequestBody final SageMakerR
125
126
}
126
127
127
128
/**
128
- * Implements the invocations POST API for application/json input
129
+ * Implements the invocations POST API for text/csv input
129
130
*
130
131
* @param csvRow, data in row format in CSV
131
- * @param accept, accept parameter from request
132
+ * @param accept, indicates the content types that the http method is able to understand
132
133
* @return ResponseEntity with body as the expected payload JSON & proper statuscode based on the input
133
134
*/
134
135
@ RequestMapping (path = "/invocations" , method = POST , consumes = AdditionalMediaType .TEXT_CSV_VALUE )
@@ -154,45 +155,31 @@ public ResponseEntity<String> transformRequestCsv(@RequestBody final byte[] csvR
154
155
* Implements the invocations POST API for application/jsonlines input
155
156
*
156
157
* @param jsonLines, lines of json values
157
- * @param accept, accept parameter from request
158
+ * @param accept, indicates the content types that the http method is able to understand
158
159
* @return ResponseEntity with body as the expected payload JSON & proper statuscode based on the input
159
160
*/
160
- @ RequestMapping (path = "/invocations" , method = POST , consumes = AdditionalMediaType .APPLICATION_JSONLINES_VALUE_MULTIPLE )
161
- public ResponseEntity <String > transformRequestJsonLines (@ RequestBody final byte [] jsonLines ,
162
- @ RequestHeader (value = HttpHeaders .ACCEPT , required = false ) final String accept ) {
161
+ @ RequestMapping (path = "/invocations" , method = POST , consumes = AdditionalMediaType .APPLICATION_JSONLINES_VALUE )
162
+ public ResponseEntity <String > transformRequestJsonLines (
163
+ @ RequestBody final byte [] jsonLines ,
164
+ @ RequestHeader (value = HttpHeaders .ACCEPT , required = false )
165
+ final String accept ) {
166
+
163
167
if (jsonLines == null ) {
168
+ LOG .error ("Input passed to the request is null" );
169
+ return ResponseEntity .badRequest ().build ();
170
+
171
+ } else if (jsonLines .length == 0 ) {
172
+
164
173
LOG .error ("Input passed to the request is empty" );
165
174
return ResponseEntity .noContent ().build ();
166
175
}
176
+
167
177
try {
168
178
final String acceptVal = this .retrieveAndVerifyAccept (accept );
169
- final DataSchema schema = this .retrieveAndVerifySchema (null , mapper );
170
- final String jsonStringLine = new String (jsonLines );
171
-
172
- // Map list of inputs to DataList object
173
- final SageMakerDataListObject sro = mapper .readValue (jsonStringLine , SageMakerDataListObject .class );
174
- List <List <Object >> inputDatas = sro .getData ();
175
- List <ResponseEntity <String >> responseList = Lists .newArrayList ();
176
-
177
- // Process each input separately and add response to a list
178
- final int inputDatasSize = inputDatas .size ();
179
- for (int idx = 0 ; idx < inputDatasSize ; ++idx ) {
180
- ResponseEntity <String > response = this .processInputData (inputDatas .get (idx ), schema , acceptVal );
181
- responseList .add (response );
182
- }
179
+ return this .processInputDataForJsonLines (new String (jsonLines ), acceptVal );
183
180
184
- // Merge response body to a new output response
185
- List <List <String >> bodyList = Lists .newArrayList ();
186
- HttpHeaders headers = null ;
187
- //combine body in responseList
188
- for (ResponseEntity <String > response :responseList ) {
189
- HttpStatus statuscode = response .getStatusCode ();
190
- headers = response .getHeaders ();
191
- bodyList .add (Lists .newArrayList (response .getBody ()));
192
- }
193
-
194
- return ResponseEntity .ok ().headers (headers ).body (bodyList .toString ());
195
181
} catch (final Exception ex ) {
182
+
196
183
LOG .error ("Error in processing current request" , ex );
197
184
return ResponseEntity .badRequest ().body (ex .getMessage ());
198
185
}
@@ -231,6 +218,72 @@ private ResponseEntity<String> processInputData(final List<Object> inputData, fi
231
218
232
219
}
233
220
221
+ /**
222
+ * Helper method to interpret the JSONLines input and return the response in the expected output format.
223
+ *
224
+ * @param jsonLinesAsString
225
+ * The JSON lines input.
226
+ *
227
+ * @param acceptVal
228
+ * The output format in which the response is to be returned.
229
+ *
230
+ * @return
231
+ * The transformed output for the JSONlines input.
232
+ *
233
+ * @throws IOException
234
+ * If there is an exception during object mapping and validation.
235
+ *
236
+ */
237
+ ResponseEntity <String > processInputDataForJsonLines (
238
+ final String jsonLinesAsString , final String acceptVal ) throws IOException {
239
+
240
+ final String lines [] = jsonLinesAsString .split ("\\ r?\\ n" );
241
+ final ObjectMapper mapper = new ObjectMapper ();
242
+
243
+ // first line is special since it could contain the schema as well. Extract the schema.
244
+ final SageMakerRequestObject firstLine = mapper .readValue (lines [0 ], SageMakerRequestObject .class );
245
+ final DataSchema schema = this .retrieveAndVerifySchema (firstLine .getSchema (), mapper );
246
+
247
+ List <List <Object >> inputDatas = Lists .newArrayList ();
248
+
249
+ for (String jsonStringLine : lines ) {
250
+ try {
251
+
252
+ final SageMakerRequestListObject sro = mapper .readValue (jsonStringLine , SageMakerRequestListObject .class );
253
+
254
+ for (int idx = 0 ; idx < sro .getData ().size (); ++idx ) {
255
+ inputDatas .add (sro .getData ().get (idx ));
256
+ }
257
+
258
+ } catch (final JsonMappingException ex ) {
259
+
260
+ final SageMakerRequestObject sro = mapper .readValue (jsonStringLine , SageMakerRequestObject .class );
261
+ inputDatas .add (sro .getData ());
262
+ }
263
+ }
264
+
265
+ List <ResponseEntity <String >> responseList = Lists .newArrayList ();
266
+
267
+ // Process each input separately and add response to a list
268
+ for (int idx = 0 ; idx < inputDatas .size (); ++idx ) {
269
+ responseList .add (this .processInputData (inputDatas .get (idx ), schema , acceptVal ));
270
+ }
271
+
272
+ // Merge response body to a new output response
273
+ List <List <String >> bodyList = Lists .newArrayList ();
274
+
275
+ // All response should be valid if no exception got catch
276
+ // which all headers should be the same and extract the first one to construct responseEntity
277
+ HttpHeaders headers = responseList .get (0 ).getHeaders ();
278
+
279
+ //combine body in responseList
280
+ for (ResponseEntity <String > response : responseList ) {
281
+ bodyList .add (Lists .newArrayList (response .getBody ()));
282
+ }
283
+
284
+ return ResponseEntity .ok ().headers (headers ).body (bodyList .toString ());
285
+ }
286
+
234
287
private boolean checkEmptyAccept (final String acceptFromRequest ) {
235
288
//Spring may send the Accept as "*\/*" (star/star) in case accept is not passed via request
236
289
return (StringUtils .isBlank (acceptFromRequest ) || StringUtils .equals (acceptFromRequest , MediaType .ALL_VALUE ));
0 commit comments