Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 884f066

Browse files
author
mahithsuresh
authored
Merge pull request #11 from himanishk/update-version-fix-vector
Update version fix vector
2 parents 5b5b3e6 + 3e64d69 commit 884f066

File tree

5 files changed

+48
-37
lines changed

5 files changed

+48
-37
lines changed

pom.xml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,13 @@
154154
<dependency>
155155
<groupId>ml.combust.mleap</groupId>
156156
<artifactId>mleap-runtime_2.11</artifactId>
157-
<version>0.14.0</version>
157+
<version>0.15.0</version>
158+
</dependency>
159+
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib-local -->
160+
<dependency>
161+
<groupId>org.apache.spark</groupId>
162+
<artifactId>spark-mllib-local_2.11</artifactId>
163+
<version>2.4.5</version>
158164
</dependency>
159165
<dependency>
160166
<groupId>org.apache.commons</groupId>

src/main/java/com/amazonaws/sagemaker/helper/DataConversionHelper.java

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,14 @@
1616

1717
package com.amazonaws.sagemaker.helper;
1818

19-
import com.amazonaws.sagemaker.dto.DataSchema;
2019
import com.amazonaws.sagemaker.dto.ColumnSchema;
20+
import com.amazonaws.sagemaker.dto.DataSchema;
2121
import com.amazonaws.sagemaker.type.BasicDataType;
2222
import com.amazonaws.sagemaker.type.DataStructureType;
2323
import com.google.common.annotations.VisibleForTesting;
2424
import com.google.common.base.Preconditions;
2525
import com.google.common.collect.Lists;
26-
import java.io.IOException;
27-
import java.io.StringReader;
28-
import java.util.List;
29-
import java.util.stream.Collectors;
30-
import ml.combust.mleap.core.types.BasicType;
31-
import ml.combust.mleap.core.types.DataType;
32-
import ml.combust.mleap.core.types.ListType;
33-
import ml.combust.mleap.core.types.ScalarType;
34-
import ml.combust.mleap.core.types.StructField;
35-
import ml.combust.mleap.core.types.StructType;
36-
import ml.combust.mleap.core.types.TensorType;
26+
import ml.combust.mleap.core.types.*;
3727
import ml.combust.mleap.runtime.frame.ArrayRow;
3828
import ml.combust.mleap.runtime.frame.DefaultLeapFrame;
3929
import ml.combust.mleap.runtime.frame.Row;
@@ -43,9 +33,15 @@
4333
import org.apache.commons.csv.CSVParser;
4434
import org.apache.commons.csv.CSVRecord;
4535
import org.apache.commons.lang3.StringUtils;
36+
import org.apache.spark.ml.linalg.Vectors;
4637
import org.springframework.beans.factory.annotation.Autowired;
4738
import org.springframework.stereotype.Component;
4839

