Skip to content

Commit 8d3885b

Browse files
committed
Impl of Generic model and tensorflow eval
1 parent 14b92bc commit 8d3885b

File tree

6 files changed

+100
-11
lines changed

6 files changed

+100
-11
lines changed

src/main/java/ml/shifu/shifu/core/Scorer.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,16 @@ public MLData call() {
248248
return result;
249249
}
250250
}.call());
251+
} else if(model instanceof GenericModel) {
252+
modelResults.add(new Callable<MLData>() {
253+
@Override
254+
public MLData call() {
255+
log.error("model is " + ((GenericModel)model).getModel().getClass() + " " +
256+
((GenericModel)model).getGMProperties().toString());
257+
MLData md = pair.getInput();
258+
return ((GenericModel) model).compute(pair.getInput());
259+
}
260+
}.call());
251261
} else {
252262
throw new RuntimeException("unsupport models");
253263
}
@@ -334,6 +344,8 @@ public int compare(String o1, String o2) {
334344
if(!tm.isClassfication() && !tm.isGBDT()) {
335345
rfTreeSizeList.add(tm.getTrees().size());
336346
}
347+
} else if(model instanceof GenericModel) {
348+
scores.add(toScore(score.getData(0)));
337349
} else {
338350
throw new RuntimeException("unsupport models");
339351
}
@@ -366,4 +378,4 @@ public void setScale(int scale) {
366378
this.scale = scale;
367379
}
368380
}
369-
}
381+
}

src/main/java/ml/shifu/shifu/core/processor/BasicModelProcessor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,8 @@ public void checkAlgorithmParam() throws Exception {
457457
modelConfig.getTrain().setNumTrainEpochs(10000);
458458
saveModelConfig();
459459
}
460+
} else if("tensorflow".equalsIgnoreCase(alg)) {
461+
//do nothing
460462
} else {
461463
throw new ShifuException(ShifuErrorCode.ERROR_UNSUPPORT_ALG);
462464
}

