Skip to content

Commit b3e83a9

Browse files
author
Alexander März
committed
Update stabilize_derivative
1 parent 923e11e commit b3e83a9

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

lightgbmlss/distributions/mixture_distribution_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,13 @@ def stabilize_derivative(self, input_der: torch.Tensor, type: str = "MAD") -> to
596596
div = torch.where(div > torch.tensor(10000.0), torch.tensor(10000.0), div)
597597
stab_der = input_der / div
598598

599+
if type == "None":
600+
stab_der = torch.nan_to_num(input_der,
601+
nan=float(torch.nanmean(input_der)),
602+
posinf=float(torch.nanmean(input_der)),
603+
neginf=float(torch.nanmean(input_der))
604+
)
605+
599606
return stab_der
600607

601608
def dist_select(self,

0 commit comments

Comments
 (0)