Skip to content

add gbdt transfer learning #545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ Changes for Shifu-0.2.4
c) If do variable selection again after a model, current work flow no need do normalize step, after variable selection then do training step.
* https://github.com/ShifuML/shifu/issues/49: Add distributed sensitivity analysis variable selection.
a) 'varSelect.wrapperEnabled=true' and 'wrapperBy=SE' in ModelConfig.json#varSelect part to enable sensitivity variable selection.
b) 'wrapperRatio' in ModelConfig.json#varSelect part is a percent to set how many variables will be removed.
b) 'filterOutRatio' in ModelConfig.json#varSelect part is a percent to set how many variables will be removed.
c) To continue variable selection by sensitivity method, run 'shifu varselect' again.
d) With 20 million of records and 1600 variables, 70 minutes (45 minutes for 200 epoch training and 25 minutes for sensitivity variable selection).
* https://github.com/ShifuML/shifu/issues/38: Improve scalability in stats step.
Expand Down
31 changes: 30 additions & 1 deletion src/main/java/ml/shifu/shifu/core/Scorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ public ScoreObject score(final MLDataPair pair, Map<String, String> rawDataMap)
return scoreNsData(pair, CommonUtils.convertRawMapToNsDataMap(rawDataMap));
}

public ScoreObject scoreNsData(MLDataPair inputPair, Map<NSColumn, String> rawNsDataMap) {
public ScoreObject scoreNsData(MLDataPair inputPair, final Map<NSColumn, String> rawNsDataMap) {
if(inputPair == null && !this.alg.equalsIgnoreCase(NNConstants.NN_ALG_NAME)) {
inputPair = CommonUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig,
selectedColumnConfigList, rawNsDataMap, cutoff, alg);
Expand Down Expand Up @@ -377,6 +377,25 @@ public MLData call() {
log.error("error in model evaluation", e);
}
}
} else if(model instanceof TransferLearningTreeModel) {
final TransferLearningTreeModel tltm = (TransferLearningTreeModel) model;

Callable<MLData> callable = new Callable<MLData>() {
@Override
public MLData call() {
MLData result = tltm.compute(rawNsDataMap);
return result;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you change to one line like 'return tltm.compute(rawNsDataMap);'

}
};
if(multiThread) {
tasks.add(callable);
} else {
try {
modelResults.add(callable.call());
} catch (Exception e) {
log.error("error in model evaluation", e);
}
}
} else if(model instanceof GenericModel) {
modelResults.add(new Callable<MLData>() {
@Override
Expand Down Expand Up @@ -477,6 +496,16 @@ public int compare(String o1, String o2) {
if(!tm.isClassfication() && !tm.isGBDT()) {
rfTreeSizeList.add(tm.getTrees().size());
}
} else if (model instanceof TransferLearningTreeModel) {
if(modelConfig.isClassification() && !modelConfig.getTrain().isOneVsAll()) {
double[] scoreArray = score.getData();
for(double sc: scoreArray) {
scores.add(sc);
}
} else {
// if one vs all multiple classification or regression
scores.add(toScore(score.getData(0)));
}
} else if(model instanceof GenericModel) {
scores.add(toScore(score.getData(0)));
} else {
Expand Down
69 changes: 69 additions & 0 deletions src/main/java/ml/shifu/shifu/core/TransferLearningTreeModel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package ml.shifu.shifu.core;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.encog.ml.BasicML;
import org.encog.ml.data.MLData;
import org.encog.ml.data.basic.BasicMLData;

import ml.shifu.shifu.column.NSColumn;
import ml.shifu.shifu.core.dtrain.dt.IndependentTreeModel;
import ml.shifu.shifu.util.CommonUtils;

public class TransferLearningTreeModel extends BasicML {

private static final long serialVersionUID = -8269024520194949153L;

/**
* Tree model instance without dependency on encog.
*/
private transient IndependentTreeModel independentTreeModel;

private transient List<IndependentTreeModel> baseTreeModels;

/**
*
* @param independentTreeModel
*/
public TransferLearningTreeModel(IndependentTreeModel independentTreeModel, List<IndependentTreeModel> baseTreeModels) {
this.independentTreeModel = independentTreeModel;
this.baseTreeModels = baseTreeModels;
}

public final MLData compute(Map<NSColumn, String> rawNsDataMap) {
HashMap<String, Object> rawDataMap = new HashMap<String, Object>();
for (Map.Entry<NSColumn, String> entry : rawNsDataMap.entrySet()) {
rawDataMap.put(entry.getKey().getSimpleName(), entry.getValue());
}

double[] res = this.getIndependentTreeModel().compute(rawDataMap);

for (IndependentTreeModel baseTreeModel : this.baseTreeModels) {
res = CommonUtils.merge(res, baseTreeModel.compute(rawDataMap));
}

return new BasicMLData(res);
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (IndependentTreeModel baseTreeModel : baseTreeModels) {
sb.append(baseTreeModel.getTrees().toString()).append("&");
}
sb.append(this.getIndependentTreeModel().getTrees().toString());

return sb.toString();
}

public IndependentTreeModel getIndependentTreeModel() {
return independentTreeModel;
}

@Override
public void updateProperties() {
// No need implementation
}
}
5 changes: 5 additions & 0 deletions src/main/java/ml/shifu/shifu/core/dtrain/CommonConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ public interface CommonConstants {
public static final String SHIFU_UPDATEBINNING_REDUCER = "shifu.updatebinning.reducer";

public static final String FIXED_LAYERS = "FixedLayers";

// For GBDT Transfer learning
public static final String FIRST_TREE_LEARNING_RATE = "FirstTreeLearningRate";

public static final String GBDT_BASE_MODEL_PATHS = "GBDTBaseModelPaths";

public static final String FIXED_BIAS = "FixedBias";
}
16 changes: 13 additions & 3 deletions src/main/java/ml/shifu/shifu/core/dtrain/dt/DTMaster.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import ml.shifu.shifu.fs.ShifuFileUtils;
import ml.shifu.shifu.util.CommonUtils;

import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
Expand Down Expand Up @@ -268,6 +269,11 @@ public class DTMaster extends AbstractMasterComputable<DTMasterParams, DTWorkerP
* TreeNodes needed to be collected statistics from workers.
*/
private Queue<TreeNode> toDoQueue;

/**
* Flag of this model is transfered from base model
*/
private boolean isTransferLearning;

@Override
public DTMasterParams doCompute(MasterContext<DTMasterParams, DTWorkerParams> context) {
Expand Down Expand Up @@ -1012,7 +1018,10 @@ public void init(MasterContext<DTMasterParams, DTWorkerParams> context) {
// learning rate only effective in gbdt
this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString());
}

this.isTransferLearning = validParams.get(CommonConstants.GBDT_BASE_MODEL_PATHS) != null;
if (isTransferLearning)
LOG.info("Config of GBDT base model is " + ((List<String>)validParams.get(CommonConstants.GBDT_BASE_MODEL_PATHS)).toString());

// initialize impurity type according to regression or classfication
String imStr = validParams.get("Impurity").toString();
int numClasses = 2;
Expand Down Expand Up @@ -1087,14 +1096,15 @@ public int compare(TreeNode o1, TreeNode o2) {
if(existingModel == null) {
// null means no existing model file or model file is in wrong format
this.trees = new CopyOnWriteArrayList<TreeNode>();
this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX), 1d));// learning rate is 1 for 1st
this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX),
this.isTransferLearning ? this.learningRate : 1d));// learning rate is 1 for 1st if it is not transfered
LOG.info("Starting to train model from scratch and existing model is empty.");
} else {
this.trees = existingModel.getTrees();
this.existingTreeSize = this.trees.size();
// starting from existing models, first tree learning rate is current learning rate
this.trees.add(new TreeNode(this.existingTreeSize, new Node(Node.ROOT_INDEX),
this.existingTreeSize == 0 ? 1d : this.learningRate));
(this.existingTreeSize == 0 && !this.isTransferLearning) ? 1d : this.learningRate));
LOG.info("Starting to train model from existing model {} with existing trees {}.",
modelPath, existingTreeSize);
}
Expand Down
68 changes: 57 additions & 11 deletions src/main/java/ml/shifu/shifu/core/dtrain/dt/DTWorker.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -332,6 +333,11 @@ public class DTWorker extends
*/
private Random sampelNegOnlyRandom = new Random(System.currentTimeMillis() + 1000L);

