Skip to content

Commit a1348af

Browse files
committed
fix incorrect CalcPartialGainFromUpdate which was used in calculating the purified interaction score
1 parent c5fb602 commit a1348af

File tree

2 files changed

+27
-60
lines changed

2 files changed

+27
-60
lines changed

shared/libebm/PartitionTwoDimensionalInteraction.cpp

Lines changed: 17 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -230,23 +230,7 @@ template<bool bHessian, size_t cCompilerScores> class PartitionTwoDimensionalInt
230230
hess10 = w10;
231231
hess11 = w11;
232232
}
233-
if(hess00 < hessianMin) {
234-
goto next;
235-
}
236-
if(hess01 < hessianMin) {
237-
goto next;
238-
}
239-
if(hess10 < hessianMin) {
240-
goto next;
241-
}
242-
if(hess11 < hessianMin) {
243-
goto next;
244-
}
245233

246-
const FloatCalc d00 = hess00;
247-
const FloatCalc d01 = hess01;
248-
const FloatCalc d10 = hess10;
249-
const FloatCalc d11 = hess11;
250234
if(CalcInteractionFlags_Purify & flags) {
251235
// purified gain
252236

@@ -326,54 +310,33 @@ template<bool bHessian, size_t cCompilerScores> class PartitionTwoDimensionalInt
326310
// (1 + weight00 / weight01 + weight00 / weight10 + weight00 / weight11)
327311
// The other pure effects can be derived the same way.
328312

329-
if(FloatCalc{0} != d00 && FloatCalc{0} != d01 && FloatCalc{0} != d10 && FloatCalc{0} != d11) {
313+
// if any of the weights are zero then the purified gain will be zero
314+
if(FloatCalc{0} != w00 && FloatCalc{0} != w01 && FloatCalc{0} != w10 && FloatCalc{0} != w11) {
330315

331316
// TODO: instead of checking the denominators for zero above, can we do it earlier?
332317
// If we're using hessians then we'd need it here, but we aren't using them yet
333318

334-
// calculate what the full updates would be for non-purified:
335-
// u = update (non-purified)
336-
const FloatCalc negUpdate00 = grad00 / hess00;
337-
const FloatCalc negUpdate01 = grad01 / hess01;
338-
const FloatCalc negUpdate10 = grad10 / hess10;
339-
const FloatCalc negUpdate11 = grad11 / hess11;
319+
// Calculate the unpurified updates. Purification is invariant to the sign,
320+
// so we can purify the negative updates and get the same result.
321+
const FloatCalc negUpdate00 = CalcNegUpdate(grad00, hess00);
322+
const FloatCalc negUpdate01 = CalcNegUpdate(grad01, hess01);
323+
const FloatCalc negUpdate10 = CalcNegUpdate(grad10, hess10);
324+
const FloatCalc negUpdate11 = CalcNegUpdate(grad11, hess11);
340325

341326
// common part of equations (positive for 00 & 11 equations, negative for 01 and 10)
342327
const FloatCalc common = negUpdate00 - negUpdate01 - negUpdate10 + negUpdate11;
343328

344-
// p = purified NEGATIVE update.
345-
const FloatCalc negPure00 = common / (FloatCalc{1} + d00 / d01 + d00 / d10 + d00 / d11);
346-
const FloatCalc negPure01 = common / (FloatCalc{-1} - d01 / d00 - d01 / d10 - d01 / d11);
347-
const FloatCalc negPure10 = common / (FloatCalc{-1} - d10 / d00 - d10 / d01 - d10 / d11);
348-
const FloatCalc negPure11 = common / (FloatCalc{1} + d11 / d00 + d11 / d01 + d11 / d10);
329+
const FloatCalc negPure00 = common / (FloatCalc{1} + w00 / w01 + w00 / w10 + w00 / w11);
330+
const FloatCalc negPure01 = common / (FloatCalc{-1} - w01 / w00 - w01 / w10 - w01 / w11);
331+
const FloatCalc negPure10 = common / (FloatCalc{-1} - w10 / w00 - w10 / w01 - w10 / w11);
332+
const FloatCalc negPure11 = common / (FloatCalc{1} + w11 / w00 + w11 / w01 + w11 / w10);
349333

350334
// g = partial gain
351-
const FloatCalc g00 = CalcPartialGainFromUpdate(hess00, negPure00);
352-
const FloatCalc g01 = CalcPartialGainFromUpdate(hess01, negPure01);
353-
const FloatCalc g10 = CalcPartialGainFromUpdate(hess10, negPure10);
354-
const FloatCalc g11 = CalcPartialGainFromUpdate(hess11, negPure11);
355-
#ifndef NDEBUG
356-
// r = reconsituted numerator (after purification)
357-
const FloatCalc r00 = negPure00 * d00;
358-
const FloatCalc r01 = negPure01 * d01;
359-
const FloatCalc r10 = negPure10 * d10;
360-
const FloatCalc r11 = negPure11 * d11;
361-
362-
// purification means summing any direction gives us zero
363-
EBM_ASSERT(std::abs(r00 + r01) < 0.001);
364-
EBM_ASSERT(std::abs(r01 + r11) < 0.001);
365-
EBM_ASSERT(std::abs(r11 + r10) < 0.001);
366-
EBM_ASSERT(std::abs(r10 + r00) < 0.001);
367-
368-
// if all of these added together are zero, then the parent partial gain should also
369-
// be zero, which means we can avoid calculating the parent partial gain.
370-
EBM_ASSERT(std::abs(r00 + r01 + r10 + r11) < 0.001);
371-
372-
EBM_ASSERT(std::abs(g00 - CalcPartialGain(r00, d00)) < 0.001);
373-
EBM_ASSERT(std::abs(g01 - CalcPartialGain(r01, d01)) < 0.001);
374-
EBM_ASSERT(std::abs(g10 - CalcPartialGain(r10, d10)) < 0.001);
375-
EBM_ASSERT(std::abs(g11 - CalcPartialGain(r11, d11)) < 0.001);
376-
#endif // NDEBUG
335+
const FloatCalc g00 = CalcPartialGainFromUpdate(grad00, hess00, negPure00);
336+
const FloatCalc g01 = CalcPartialGainFromUpdate(grad01, hess01, negPure01);
337+
const FloatCalc g10 = CalcPartialGainFromUpdate(grad10, hess10, negPure10);
338+
const FloatCalc g11 = CalcPartialGainFromUpdate(grad11, hess11, negPure11);
339+
377340
gain += g00;
378341
gain += g01;
379342
gain += g10;

shared/libebm/ebm_stats.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include "logging.h" // EBM_ASSERT
99
#include "unzoned.h" // INLINE_ALWAYS, LIKELY, UNLIKELY
1010

11-
#include "ebm_internal.hpp"
11+
#include "ebm_internal.hpp" // FloatCalc
1212

1313
namespace DEFINED_ZONE_NAME {
1414
#ifndef DEFINED_ZONE_NAME
@@ -77,15 +77,15 @@ INLINE_ALWAYS static FloatCalc CalcNegUpdate(const FloatCalc sumGradient, const
7777
return UNLIKELY(sumHessian < k_hessianMin) ? FloatCalc{0} : sumGradient / sumHessian;
7878
}
7979

80-
INLINE_ALWAYS static FloatCalc CalcPartialGainFromUpdate(const FloatCalc sumHessian, const FloatCalc negUpdate) {
80+
INLINE_ALWAYS static FloatCalc CalcPartialGainFromUpdate(
81+
const FloatCalc sumGradient, const FloatCalc sumHessian, const FloatCalc negUpdate) {
8182
// a loss function with negative hessians would be unstable
8283
EBM_ASSERT(std::isnan(sumHessian) || FloatCalc{0} <= sumHessian);
8384

84-
EBM_ASSERT(FloatCalc{0} < k_hessianMin);
85-
const FloatCalc partialGain =
86-
UNLIKELY(sumHessian < k_hessianMin) ? FloatCalc{0} : negUpdate * negUpdate * sumHessian;
85+
// do not consider k_hessianMin here since the update should be zero if sumHessian is below
86+
87+
const FloatCalc partialGain = negUpdate * (sumGradient * FloatCalc{2} - negUpdate * sumHessian);
8788

88-
EBM_ASSERT(std::isnan(negUpdate) || std::isnan(sumHessian) || FloatCalc{0} <= partialGain);
8989
return partialGain;
9090
}
9191

@@ -107,6 +107,10 @@ INLINE_ALWAYS static FloatCalc CalcPartialGain(const FloatCalc sumGradient, cons
107107
// in another bin, and then the two bins are added together. That would lead to NaN even without NaN samples.
108108

109109
EBM_ASSERT(std::isnan(sumGradient) || std::isnan(sumHessian) || FloatCalc{0} <= partialGain);
110+
111+
EBM_ASSERT(IsApproxEqual(
112+
partialGain, CalcPartialGainFromUpdate(sumGradient, sumHessian, CalcNegUpdate(sumGradient, sumHessian))));
113+
110114
return partialGain;
111115
}
112116

0 commit comments

Comments
 (0)