26
26
import java .io .InputStream ;
27
27
import java .io .InputStreamReader ;
28
28
import java .io .Reader ;
29
+ import java .lang .reflect .Constructor ;
29
30
import java .util .ArrayList ;
30
31
import java .util .Arrays ;
31
32
import java .util .Collection ;
48
49
import ml .shifu .shifu .container .obj .ColumnConfig .ColumnFlag ;
49
50
import ml .shifu .shifu .container .obj .ColumnType ;
50
51
import ml .shifu .shifu .container .obj .EvalConfig ;
52
+ import ml .shifu .shifu .container .obj .GenericModelConfig ;
51
53
import ml .shifu .shifu .container .obj .ModelConfig ;
52
54
import ml .shifu .shifu .container .obj .ModelTrainConf .ALGORITHM ;
53
55
import ml .shifu .shifu .container .obj .RawSourceData .SourceType ;
56
+ import ml .shifu .shifu .core .Computable ;
57
+ import ml .shifu .shifu .core .GenericModel ;
54
58
import ml .shifu .shifu .core .LR ;
55
59
import ml .shifu .shifu .core .NNModel ;
56
60
import ml .shifu .shifu .core .Normalizer ;
@@ -969,7 +973,44 @@ public static List<BasicML> loadBasicModels(ModelConfig modelConfig, EvalConfig
969
973
boolean gbtConvertToProb , String gbtScoreConvertStrategy ) throws IOException {
970
974
List <BasicML > models = new ArrayList <BasicML >();
971
975
FileSystem fs = ShifuFileUtils .getFileSystemBySourceType (sourceType );
972
-
976
+ List <FileStatus > genericModelConfigs = findGenericModels (modelConfig , evalConfig , sourceType );
977
+ if (!genericModelConfigs .isEmpty ()) {
978
+ for (FileStatus f : genericModelConfigs ) {
979
+ GenericModelConfig gmc = loadJSON (f .getPath ().toString (), sourceType , GenericModelConfig .class );
980
+
981
+ if (SourceType .HDFS .equals (sourceType )) {
982
+
983
+ FileSystem hdfs = HDFSUtils .getFS ();
984
+ PathFinder pathFinder = new PathFinder (modelConfig );
985
+ String alg = (String )gmc .getProperties ().get ("algorithm" );
986
+ String src = pathFinder .getModelsPath (sourceType );
987
+ hdfs .copyToLocalFile (false , new Path (src ), new Path (System .getProperty ("user.dir" )), true );
988
+ gmc .getProperties ().put ("modelpath" , System .getProperty ("user.dir" ) + "/models" );
989
+ File file = new File (System .getProperty ("user.dir" ) + "/models" );
990
+ for (String str : file .list ()) {
991
+ log .error ("list file in " + file .getAbsolutePath () + " : " + str );
992
+ }
993
+ log .error ("gmc model path is : " + gmc .getProperties ().get ("modelpath" ));
994
+ if ("tensorflow" .equals (alg )) {
995
+
996
+ try {
997
+ Class c = Class .forName ("ml.shifu.shifu.tensorflow.TensorflowModel" );
998
+ Computable computable = (Computable )c .newInstance ();
999
+ computable .init (gmc );
1000
+ GenericModel genericModel = new GenericModel (computable , gmc .getProperties ());
1001
+ models .add (genericModel );
1002
+ log .error ("load generic model" );
1003
+ } catch (Exception e ) {
1004
+ log .error ("" , e );
1005
+ throw new RuntimeException ("Get real model fail" );
1006
+ }
1007
+ }
1008
+ }
1009
+ }
1010
+ log .error ("return generic model " + models .size ());
1011
+ return models ;
1012
+ }
1013
+
973
1014
List <FileStatus > modelFileStats = locateBasicModels (modelConfig , evalConfig , sourceType );
974
1015
if (CollectionUtils .isNotEmpty (modelFileStats )) {
975
1016
for (FileStatus f : modelFileStats ) {
@@ -995,6 +1036,7 @@ public static List<FileStatus> locateBasicModels(ModelConfig modelConfig, EvalCo
995
1036
if (CollectionUtils .isEmpty (listStatus )) {
996
1037
// throw new ShifuException(ShifuErrorCode.ERROR_MODEL_FILE_NOT_FOUND);
997
1038
// disable exception, since we there maybe sub-models
1039
+ listStatus = findGenericModels (modelConfig , evalConfig , sourceType );
998
1040
return listStatus ;
999
1041
}
1000
1042
@@ -1244,6 +1286,33 @@ public static List<FileStatus> findModels(ModelConfig modelConfig, EvalConfig ev
1244
1286
1245
1287
return fileList ;
1246
1288
}
1289
+
1290
+ public static List <FileStatus > findGenericModels (ModelConfig modelConfig , EvalConfig evalConfig , SourceType sourceType )
1291
+ throws IOException {
1292
+ FileSystem fs = ShifuFileUtils .getFileSystemBySourceType (sourceType );
1293
+ PathFinder pathFinder = new PathFinder (modelConfig );
1294
+
1295
+ // If the algorithm in ModelConfig is NN, we only load NN models
1296
+ // the same as SVM, LR
1297
+ String modelSuffix = ".json" ;
1298
+
1299
+ List <FileStatus > fileList = new ArrayList <FileStatus >();
1300
+ if (null == evalConfig || StringUtils .isBlank (evalConfig .getModelsPath ())) {
1301
+ Path path = new Path (pathFinder .getModelsPath (sourceType ));
1302
+ fileList .addAll (Arrays .asList (fs .listStatus (path , new FileSuffixPathFilter (modelSuffix ))));
1303
+ } else {
1304
+ String modelsPath = evalConfig .getModelsPath ();
1305
+ FileStatus [] expandedPaths = fs .globStatus (new Path (modelsPath ));
1306
+ if (ArrayUtils .isNotEmpty (expandedPaths )) {
1307
+ for (FileStatus epath : expandedPaths ) {
1308
+ fileList .addAll (
1309
+ Arrays .asList (fs .listStatus (epath .getPath (), new FileSuffixPathFilter (modelSuffix ))));
1310
+ }
1311
+ }
1312
+ }
1313
+
1314
+ return fileList ;
1315
+ }
1247
1316
1248
1317
public static List <ModelSpec > loadSubModels (ModelConfig modelConfig , List <ColumnConfig > columnConfigList ,
1249
1318
EvalConfig evalConfig , SourceType sourceType , Boolean gbtConvertToProb ) {
@@ -1494,7 +1563,7 @@ public int compare(File from, File to) {
1494
1563
throw new IOException (String .format ("Failed to list files in %s" , modelsPathDir .getAbsolutePath ()));
1495
1564
}
1496
1565
}
1497
-
1566
+
1498
1567
/**
1499
1568
* Return one HashMap Object contains keys in the first parameter, values in the second parameter. Before calling
1500
1569
* this method, you should be aware that headers should be unique.
@@ -3002,4 +3071,4 @@ public static String[] splitString(String str, String delimiter) {
3002
3071
return categories .toArray (new String [0 ]);
3003
3072
}
3004
3073
3005
- }
3074
+ }
0 commit comments