40+
import java.io.IOException;
41+
import java.io.StringReader;
42+
import java.util.List;
43+
import java.util.stream.Collectors;
44+
4945
/**
5046
* Converter class to convert data between input to MLeap expected types and convert back MLeap helper to Java types
5147
* for output.
@@ -168,12 +164,12 @@ protected Object convertInputDataToJavaType(final String type, final String stru
168164
default:
169165
throw new IllegalArgumentException("Given type is not supported");
170166
}
171-
} else {
167+
} else if (!StringUtils.isBlank(structure) && StringUtils.equals(structure, DataStructureType.ARRAY)) {
172168
List<Object> listOfObjects;
173169
try {
174170
listOfObjects = (List<Object>) value;
175171
} catch (ClassCastException cce) {
176-
throw new IllegalArgumentException("Input val is not a list but struct passed is vector or array");
172+
throw new IllegalArgumentException("Input val is not a list but struct passed is array");
177173
}
178174
switch (type) {
179175
case BasicDataType.INTEGER:
@@ -194,7 +190,17 @@ protected Object convertInputDataToJavaType(final String type, final String stru
194190
default:
195191
throw new IllegalArgumentException("Given type is not supported");
196192
}
197-
193+
} else {
194+
if(!type.equals(BasicDataType.DOUBLE))
195+
throw new IllegalArgumentException("Only Double type is supported for vector");
196+
List<Double> vectorValues;
197+
try {
198+
vectorValues = (List<Double>)value;
199+
} catch (ClassCastException cce) {
200+
throw new IllegalArgumentException("Input val is not a list but struct passed is vector");
201+
}
202+
double[] primitiveVectorValues = vectorValues.stream().mapToDouble(d -> d).toArray();
203+
return Vectors.dense(primitiveVectorValues);
198204
}
199205
}
200206

src/test/java/com/amazonaws/sagemaker/dto/SageMakerRequestObjectTest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818

1919
import com.fasterxml.jackson.databind.ObjectMapper;
2020
import com.google.common.collect.Lists;
21-
import java.io.IOException;
2221
import org.apache.commons.io.IOUtils;
2322
import org.junit.Assert;
2423
import org.junit.Test;
2524

25+
import java.io.IOException;
26+
2627
public class SageMakerRequestObjectTest {
2728

2829
private ObjectMapper mapper = new ObjectMapper();
@@ -80,14 +81,14 @@ public void testParseCompleteInputJson() throws IOException {
8081
Assert.assertEquals(sro.getSchema().getInput().get(0).getName(), "name_1");
8182
Assert.assertEquals(sro.getSchema().getInput().get(1).getName(), "name_2");
8283
Assert.assertEquals(sro.getSchema().getInput().get(2).getName(), "name_3");
83-
Assert.assertEquals(sro.getSchema().getInput().get(0).getType(), "int");
84+
Assert.assertEquals(sro.getSchema().getInput().get(0).getType(), "double");
8485
Assert.assertEquals(sro.getSchema().getInput().get(1).getType(), "string");
8586
Assert.assertEquals(sro.getSchema().getInput().get(2).getType(), "double");
8687
Assert.assertEquals(sro.getSchema().getInput().get(0).getStruct(), "vector");
8788
Assert.assertEquals(sro.getSchema().getInput().get(1).getStruct(), "basic");
8889
Assert.assertEquals(sro.getSchema().getInput().get(2).getStruct(), "array");
8990
Assert.assertEquals(sro.getData(),
90-
Lists.newArrayList(Lists.newArrayList(1, 2, 3), "C", Lists.newArrayList(38.0, 24.0)));
91+
Lists.newArrayList(Lists.newArrayList(1.0, 2.0, 3.0), "C", Lists.newArrayList(38.0, 24.0)));
9192
Assert.assertEquals(sro.getSchema().getOutput().getName(), "features");
9293
Assert.assertEquals(sro.getSchema().getOutput().getType(), "double");
9394
Assert.assertEquals(sro.getSchema().getOutput().getStruct(), "vector");

src/test/java/com/amazonaws/sagemaker/helper/DataConversionHelperTest.java

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
import com.amazonaws.sagemaker.type.DataStructureType;
2323
import com.fasterxml.jackson.databind.ObjectMapper;
2424
import com.google.common.collect.Lists;
25-
import java.io.IOException;
26-
import java.util.List;
2725
import ml.combust.mleap.core.types.ListType;
2826
import ml.combust.mleap.core.types.ScalarType;
2927
import ml.combust.mleap.core.types.TensorType;
@@ -32,9 +30,13 @@
3230
import ml.combust.mleap.runtime.javadsl.LeapFrameBuilder;
3331
import ml.combust.mleap.runtime.javadsl.LeapFrameBuilderSupport;
3432
import org.apache.commons.io.IOUtils;
33+
import org.apache.spark.ml.linalg.Vectors;
3534
import org.junit.Assert;
3635
import org.junit.Test;
3736

37+
import java.io.IOException;
38+
import java.util.List;
39+
3840
public class DataConversionHelperTest {
3941

4042
private ObjectMapper mapper = new ObjectMapper();
@@ -143,21 +145,11 @@ public void testCastingInputToJavaTypeSingle() {
143145

144146
@Test
145147
public void testCastingInputToJavaTypeList() {
146-
Assert.assertEquals(Lists.newArrayList(1, 2), dataConversionHelper
147-
.convertInputDataToJavaType(BasicDataType.INTEGER, DataStructureType.VECTOR,
148-
Lists.newArrayList(new Integer("1"), new Integer("2"))));
149-
150-
Assert.assertEquals(Lists.newArrayList(1.0, 2.0), dataConversionHelper
151-
.convertInputDataToJavaType(BasicDataType.FLOAT, DataStructureType.VECTOR,
152-
Lists.newArrayList(new Double("1.0"), new Double("2.0"))));
153148

154-
Assert.assertEquals(Lists.newArrayList(1.0, 2.0), dataConversionHelper
155-
.convertInputDataToJavaType(BasicDataType.DOUBLE, DataStructureType.VECTOR,
156-
Lists.newArrayList(new Double("1.0"), new Double("2.0"))));
157-
158-
Assert.assertEquals(Lists.newArrayList(new Byte("1")), dataConversionHelper
159-
.convertInputDataToJavaType(BasicDataType.BYTE, DataStructureType.VECTOR,
160-
Lists.newArrayList(new Byte("1"))));
149+
//Check vector struct and double type returns a Spark vector
150+
Assert.assertEquals(Vectors.dense(new double[]{1.0, 2.0}),dataConversionHelper
151+
.convertInputDataToJavaType(BasicDataType.DOUBLE, DataStructureType.VECTOR,
152+
Lists.newArrayList(new Double("1.0"), new Double("2.0"))));
161153

162154
Assert.assertEquals(Lists.newArrayList(1L, 2L), dataConversionHelper
163155
.convertInputDataToJavaType(BasicDataType.LONG, DataStructureType.ARRAY,
@@ -175,6 +167,12 @@ public void testCastingInputToJavaTypeList() {
175167
Lists.newArrayList(Boolean.valueOf("1"))));
176168
}
177169

170+
@Test(expected = IllegalArgumentException.class)
171+
public void testConvertInputToJavaTypeNonDoibleVector() {
172+
dataConversionHelper
173+
.convertInputDataToJavaType(BasicDataType.INTEGER, DataStructureType.VECTOR, new Integer("1"));
174+
}
175+
178176
@Test(expected = IllegalArgumentException.class)
179177
public void testCastingInputToJavaTypeNonList() {
180178
dataConversionHelper

src/test/resources/com/amazonaws/sagemaker/dto/complete_input.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"input": [
44
{
55
"name": "name_1",
6-
"type": "int",
6+
"type": "double",
77
"struct": "vector"
88
},
99
{
@@ -23,5 +23,5 @@
2323
"struct": "vector"
2424
}
2525
},
26-
"data": [[1, 2, 3], "C", [38.0, 24.0]]
26+
"data": [[1.0, 2.0, 3.0], "C", [38.0, 24.0]]
2727
}

0 commit comments

Comments
 (0)