@@ -230,23 +230,7 @@ template<bool bHessian, size_t cCompilerScores> class PartitionTwoDimensionalInt
230230 hess10 = w10;
231231 hess11 = w11;
232232 }
233- if (hess00 < hessianMin) {
234- goto next;
235- }
236- if (hess01 < hessianMin) {
237- goto next;
238- }
239- if (hess10 < hessianMin) {
240- goto next;
241- }
242- if (hess11 < hessianMin) {
243- goto next;
244- }
245233
246- const FloatCalc d00 = hess00;
247- const FloatCalc d01 = hess01;
248- const FloatCalc d10 = hess10;
249- const FloatCalc d11 = hess11;
250234 if (CalcInteractionFlags_Purify & flags) {
251235 // purified gain
252236
@@ -326,54 +310,33 @@ template<bool bHessian, size_t cCompilerScores> class PartitionTwoDimensionalInt
326310 // (1 + weight00 / weight01 + weight00 / weight10 + weight00 / weight11)
327311 // The other pure effects can be derived the same way.
328312
329- if (FloatCalc{0 } != d00 && FloatCalc{0 } != d01 && FloatCalc{0 } != d10 && FloatCalc{0 } != d11) {
313+ // if any of the weights are zero then the purified gain will be zero
314+ if (FloatCalc{0 } != w00 && FloatCalc{0 } != w01 && FloatCalc{0 } != w10 && FloatCalc{0 } != w11) {
330315
331316 // TODO: instead of checking the denominators for zero above, can we do it earlier?
332317 // If we're using hessians then we'd need it here, but we aren't using them yet
333318
334- // calculate what the full updates would be for non-purified:
335- // u = update (non-purified)
336- const FloatCalc negUpdate00 = grad00 / hess00;
337- const FloatCalc negUpdate01 = grad01 / hess01;
338- const FloatCalc negUpdate10 = grad10 / hess10;
339- const FloatCalc negUpdate11 = grad11 / hess11;
319+ // Calculate the unpurified updates. Purification is invariant to the sign,
320+ // so we can purify the negative updates and get the same result.
321+ const FloatCalc negUpdate00 = CalcNegUpdate ( grad00, hess00) ;
322+ const FloatCalc negUpdate01 = CalcNegUpdate ( grad01, hess01) ;
323+ const FloatCalc negUpdate10 = CalcNegUpdate ( grad10, hess10) ;
324+ const FloatCalc negUpdate11 = CalcNegUpdate ( grad11, hess11) ;
340325
341326 // common part of equations (positive for 00 & 11 equations, negative for 01 and 10)
342327 const FloatCalc common = negUpdate00 - negUpdate01 - negUpdate10 + negUpdate11;
343328
344- // p = purified NEGATIVE update.
345- const FloatCalc negPure00 = common / (FloatCalc{1 } + d00 / d01 + d00 / d10 + d00 / d11);
346- const FloatCalc negPure01 = common / (FloatCalc{-1 } - d01 / d00 - d01 / d10 - d01 / d11);
347- const FloatCalc negPure10 = common / (FloatCalc{-1 } - d10 / d00 - d10 / d01 - d10 / d11);
348- const FloatCalc negPure11 = common / (FloatCalc{1 } + d11 / d00 + d11 / d01 + d11 / d10);
329+ const FloatCalc negPure00 = common / (FloatCalc{1 } + w00 / w01 + w00 / w10 + w00 / w11);
330+ const FloatCalc negPure01 = common / (FloatCalc{-1 } - w01 / w00 - w01 / w10 - w01 / w11);
331+ const FloatCalc negPure10 = common / (FloatCalc{-1 } - w10 / w00 - w10 / w01 - w10 / w11);
332+ const FloatCalc negPure11 = common / (FloatCalc{1 } + w11 / w00 + w11 / w01 + w11 / w10);
349333
350334 // g = partial gain
351- const FloatCalc g00 = CalcPartialGainFromUpdate (hess00, negPure00);
352- const FloatCalc g01 = CalcPartialGainFromUpdate (hess01, negPure01);
353- const FloatCalc g10 = CalcPartialGainFromUpdate (hess10, negPure10);
354- const FloatCalc g11 = CalcPartialGainFromUpdate (hess11, negPure11);
355- #ifndef NDEBUG
356- // r = reconsituted numerator (after purification)
357- const FloatCalc r00 = negPure00 * d00;
358- const FloatCalc r01 = negPure01 * d01;
359- const FloatCalc r10 = negPure10 * d10;
360- const FloatCalc r11 = negPure11 * d11;
361-
362- // purification means summing any direction gives us zero
363- EBM_ASSERT (std::abs (r00 + r01) < 0.001 );
364- EBM_ASSERT (std::abs (r01 + r11) < 0.001 );
365- EBM_ASSERT (std::abs (r11 + r10) < 0.001 );
366- EBM_ASSERT (std::abs (r10 + r00) < 0.001 );
367-
368- // if all of these added together are zero, then the parent partial gain should also
369- // be zero, which means we can avoid calculating the parent partial gain.
370- EBM_ASSERT (std::abs (r00 + r01 + r10 + r11) < 0.001 );
371-
372- EBM_ASSERT (std::abs (g00 - CalcPartialGain (r00, d00)) < 0.001 );
373- EBM_ASSERT (std::abs (g01 - CalcPartialGain (r01, d01)) < 0.001 );
374- EBM_ASSERT (std::abs (g10 - CalcPartialGain (r10, d10)) < 0.001 );
375- EBM_ASSERT (std::abs (g11 - CalcPartialGain (r11, d11)) < 0.001 );
376- #endif // NDEBUG
335+ const FloatCalc g00 = CalcPartialGainFromUpdate (grad00, hess00, negPure00);
336+ const FloatCalc g01 = CalcPartialGainFromUpdate (grad01, hess01, negPure01);
337+ const FloatCalc g10 = CalcPartialGainFromUpdate (grad10, hess10, negPure10);
338+ const FloatCalc g11 = CalcPartialGainFromUpdate (grad11, hess11, negPure11);
339+
377340 gain += g00;
378341 gain += g01;
379342 gain += g10;
0 commit comments