Skip to content

Commit ca5f813

Browse files
committed
Only pass what is needed to loss deriv functions
1 parent 8ccd75e commit ca5f813

File tree

3 files changed

+46
-28
lines changed

3 files changed

+46
-28
lines changed

R-package/inst/include/ensemble.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,4 @@ class ENSEMBLE
5858
};
5959

6060

61-
62-
#endif
61+
#endif

R-package/inst/include/loss_functions.hpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
#include "external_rcpp.hpp"
77

88
// ----------- LOSS --------------
9-
double loss(Tvec<double> &y, Tvec<double> &pred, std::string loss_type, Tvec<double> &w, ENSEMBLE* ens_ptr){
9+
double loss(
10+
Tvec<double> &y,
11+
Tvec<double> &pred,
12+
std::string loss_type,
13+
Tvec<double> &w,
14+
double extra_param=0.0
15+
){
16+
// Evaluates the loss function at pred
1017
int n = y.size();
1118
double res = 0;
1219

@@ -38,7 +45,7 @@ double loss(Tvec<double> &y, Tvec<double> &pred, std::string loss_type, Tvec<dou
3845
res += y[i]*w[i]*exp(-pred[i]) + pred[i];
3946
}
4047
}else if(loss_type=="negbinom"){
41-
double dispersion = ens_ptr -> extra_param;
48+
double dispersion = extra_param;
4249
for(int i=0; i<n; i++){
4350
// log-link, mu=exp(pred[i])
4451
res += -y[i]*pred[i] + (y[i]*dispersion)*log(1.0+exp(pred[i])/dispersion); // Keep only relevant part
@@ -75,8 +82,13 @@ double loss(Tvec<double> &y, Tvec<double> &pred, std::string loss_type, Tvec<dou
7582
}
7683

7784

78-
Tvec<double> dloss(Tvec<double> &y, Tvec<double> &pred, std::string loss_type, ENSEMBLE* ens_ptr){
79-
85+
Tvec<double> dloss(
86+
Tvec<double> &y,
87+
Tvec<double> &pred,
88+
std::string loss_type,
89+
double extra_param=0.0
90+
){
91+
// Returns the first order derivative of the loss function at pred
8092
int n = y.size();
8193
Tvec<double> g(n);
8294

@@ -107,7 +119,7 @@ Tvec<double> dloss(Tvec<double> &y, Tvec<double> &pred, std::string loss_type, E
107119
}
108120
}else if(loss_type == "negbinom"){
109121
// NEGATIVE BINOMIAL, LOG LINK
110-
double dispersion = ens_ptr->extra_param;
122+
double dispersion = extra_param;
111123
for(int i=0; i<n; i++){
112124
g[i] = -y[i] + (y[i]+dispersion)*exp(pred[i]) / (dispersion + exp(pred[i]));
113125
}
@@ -143,7 +155,13 @@ Tvec<double> dloss(Tvec<double> &y, Tvec<double> &pred, std::string loss_type, E
143155
}
144156

145157

146-
Tvec<double> ddloss(Tvec<double> &y, Tvec<double> &pred, std::string loss_type, ENSEMBLE* ens_ptr){
158+
Tvec<double> ddloss(
159+
Tvec<double> &y,
160+
Tvec<double> &pred,
161+
std::string loss_type,
162+
double extra_param=0.0
163+
){
164+
// Returns the second order derivative of the loss function at pred
147165
int n = y.size();
148166
Tvec<double> h(n);
149167

@@ -174,7 +192,7 @@ Tvec<double> ddloss(Tvec<double> &y, Tvec<double> &pred, std::string loss_type,
174192
}
175193
}else if( loss_type == "negbinom" ){
176194
// NEGATIVE BINOMIAL, LOG LINK
177-
double dispersion = ens_ptr->extra_param;
195+
double dispersion = extra_param;
178196
for(int i=0; i<n; i++){
179197
h[i] = (y[i]+dispersion)*dispersion*exp(pred[i]) /
180198
( (dispersion + exp(pred[i]))*(dispersion + exp(pred[i])) );
@@ -217,4 +235,5 @@ Tvec<double> ddloss(Tvec<double> &y, Tvec<double> &pred, std::string loss_type,
217235
return h;
218236
}
219237

220-
#endif
238+
239+
#endif

R-package/src/agtboost.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,11 @@ void ENSEMBLE::train(
173173
}
174174
pred.setConstant(this->initialPred);
175175
pred += offset;
176-
this->initial_score = loss(y, pred, loss_function, w, this);
176+
this->initial_score = loss(y, pred, loss_function, w, extra_param);
177177

178178
// First tree
179-
g = dloss(y, pred, loss_function, this) * w;
180-
h = ddloss(y, pred, loss_function, this) * w;
179+
g = dloss(y, pred, loss_function, extra_param) * w;
180+
h = ddloss(y, pred, loss_function, extra_param) * w;
181181
this->first_tree = new GBTREE;
182182
this->first_tree->train(g, h, X, cir_sim, greedy_complexities, learning_rate);
183183
GBTREE* current_tree = this->first_tree;
@@ -187,7 +187,7 @@ void ENSEMBLE::train(
187187
verbose,
188188
1,
189189
current_tree->getNumLeaves(),
190-
loss(y, pred, loss_function, w, this),
190+
loss(y, pred, loss_function, w, extra_param),
191191
this->estimate_generalization_loss(1)
192192
);
193193

@@ -197,8 +197,8 @@ void ENSEMBLE::train(
197197
if (i % 1 == 0)
198198
Rcpp::checkUserInterrupt();
199199
// Calculate gradients
200-
g = dloss(y, pred, loss_function, this) * w;
201-
h = ddloss(y, pred, loss_function, this) * w;
200+
g = dloss(y, pred, loss_function, extra_param) * w;
201+
h = ddloss(y, pred, loss_function, extra_param) * w;
202202
// Check for perfect fit
203203
if(((g.array())/h.array()).matrix().maxCoeff() < 1e-12){
204204
// Every perfect step is below tresh
@@ -212,7 +212,7 @@ void ENSEMBLE::train(
212212
// Calculate expected generalization loss for tree
213213
expected_loss = tree_expected_test_reduction(new_tree, learning_rate);
214214
// Update ensemble training loss and ensemble optimism for iteration k-1
215-
ensemble_training_loss = loss(y, pred, loss_function, w, this);
215+
ensemble_training_loss = loss(y, pred, loss_function, w, extra_param);
216216
ensemble_approx_training_loss = this->estimate_training_loss(i-1) +
217217
new_tree->getTreeScore() * (-2)*learning_rate*(learning_rate/2 - 1);
218218
ensemble_optimism = this->estimate_optimism(i-1) +
@@ -253,18 +253,18 @@ void ENSEMBLE::train_from_preds(Tvec<double> &pred, Tvec<double> &y, Tmat<double
253253
Tvec<double> g(n), h(n);
254254

255255
// Initial prediction
256-
g = dloss(y, pred, loss_function, this)*w;
257-
h = ddloss(y, pred, loss_function, this)*w;
256+
g = dloss(y, pred, loss_function, extra_param)*w;
257+
h = ddloss(y, pred, loss_function, extra_param)*w;
258258
this->initialPred = - g.sum() / h.sum();
259259
pred = pred.array() + this->initialPred;
260-
this->initial_score = loss(y, pred, loss_function, w, this); //(y - pred).squaredNorm() / n;
260+
this->initial_score = loss(y, pred, loss_function, w, extra_param); //(y - pred).squaredNorm() / n;
261261

262262
// Prepare cir matrix
263263
Tmat<double> cir_sim = cir_sim_mat(100, 100);
264264

265265
// First tree
266-
g = dloss(y, pred, loss_function, this)*w;
267-
h = ddloss(y, pred, loss_function, this)*w;
266+
g = dloss(y, pred, loss_function, extra_param)*w;
267+
h = ddloss(y, pred, loss_function, extra_param)*w;
268268
this->first_tree = new GBTREE;
269269
this->first_tree->train(g, h, X, cir_sim, greedy_complexities, learning_rate_set);
270270
GBTREE* current_tree = this->first_tree;
@@ -277,7 +277,7 @@ void ENSEMBLE::train_from_preds(Tvec<double> &pred, Tvec<double> &y, Tmat<double
277277
std::setprecision(4) <<
278278
"it: " << 1 <<
279279
" | n-leaves: " << current_tree->getNumLeaves() <<
280-
" | tr loss: " << loss(y, pred, loss_function, w, this) <<
280+
" | tr loss: " << loss(y, pred, loss_function, w, extra_param) <<
281281
" | gen loss: " << this->estimate_generalization_loss(1) <<
282282
std::endl;
283283
}
@@ -292,8 +292,8 @@ void ENSEMBLE::train_from_preds(Tvec<double> &pred, Tvec<double> &y, Tmat<double
292292

293293
// TRAINING
294294
GBTREE* new_tree = new GBTREE();
295-
g = dloss(y, pred, loss_function, this)*w;
296-
h = ddloss(y, pred, loss_function, this)*w;
295+
g = dloss(y, pred, loss_function, extra_param)*w;
296+
h = ddloss(y, pred, loss_function, extra_param)*w;
297297
new_tree->train(g, h, X, cir_sim, greedy_complexities, learning_rate_set);
298298

299299
// EXPECTED LOSS
@@ -310,7 +310,7 @@ void ENSEMBLE::train_from_preds(Tvec<double> &pred, Tvec<double> &y, Tmat<double
310310
std::setprecision(4) <<
311311
"it: " << i <<
312312
" | n-leaves: " << current_tree->getNumLeaves() <<
313-
" | tr loss: " << loss(y, pred, loss_function, w, this) <<
313+
" | tr loss: " << loss(y, pred, loss_function, w, extra_param) <<
314314
" | gen loss: " << this->estimate_generalization_loss(i-1) + expected_loss <<
315315
std::endl;
316316

@@ -504,7 +504,7 @@ Tvec<double> ENSEMBLE::convergence(Tvec<double> &y, Tmat<double> &X){
504504
w.setOnes();
505505

506506
// After each update (tree), compute loss
507-
loss_val[0] = loss(y, pred, this->loss_function, w, this);
507+
loss_val[0] = loss(y, pred, this->loss_function, w, extra_param);
508508

509509
GBTREE* current = this->first_tree;
510510
for(int k=1; k<(K+1); k++)
@@ -513,7 +513,7 @@ Tvec<double> ENSEMBLE::convergence(Tvec<double> &y, Tmat<double> &X){
513513
pred = pred + (this->learning_rate) * (current->predict_data(X));
514514

515515
// Compute loss
516-
loss_val[k] = loss(y, pred, this->loss_function, w, this);
516+
loss_val[k] = loss(y, pred, this->loss_function, w, extra_param);
517517

518518
// Update to next tree
519519
current = current->next_tree;

0 commit comments

Comments
 (0)