src/main/java/ml/shifu/shifu/pig/PigExecutor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,12 @@ private PigServer createPigServer(SourceType sourceType) throws IOException {
183183
pigServer.getPigContext().addJar(HDPUtils.findContainingFile("core-site.xml"));
184184
pigServer.getPigContext().addJar(HDPUtils.findContainingFile("mapred-site.xml"));
185185
pigServer.getPigContext().addJar(HDPUtils.findContainingFile("yarn-site.xml"));
186+
pigServer.getPigContext().addJar(HDPUtils.findContainingFile("libstdc++.so.6"));
186187
}
188+
pigServer.getPigContext().getConf().put("mapreduce.admin.user.env",
189+
"JAVA_HOME=./jdk1.8.zip/jdk1.8/");
190+
pigServer.getPigContext().getConf().put("mapred.cache.archives",
191+
"hdfs:///user/wzhu1/glibc_2.17.zip,hdfs:///user/wzhu1/jdk1.8.zip");
187192
} else {
188193
log.info("ExecType: LOCAL");
189194
pigServer = new ShifuPigServer(ExecType.LOCAL);

src/main/java/ml/shifu/shifu/pig/ShifuPigStorage.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright [2013-2017] PayPal Software Foundation
2+
* Copyright [2013-2018] PayPal Software Foundation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -15,21 +15,21 @@
1515
*/
1616
package ml.shifu.shifu.pig;
1717

18-
import org.apache.pig.piggybank.storage.CSVExcelStorage;
19-
import org.apache.pig.builtin.PigStorage;
20-
import org.apache.pig.data.Tuple;
21-
import org.apache.pig.ResourceSchema;
2218
import org.apache.hadoop.mapreduce.RecordReader;
2319
import org.apache.hadoop.mapreduce.Job;
2420
import org.apache.hadoop.mapreduce.InputFormat;
2521
import org.apache.hadoop.mapreduce.RecordWriter;
22+
23+
import org.apache.pig.piggybank.storage.CSVExcelStorage;
24+
import org.apache.pig.builtin.PigStorage;
25+
import org.apache.pig.data.Tuple;
26+
import org.apache.pig.ResourceSchema;
2627
import org.apache.pig.impl.logicalLayer.FrontendException;
2728
import org.apache.pig.backend.hadoop.executionengine.mapReduceLayer.PigSplit;
2829

2930
import java.io.IOException;
3031
import java.util.List;
3132

32-
3333
public class ShifuPigStorage extends PigStorage {
3434

3535
private PigStorage shifuStorage;

src/main/java/ml/shifu/shifu/util/CommonUtils.java

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.io.InputStream;
2727
import java.io.InputStreamReader;
2828
import java.io.Reader;
29+
import java.lang.reflect.Constructor;
2930
import java.util.ArrayList;
3031
import java.util.Arrays;
3132
import java.util.Collection;
@@ -48,9 +49,12 @@
4849
import ml.shifu.shifu.container.obj.ColumnConfig.ColumnFlag;
4950
import ml.shifu.shifu.container.obj.ColumnType;
5051
import ml.shifu.shifu.container.obj.EvalConfig;
52+
import ml.shifu.shifu.container.obj.GenericModelConfig;
5153
import ml.shifu.shifu.container.obj.ModelConfig;
5254
import ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM;
5355
import ml.shifu.shifu.container.obj.RawSourceData.SourceType;
56+
import ml.shifu.shifu.core.Computable;
57+
import ml.shifu.shifu.core.GenericModel;
5458
import ml.shifu.shifu.core.LR;
5559
import ml.shifu.shifu.core.NNModel;
5660
import ml.shifu.shifu.core.Normalizer;
@@ -969,7 +973,44 @@ public static List<BasicML> loadBasicModels(ModelConfig modelConfig, EvalConfig
969973
boolean gbtConvertToProb, String gbtScoreConvertStrategy) throws IOException {
970974
List<BasicML> models = new ArrayList<BasicML>();
971975
FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType);
972-
976+
List<FileStatus> genericModelConfigs = findGenericModels(modelConfig, evalConfig, sourceType);
977+
if(!genericModelConfigs.isEmpty()) {
978+
for(FileStatus f : genericModelConfigs) {
979+
GenericModelConfig gmc = loadJSON(f.getPath().toString(), sourceType, GenericModelConfig.class);
980+
981+
if(SourceType.HDFS.equals(sourceType)) {
982+
983+
FileSystem hdfs = HDFSUtils.getFS();
984+
PathFinder pathFinder = new PathFinder(modelConfig);
985+
String alg = (String)gmc.getProperties().get("algorithm");
986+
String src = pathFinder.getModelsPath(sourceType);
987+
hdfs.copyToLocalFile(false, new Path(src), new Path(System.getProperty("user.dir")), true);
988+
gmc.getProperties().put("modelpath", System.getProperty("user.dir") + "/models");
989+
File file = new File(System.getProperty("user.dir") + "/models");
990+
for(String str : file.list()) {
991+
log.error("list file in " + file.getAbsolutePath() + " : " + str);
992+
}
993+
log.error("gmc model path is : " + gmc.getProperties().get("modelpath"));
994+
if("tensorflow".equals(alg)) {
995+
996+
try {
997+
Class c = Class.forName("ml.shifu.shifu.tensorflow.TensorflowModel");
998+
Computable computable = (Computable)c.newInstance();
999+
computable.init(gmc);
1000+
GenericModel genericModel = new GenericModel(computable, gmc.getProperties());
1001+
models.add(genericModel);
1002+
log.error("load generic model");
1003+
} catch (Exception e) {
1004+
log.error("", e);
1005+
throw new RuntimeException("Get real model fail");
1006+
}
1007+
}
1008+
}
1009+
}
1010+
log.error("return generic model " + models.size());
1011+
return models;
1012+
}
1013+
9731014
List<FileStatus> modelFileStats = locateBasicModels(modelConfig, evalConfig, sourceType);
9741015
if(CollectionUtils.isNotEmpty(modelFileStats)) {
9751016
for(FileStatus f: modelFileStats) {
@@ -995,6 +1036,7 @@ public static List<FileStatus> locateBasicModels(ModelConfig modelConfig, EvalCo
9951036
if(CollectionUtils.isEmpty(listStatus)) {
9961037
// throw new ShifuException(ShifuErrorCode.ERROR_MODEL_FILE_NOT_FOUND);
9971038
// disable exception, since we there maybe sub-models
1039+
listStatus = findGenericModels(modelConfig, evalConfig, sourceType);
9981040
return listStatus;
9991041
}
10001042

@@ -1244,6 +1286,33 @@ public static List<FileStatus> findModels(ModelConfig modelConfig, EvalConfig ev
12441286

12451287
return fileList;
12461288
}
1289+
1290+
public static List<FileStatus> findGenericModels(ModelConfig modelConfig, EvalConfig evalConfig, SourceType sourceType)
1291+
throws IOException {
1292+
FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType);
1293+
PathFinder pathFinder = new PathFinder(modelConfig);
1294+
1295+
// If the algorithm in ModelConfig is NN, we only load NN models
1296+
// the same as SVM, LR
1297+
String modelSuffix = ".json";
1298+
1299+
List<FileStatus> fileList = new ArrayList<FileStatus>();
1300+
if(null == evalConfig || StringUtils.isBlank(evalConfig.getModelsPath())) {
1301+
Path path = new Path(pathFinder.getModelsPath(sourceType));
1302+
fileList.addAll(Arrays.asList(fs.listStatus(path, new FileSuffixPathFilter(modelSuffix))));
1303+
} else {
1304+
String modelsPath = evalConfig.getModelsPath();
1305+
FileStatus[] expandedPaths = fs.globStatus(new Path(modelsPath));
1306+
if(ArrayUtils.isNotEmpty(expandedPaths)) {
1307+
for(FileStatus epath: expandedPaths) {
1308+
fileList.addAll(
1309+
Arrays.asList(fs.listStatus(epath.getPath(), new FileSuffixPathFilter(modelSuffix))));
1310+
}
1311+
}
1312+
}
1313+
1314+
return fileList;
1315+
}
12471316

12481317
public static List<ModelSpec> loadSubModels(ModelConfig modelConfig, List<ColumnConfig> columnConfigList,
12491318
EvalConfig evalConfig, SourceType sourceType, Boolean gbtConvertToProb) {
@@ -1494,7 +1563,7 @@ public int compare(File from, File to) {
14941563
throw new IOException(String.format("Failed to list files in %s", modelsPathDir.getAbsolutePath()));
14951564
}
14961565
}
1497-
1566+
14981567
/**
14991568
* Return one HashMap Object contains keys in the first parameter, values in the second parameter. Before calling
15001569
* this method, you should be aware that headers should be unique.
@@ -3002,4 +3071,4 @@ public static String[] splitString(String str, String delimiter) {
30023071
return categories.toArray(new String[0]);
30033072
}
30043073

3005-
}
3074+
}

src/main/resources/store/ModelConfigMeta.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,8 @@
531531
{"label": "Support Vector Machine", "value":"SVM"},
532532
{"label": "Decision Tree", "value":"DT"},
533533
{"label": "Random Forest", "value":"RF"},
534-
{"label": "Gradient Boost Decision Tree", "value":"GBT"}
534+
{"label": "Gradient Boost Decision Tree", "value":"GBT"},
535+
{"label": "Tensorflow", "value":"Tensorflow"}
535536
]
536537
}, {
537538
"name": "gridConfigFile",

0 commit comments

Comments
 (0)