/**
* Used for transfer learning, model will be init in preload method
*/
private List<IndependentTreeModel> baseModels = null;

@Override
public void initRecordReader(GuaguaFileSplit fileSplit) throws IOException {
super.setRecordReader(new GuaguaLineRecordReader(fileSplit));
Expand Down Expand Up @@ -623,7 +629,7 @@ public DTWorkerParams doCompute(WorkerContext<DTMasterParams, DTWorkerParams> co
double predict = predictNode.getPredict().getPredict();
// first tree logic, master must set it to first tree even second tree with ROOT is
// sending
if(context.getLastMasterResult().isFirstTree()) {
if(context.getLastMasterResult().isFirstTree() && !isTransferLearning()) {
data.predict = (float) predict;
} else {
// random drop
Expand Down Expand Up @@ -665,7 +671,8 @@ public DTWorkerParams doCompute(WorkerContext<DTMasterParams, DTWorkerParams> co
}
}

if(context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()) {
if(context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()
&& !isTransferLearning()) {
Node currTree = trees.get(currTreeIndex).getNode();
Node predictNode = predictNodeIndex(currTree, data, true);
if(predictNode.getPredict() != null) {
Expand Down Expand Up @@ -720,7 +727,7 @@ public DTWorkerParams doCompute(WorkerContext<DTMasterParams, DTWorkerParams> co
Node predictNode = predictNodeIndex(node, data, false);
if(predictNode.getPredict() != null) {
double predict = predictNode.getPredict().getPredict();
if(context.getLastMasterResult().isFirstTree()) {
if(context.getLastMasterResult().isFirstTree() && !isTransferLearning()) {
data.predict = (float) predict;
} else {
data.predict += (float) (this.learningRate * predict);
Expand All @@ -729,7 +736,8 @@ public DTWorkerParams doCompute(WorkerContext<DTMasterParams, DTWorkerParams> co
}
}
}
if(context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()) {
if(context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()
&& !isTransferLearning()) {
Node predictNode = predictNodeIndex(trees.get(currTreeIndex).getNode(), data, true);
if(predictNode.getPredict() != null) {
validationError += data.significance * loss
Expand Down Expand Up @@ -1090,15 +1098,29 @@ private Node predictNodeIndex(Node node, Data data, boolean isForErr) {
}
return predictNodeIndex(nextNode, data, isForErr);
}


@Override
/**
* Preload GBDT base model for transfer learning
*/
public void preLoad(WorkerContext<DTMasterParams, DTWorkerParams> context) {
List<String> baseModelPaths = (List<String>) this.modelConfig.getTrain().getParams().get(CommonConstants.GBDT_BASE_MODEL_PATHS);

if (!this.isGBDT || baseModelPaths == null || baseModelPaths.isEmpty()) {
return;
}

this.baseModels = CommonUtils.loadGBDTBaseModels(baseModelPaths);
}

@Override
public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue,
WorkerContext<DTMasterParams, DTWorkerParams> context) {
this.count += 1;
if((this.count) % 5000 == 0) {
LOG.info("Read {} records.", this.count);
}

// hashcode for fixed input split in train and validation
long hashcode = 0;

Expand All @@ -1110,6 +1132,8 @@ public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableA
// the function in akka mode.
int index = 0, inputIndex = 0;
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
Map<String, Object> baseModelInput = new HashMap<String, Object>();

for(String input: this.splitter.split(currentValue.getWritable().toString())) {
if(index == this.columnConfigList.size()) {
// do we need to check if not weighted directly set to 1f; if such logic non-weight at first, then
Expand All @@ -1133,6 +1157,11 @@ public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableA
if(columnConfig != null && columnConfig.isTarget()) {
ideal = getFloatValue(input);
} else {
// put header and value into baseModelInput for transfer learning only
if (this.isTransferLearning()) {
baseModelInput.put(columnConfig.getColumnName(), input);
}

if(!isAfterVarSelect) {
// no variable selected, good candidate but not meta and not target chose
if(!columnConfig.isMeta() && !columnConfig.isTarget()
Expand Down Expand Up @@ -1226,7 +1255,18 @@ public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableA
throw new RuntimeException("Input length is inconsistent with parsing size. Input original size: "
+ inputs.length + ", parsing size:" + inputIndex + ", delimiter:" + delimiter + ".");
}


// calculate new label by inputing train records (title and value)
float predict = ideal;
float output = ideal;
if (this.isTransferLearning()) {
predict = 0f;
for (IndependentTreeModel baseModel : this.baseModels) {
double[] result = baseModel.compute(baseModelInput);
predict += result[0];
}
}

if(this.isOneVsAll) {
// if one vs all, update target value according to index of target
ideal = updateOneVsAllTargetValue(ideal);
Expand Down Expand Up @@ -1258,16 +1298,14 @@ && isInRange(hashcode, startHashCode, endHashCode)) {
}
}

float output = ideal;
float predict = ideal;

// up sampling logic, just add more weights while bagging sampling rate is still not changed
if(modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal, 1d) == 0) {
// Double.compare(ideal, 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0
significance = significance * (this.upSampleRng.sample() + 1);
}

Data data = new Data(inputs, predict, output, output, significance);
Data data = new Data(inputs, predict, output, ideal, significance);

boolean isValidation = false;
if(context.getAttachment() != null && context.getAttachment() instanceof Boolean) {
Expand Down Expand Up @@ -1465,7 +1503,7 @@ private void recoverGBTData(WorkerContext<DTMasterParams, DTWorkerParams> contex
int iterLen = isFailoverOrContinuous ? trees.size() - 1 : trees.size();
for(int i = 0; i < iterLen; i++) {
TreeNode currTree = trees.get(i);
if(i == 0) {
if(i == 0 && !isTransferLearning()) {
double oldPredict = predictNodeIndex(currTree.getNode(), data, false).getPredict().getPredict();
predict = (float) oldPredict;
output = -1f * loss.computeGradient(predict, data.label);
Expand Down Expand Up @@ -1578,6 +1616,14 @@ private float updateOneVsAllTargetValue(float ideal) {
return Float.compare(ideal, trainerId) == 0 ? 1f : 0f;
}

/**
* If user give valid base model which means he need us to do transfer learning
* @return
*/
private boolean isTransferLearning() {
return (this.baseModels != null && !this.baseModels.isEmpty());
}

static class Data implements Serializable, Bytable {

private static final long serialVersionUID = 903201066309036170L;
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/ml/shifu/shifu/core/dtrain/gs/GridSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ private void parseParams(Map<String, Object> params) {

// stats on hyper parameters
for(Entry<String, Object> entry: sortedMap.entrySet()) {
if(entry.getKey().equals("ActivationFunc") || entry.getKey().equals("NumHiddenNodes") || entry.getKey().equals("FixedLayers")) {
if(entry.getKey().equals("ActivationFunc") || entry.getKey().equals("NumHiddenNodes")
|| entry.getKey().equals("FixedLayers") || entry.getKey().equals("GBDTBaseModelPaths")) {
if(entry.getValue() instanceof List) {
if(((List) (entry.getValue())).size() > 0 && ((List) (entry.getValue())).get(0) instanceof List) {
// ActivationFunc and NumHiddenNodes in NN is already List, so as hyper parameter they should be
Expand Down
Loading