Skip to content

Mtl training #674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: develop-mtl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d7b4d8f
comment out 2 unittest to build successful.
RoastEgg Jul 9, 2019
3a3b224
comment out 2 unittest to build successful.
RoastEgg Jul 9, 2019
894050e
Merge branch 'master' of github.com:RoastEgg/shifu
RoastEgg Jul 16, 2019
b62034a
complete MultiTaskNN except some read and write methods.
RoastEgg Jul 17, 2019
f2c0053
complete MTNNMaster except 'loadModel'
RoastEgg Jul 18, 2019
07d238e
Merge branch 'master' into develop
RoastEgg Jul 18, 2019
ad39193
fix merge conflits
RoastEgg Jul 19, 2019
8172870
some work on master and modification on model.
RoastEgg Jul 21, 2019
dfb3c37
complete MTNNParallelGradient and modify other classes.
RoastEgg Jul 22, 2019
bf68df8
init method and doCompute method of MTNNWorker.
RoastEgg Jul 23, 2019
f90a2b9
refactor params. Errors in load method of Worker!
RoastEgg Jul 24, 2019
8b2f896
fix the compile errors in MTNNWorker.(comment out sampling and update…
RoastEgg Jul 24, 2019
086867c
Merge branch 'develop-mtl' into develop
RoastEgg Jul 25, 2019
60ca53b
merge the branch develop-mtl. Begin to use MTNNtest. modify the prope…
RoastEgg Jul 25, 2019
4f7784a
still can't reach the doCompute method in WDLMater or MTNNMaster.
RoastEgg Jul 26, 2019
ad4607e
remove resource in test
RoastEgg Jul 26, 2019
ae207a7
a lot of bugs fixed and the test of MTNN passed.
RoastEgg Aug 1, 2019
d35cad5
MTNNOuput finished and local test passed.
RoastEgg Aug 5, 2019
a1c8383
fix bug(trianErrors and validationErrors initial in MTNNParams) and run
RoastEgg Aug 7, 2019
e42d770
fix bug of ActivationFactory
RoastEgg Aug 7, 2019
e962a84
Evaluation code done.Waiting for debug in hadoop.
RoastEgg Aug 8, 2019
10a3efa
fix bug 'unsupport models' in eval.
RoastEgg Aug 12, 2019
b597155
rename classes about MultiTask Learning.
RoastEgg Aug 13, 2019
b1d39fb
First version code to fit mtl(waiting for debugging.)
RoastEgg Aug 16, 2019
f2f06c4
SampleWeight modified in Data in Worker.
RoastEgg Aug 20, 2019
c75f163
Mtl training pass in hadoop and local.
RoastEgg Aug 23, 2019
ece3e3d
mtl training done on hadoop.
RoastEgg Sep 3, 2019
8130716
refactor code.
RoastEgg Sep 3, 2019
20cf051
resource congifs for unit tests.
RoastEgg Sep 4, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class ModelTrainConf {
* @author Zhang David (pengzhang@paypal.com)
*/
public static enum ALGORITHM {
NN, LR, SVM, DT, RF, GBT, TENSORFLOW, WDL
NN, LR, SVM, DT, RF, GBT, TENSORFLOW, WDL, MTL
}

/**
Expand Down
66 changes: 66 additions & 0 deletions src/main/java/ml/shifu/shifu/core/MTLModel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright [2013-2019] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core;

import ml.shifu.shifu.core.dtrain.mtl.IndependentMTLModel;
import org.encog.ml.BasicML;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData;
import org.encog.ml.data.basic.BasicMLData;

import java.io.IOException;
import java.io.InputStream;

/**
* @author haillu
*/
public class MTLModel extends BasicML implements MLRegression {
private IndependentMTLModel independentMTLModel;

public MTLModel(IndependentMTLModel independentMTLModel) {
this.independentMTLModel = independentMTLModel;
}

@Override
public void updateProperties() {
// No need to implement
}

@Override
public MLData compute(MLData input) {
double[] result = independentMTLModel.compute(input.getData());
return new BasicMLData(result);
}

@Override
public int getInputCount() {
return independentMTLModel.getMtl().getInputSize();
}

@Override
public int getOutputCount() {
return independentMTLModel.getMtl().getTaskNumber();
}

public static MTLModel loadFromStream(InputStream input) throws IOException {
return new MTLModel(IndependentMTLModel.loadFromStream(input));
}

public static MTLModel loadFromStream(InputStream input, boolean isRemoveNameSpace) throws IOException {
return new MTLModel(IndependentMTLModel.loadFromStream(input, isRemoveNameSpace));
}

}
24 changes: 24 additions & 0 deletions src/main/java/ml/shifu/shifu/core/Scorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,28 @@ public MLData call() {
log.error("error in model evaluation", e);
}
}
} else if (model instanceof MTLModel){
final MTLModel mtl = (MTLModel) model;
if(mtl.getInputCount() != pair.getInput().size()) {
throw new RuntimeException("MTL and input size mismatch: mtl input Size = " + mtl.getInputCount()
+ "; data input Size = " + pair.getInput().size());
}

Callable<MLData> callable = new Callable<MLData>() {
@Override
public MLData call() {
return new BasicMLData(mtl.compute(pair.getInput()));
}
};
if(multiThread) {
tasks.add(callable);
} else {
try {
modelResults.add(callable.call());
} catch (Exception e) {
log.error("error in MTL model evaluation", e);
}
}
} else {
throw new RuntimeException("unsupport models");
}
Expand Down Expand Up @@ -525,6 +547,8 @@ public int compare(String o1, String o2) {
scores.add(toScore(score.getData(0)));
} else if(model instanceof WDLModel) {
scores.add(toScore(score.getData(0)));
} else if (model instanceof MTLModel){
scores.add(toScore(score.getData(0)));
} else {
throw new RuntimeException("unsupport models");
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/ml/shifu/shifu/core/dtrain/CommonConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ public interface CommonConstants {

public static final int WDL_FORMAT_VERSION = 1;

public static final int MTL_FORMAT_VERSION = 1;

public static final int DEFAULT_EMBEDING_OUTPUT = 8;

public static final String WIDE_ENABLE = "wideEnable";
Expand Down
27 changes: 27 additions & 0 deletions src/main/java/ml/shifu/shifu/core/dtrain/SerializationType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package ml.shifu.shifu.core.dtrain;

import java.util.Arrays;

/**
* @author haillu
*/
public enum SerializationType {
/**
* Serialize types, each of them including different serialize scope
*/
WEIGHTS(0), GRADIENTS(1), MODEL_SPEC(2), ERROR(-1);

int value;

SerializationType(int type) {
this.value = type;
}

public static SerializationType getSerializationType(int value) {
return Arrays.stream(values()).filter(type -> type.value == value).findFirst().orElse(SerializationType.ERROR);
}

public int getValue() {
return this.value;
}
}
136 changes: 136 additions & 0 deletions src/main/java/ml/shifu/shifu/core/dtrain/mtl/BinaryMTLSerializer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright [2013-2019] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.dtrain.mtl;

import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.core.Normalizer;
import ml.shifu.shifu.core.dtrain.CommonConstants;
import ml.shifu.shifu.core.dtrain.DTrainUtils;
import ml.shifu.shifu.core.dtrain.SerializationType;
import ml.shifu.shifu.core.dtrain.StringUtils;
import ml.shifu.shifu.core.dtrain.nn.NNColumnStats;
import ml.shifu.shifu.util.CommonUtils;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IOUtils;

import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.GZIPOutputStream;

/**
* @author haillu
*/
public class BinaryMTLSerializer {
public static void save(ModelConfig modelConfig, List<List<ColumnConfig>> mtlColumnConfigLists,
MultiTaskLearning mtl, FileSystem fs, Path output) throws IOException {
DataOutputStream fos = null;
try {
fos = new DataOutputStream(new GZIPOutputStream(fs.create(output)));

// version
fos.writeInt(CommonConstants.MTL_FORMAT_VERSION);
// Reserved two double field, one double field and one string field
fos.writeDouble(0.0f);
fos.writeDouble(0.0f);
fos.writeDouble(0.0d);
fos.writeUTF("Reserved field");

// write normStr
String normStr = modelConfig.getNormalize().getNormType().toString();
StringUtils.writeString(fos, normStr);

// write task number.
fos.writeInt(mtlColumnConfigLists.size());

for(List<ColumnConfig> ccs: mtlColumnConfigLists) {
// compute columns needed
Map<Integer, String> columnIndexNameMapping = getIndexNameMapping(ccs);

// write column stats to output
List<NNColumnStats> csList = new ArrayList<>();
for(ColumnConfig cc: ccs) {
if(columnIndexNameMapping.containsKey(cc.getColumnNum())) {
NNColumnStats cs = new NNColumnStats();
cs.setCutoff(modelConfig.getNormalizeStdDevCutOff());
cs.setColumnType(cc.getColumnType());
cs.setMean(cc.getMean());
cs.setStddev(cc.getStdDev());
cs.setColumnNum(cc.getColumnNum());
cs.setColumnName(cc.getColumnName());
cs.setBinCategories(cc.getBinCategory());
cs.setBinBoundaries(cc.getBinBoundary());
cs.setBinPosRates(cc.getBinPosRate());
cs.setBinCountWoes(cc.getBinCountWoe());
cs.setBinWeightWoes(cc.getBinWeightedWoe());

// TODO cache such computation
double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, false);
cs.setWoeMean(meanAndStdDev[0]);
cs.setWoeStddev(meanAndStdDev[1]);
double[] weightMeanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, true);
cs.setWoeWgtMean(weightMeanAndStdDev[0]);
cs.setWoeWgtStddev(weightMeanAndStdDev[1]);

csList.add(cs);
}
}

fos.writeInt(csList.size());
for(NNColumnStats cs: csList) {
cs.write(fos);
}

Map<Integer, Integer> columnMapping = DTrainUtils.getColumnMapping(ccs);
fos.writeInt(columnMapping.size());
for(Map.Entry<Integer, Integer> entry: columnMapping.entrySet()) {
fos.writeInt(entry.getKey());
fos.writeInt(entry.getValue());
}

}

// persist multi task learning Model
mtl.write(fos, SerializationType.MODEL_SPEC);
} finally {
IOUtils.closeStream(fos);
}
}

private static Map<Integer, String> getIndexNameMapping(List<ColumnConfig> columnConfigList) {
Map<Integer, String> columnIndexNameMapping = new HashMap<>(columnConfigList.size());
for(ColumnConfig columnConfig: columnConfigList) {
if(columnConfig.isFinalSelect()) {
columnIndexNameMapping.put(columnConfig.getColumnNum(), columnConfig.getColumnName());
}
}

if(columnIndexNameMapping.size() == 0) {
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for(ColumnConfig columnConfig: columnConfigList) {
if(CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
columnIndexNameMapping.put(columnConfig.getColumnNum(), columnConfig.getColumnName());
}
}
}
return columnIndexNameMapping;
}
}
Loading