Skip to content

Commit c31164b

Browse files
Merge pull request opencv#18126 from danielenricocahall:add-oob-error-sample-weighting
Account for sample weights in calculating OOB Error * account for sample weights in oob error calculation * redefine oob error functions * fix ABI compatibility
1 parent 3835ab3 commit c31164b

File tree

3 files changed

+79
-2
lines changed

3 files changed

+79
-2
lines changed

modules/ml/include/opencv2/ml.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,15 @@ class CV_EXPORTS_W RTrees : public DTrees
12941294
*/
12951295
CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const;
12961296

1297+
/** Returns the OOB error value, computed at the training stage when calcOOBError is set to true.
1298+
* If this flag was set to false, 0 is returned. The OOB error is also scaled by sample weighting.
1299+
*/
1300+
#if CV_VERSION_MAJOR == 3
1301+
CV_WRAP double getOOBError() const;
1302+
#else
1303+
/*CV_WRAP*/ virtual double getOOBError() const = 0;
1304+
#endif
1305+
12971306
/** Creates the empty model.
12981307
Use StatModel::train to train the model, StatModel::train to create and train the model,
12991308
Algorithm::load to load the pre-trained model.

modules/ml/src/rtrees.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,14 @@ class DTreesImplForRTrees CV_FINAL : public DTreesImpl
216216
sample = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
217217

218218
double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
219+
double sample_weight = w->sample_weights[w->sidx[j]];
219220
if( !_isClassifier )
220221
{
221222
oobres[j] += val;
222223
oobcount[j]++;
223224
double true_val = w->ord_responses[w->sidx[j]];
224225
double a = oobres[j]/oobcount[j] - true_val;
225-
oobError += a*a;
226+
oobError += sample_weight * a*a;
226227
val = (val - true_val)/max_response;
227228
ncorrect_responses += std::exp( -val*val );
228229
}
@@ -237,7 +238,7 @@ class DTreesImplForRTrees CV_FINAL : public DTreesImpl
237238
if( votes[best_class] < votes[k] )
238239
best_class = k;
239240
int diff = best_class != w->cat_responses[w->sidx[j]];
240-
oobError += diff;
241+
oobError += sample_weight * diff;
241242
ncorrect_responses += diff == 0;
242243
}
243244
}
@@ -421,6 +422,10 @@ class DTreesImplForRTrees CV_FINAL : public DTreesImpl
421422
}
422423
}
423424

425+
double getOOBError() const {
426+
return oobError;
427+
}
428+
424429
RTreeParams rparams;
425430
double oobError;
426431
vector<float> varImportance;
@@ -505,6 +510,12 @@ class RTreesImpl CV_FINAL : public RTrees
505510
const vector<Node>& getNodes() const CV_OVERRIDE { return impl.getNodes(); }
506511
const vector<Split>& getSplits() const CV_OVERRIDE { return impl.getSplits(); }
507512
const vector<int>& getSubsets() const CV_OVERRIDE { return impl.getSubsets(); }
513+
#if CV_VERSION_MAJOR == 3
514+
double getOOBError_() const { return impl.getOOBError(); }
515+
#else
516+
double getOOBError() const CV_OVERRIDE { return impl.getOOBError(); }
517+
#endif
518+
508519

509520
DTreesImplForRTrees impl;
510521
};
@@ -532,6 +543,17 @@ void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
532543
return this_->getVotes_(input, output, flags);
533544
}
534545

546+
#if CV_VERSION_MAJOR == 3
547+
double RTrees::getOOBError() const
548+
{
549+
CV_TRACE_FUNCTION();
550+
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
551+
if(!this_)
552+
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
553+
return this_->getOOBError_();
554+
}
555+
#endif
556+
535557
}}
536558

537559
// End of file.

modules/ml/test/test_rtrees.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,50 @@ TEST(ML_RTrees, getVotes)
5151
EXPECT_EQ(result.at<float>(0, predicted_class), rt->predict(test));
5252
}
5353

54+
TEST(ML_RTrees, 11142_sample_weights_regression)
55+
{
56+
int n = 3;
57+
// RTrees for regression
58+
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
59+
//simple regression problem of x -> 2x
60+
Mat data = (Mat_<float>(n,1) << 1, 2, 3);
61+
Mat values = (Mat_<float>(n,1) << 2, 4, 6);
62+
Mat weights = (Mat_<float>(n, 1) << 10, 10, 10);
63+
64+
Ptr<TrainData> trainData = TrainData::create(data, ml::ROW_SAMPLE, values);
65+
rt->train(trainData);
66+
double error_without_weights = round(rt->getOOBError());
67+
rt->clear();
68+
Ptr<TrainData> trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, values, Mat(), Mat(), weights );
69+
rt->train(trainDataWithWeights);
70+
double error_with_weights = round(rt->getOOBError());
71+
// error with weights should be larger than error without weights
72+
EXPECT_GE(error_with_weights, error_without_weights);
73+
}
74+
75+
TEST(ML_RTrees, 11142_sample_weights_classification)
76+
{
77+
int n = 12;
78+
// RTrees for classification
79+
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
80+
81+
Mat data(n, 4, CV_32F);
82+
randu(data, 0, 10);
83+
Mat labels = (Mat_<int>(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2);
84+
Mat weights = (Mat_<float>(n, 1) << 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10);
85+
86+
rt->train(data, ml::ROW_SAMPLE, labels);
87+
rt->clear();
88+
double error_without_weights = round(rt->getOOBError());
89+
Ptr<TrainData> trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, labels, Mat(), Mat(), weights );
90+
rt->train(data, ml::ROW_SAMPLE, labels);
91+
double error_with_weights = round(rt->getOOBError());
92+
std::cout << error_without_weights << std::endl;
93+
std::cout << error_with_weights << std::endl;
94+
// error with weights should be larger than error without weights
95+
EXPECT_GE(error_with_weights, error_without_weights);
96+
}
97+
98+
99+
54100
}} // namespace

0 commit comments

Comments
 (0)