Problem with median_survival_time_ in custom fitter #1213
-
| Hi, I am very impressed how easy it was to define custom fitters in lifelines! I am particularly interested in extreme value distributions, specifically EV-I min, EV-I max, and the generalized extreme value distribution. Unfortunately, I encountered a couple of issues with the GEV custom fitter I defined. Convergence was spotty, but I found that judicious choice of  Thank you, import autograd.numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines import *
from lifelines.fitters import ParametricUnivariateFitterT1 = np.array([1.1667, 1.1667, 1.1667, 1.1667, 1.1667, 1.1667, 1.1667, 1.1833,
       1.1833, 1.1833, 1.1833, 1.1833, 1.2   , 1.2   , 1.2   , 1.2   ,
       1.2   , 1.2   , 1.2   , 1.2   , 1.2   , 1.2167, 1.2167, 1.2167,
       1.2167, 1.2167, 1.2167, 1.2167, 1.2167, 1.2167, 1.2167, 1.2167,
       1.2167, 1.2167, 1.2167, 1.2167, 1.2167, 1.2333, 1.2333, 1.2333,
       1.2333, 1.2333, 1.2333, 1.2333, 1.2333, 1.2333, 1.2333, 1.2333,
       1.2333, 1.2333, 1.2333, 1.2333, 1.2333, 1.2333, 1.2333, 1.2333,
       1.2333, 1.2333, 1.25  , 1.25  , 1.25  , 1.25  , 1.25  , 1.25  ,
       1.25  , 1.25  , 1.25  , 1.25  , 1.2667, 1.2667, 1.2667, 1.2667,
       1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667,
       1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667,
       1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667,
       1.2667, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833,
       1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833,
       1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833,
       1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833,
       1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   ,
       1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   ,
       1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   ,
       1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   ,
       1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   , 1.3   ,
       1.3   , 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167,
       1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167,
       1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167,
       1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167,
       1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3333, 1.3333,
       1.3333, 1.3333, 1.3333, 1.3333, 1.3333, 1.3333, 1.3333, 1.3333,
       1.3333, 1.3333, 1.3333, 1.3333, 1.3333, 1.3333, 1.3333, 1.3333,
       1.3333, 1.3333, 1.3333, 1.35  , 1.35  , 1.35  , 1.35  , 1.35  ,
       1.35  , 1.35  , 1.35  , 1.35  , 1.35  , 1.35  , 1.35  , 1.35  ,
       1.35  , 1.35  , 1.35  , 1.35  , 1.35  , 1.35  , 1.35  , 1.35  ,
       1.35  , 1.35  , 1.35  , 1.35  , 1.35  , 1.35  , 1.35  , 1.35  ,
       1.35  , 1.35  , 1.35  , 1.35  , 1.35  , 1.35  , 1.35  , 1.35  ,
       1.35  , 1.3667, 1.3667, 1.3667, 1.3667, 1.3667, 1.3667, 1.3667,
       1.3667, 1.3667, 1.3667, 1.3667, 1.3667, 1.3667, 1.3833, 1.3833,
       1.3833, 1.3833, 1.3833, 1.3833, 1.4   , 1.4167, 1.4167, 1.4333,
       1.4833, 1.4833])T2 = np.array([1.6   , 1.6   , 1.6167, 1.6333, 1.6333, 1.6333, 1.6333, 1.6333,
       1.6333, 1.6333, 1.65  , 1.65  , 1.65  , 1.65  , 1.65  , 1.65  ,
       1.65  , 1.65  , 1.65  , 1.65  , 1.65  , 1.65  , 1.65  , 1.65  ,
       1.65  , 1.6667, 1.6667, 1.6667, 1.6667, 1.6667, 1.6667, 1.6667,
       1.6667, 1.6667, 1.6667, 1.6667, 1.6667, 1.6667, 1.6667, 1.6667,
       1.6833, 1.6833, 1.6833, 1.6833, 1.6833, 1.6833, 1.6833, 1.6833,
       1.6833, 1.6833, 1.6833, 1.6833, 1.6833, 1.6833, 1.6833, 1.6833,
       1.7   , 1.7   , 1.7   , 1.7   , 1.7   , 1.7   , 1.7   , 1.7   ,
       1.7   , 1.7   , 1.7   , 1.7   , 1.7   , 1.7   , 1.7   , 1.7   ,
       1.7   , 1.7   , 1.7167, 1.7167, 1.7167, 1.7167, 1.7167, 1.7167,
       1.7167, 1.7167, 1.7167, 1.7167, 1.7167, 1.7167, 1.7167, 1.7167,
       1.7167, 1.7333, 1.7333, 1.7333, 1.7333, 1.7333, 1.75  , 1.75  ,
       1.75  , 1.75  , 1.75  , 1.75  , 1.75  , 1.75  , 1.75  , 1.75  ,
       1.75  , 1.75  , 1.75  , 1.75  , 1.75  , 1.75  , 1.7667, 1.7667,
       1.7667, 1.7667, 1.7667, 1.7667, 1.7667, 1.7667, 1.7667, 1.7667,
       1.7667, 1.7667, 1.7667, 1.7667, 1.7833, 1.7833, 1.7833, 1.7833,
       1.7833, 1.7833, 1.7833, 1.7833, 1.7833, 1.7833, 1.8   , 1.8   ,
       1.8   , 1.8   , 1.8   , 1.8   , 1.8   , 1.8   , 1.8   , 1.8   ,
       1.8   , 1.8   , 1.8167, 1.8167, 1.8167, 1.8167, 1.8167, 1.8167,
       1.8167, 1.8167, 1.8167, 1.8167, 1.8167, 1.8167, 1.8167, 1.8167,
       1.8167, 1.8167, 1.8333, 1.8333, 1.8333, 1.8333, 1.8333, 1.8333,
       1.8333, 1.85  , 1.85  , 1.85  , 1.85  , 1.85  , 1.85  , 1.8667,
       1.8667, 1.8667, 1.8833, 1.8833, 1.8833, 1.8833, 1.9   , 1.9167,
       1.9167, 1.9333, 1.95  , 1.9833, 2.0333, 2.0667, 2.1   ])class GEVDistFitter(ParametricUnivariateFitter):
    _fitted_parameter_names = ["J_", "mu_", "k_"]
    _bounds = [(0, None), (None, None), (None, None)]
    def _cumulative_hazard(self, params, times):
        J_, mu_, k_ = params
        z = J_*(times - mu_)
        return -np.log1p(-np.exp(-np.power(1+k_*z,-1/k_)))class EV1maxDistFitter(ParametricUnivariateFitter):
    _fitted_parameter_names = ["J_", "mu_"]
    #_bounds = [(0, None), (0, None), (0, T.min()-0.001)]
    def _cumulative_hazard(self, params, times):
        J_, mu_ = params
        z = J_*(times - mu_)
        return -np.log1p(-np.exp(-np.exp(-z)))class EV1minDistFitter(ParametricUnivariateFitter):
    _fitted_parameter_names = ["J_", "mu_"]
    #_bounds = [(0, None), (0, None), (0, T.min()-0.001)]
    def _cumulative_hazard(self, params, times):
        J_, mu_ = params
        z = J_*(times - mu_)
        return np.exp(z)kmf1    = KaplanMeierFitter().fit(T1,label='KM')
gev1    = GEVDistFitter().fit(T1,initial_point=np.array([10,1.3,-0.1]),label='GEV')
ev1max1 = EV1maxDistFitter().fit(T1,label='EV1_max')
ev1min1 = EV1minDistFitter().fit(T1,label='EV1_min')
kmf2    = KaplanMeierFitter().fit(T2,label='KM')
gev2    = GEVDistFitter().fit(T2,initial_point=np.array([10,1.3,-0.1]),label='GEV')
ev1max2 = EV1maxDistFitter().fit(T2,label='EV1_max')
ev1min2 = EV1minDistFitter().fit(T2,label='EV1_min')fig, axes = plt.subplots(3, 2,sharey=True, figsize = (8,10))
for ii in range(3):
    kmf1.plot_survival_function(ax=axes[ii,0])
    axes[ii,0].set_xlim(0,3)
    kmf2.plot_survival_function(ax=axes[ii,1])
    axes[ii,1].set_xlim(0,3)
gev1.plot_survival_function(ax=axes[0,0])
ev1max1.plot_survival_function(ax=axes[1,0])
ev1min1.plot_survival_function(ax=axes[2,0])
gev2.plot_survival_function(ax=axes[0,1])
ev1max2.plot_survival_function(ax=axes[1,1])
ev1min2.plot_survival_function(ax=axes[2,1])tau1 = np.array(
    [[gev1.median_survival_time_,
      ev1max1.median_survival_time_,
      ev1min1.median_survival_time_],
     [(np.power(np.log(2),-gev1.k_)-1)/(gev1.J_*gev1.k_)+gev1.mu_,
      ev1max1.mu_-1/ev1max1.J_*np.log(np.log(2)),
      ev1min1.mu_+1/ev1min1.J_*np.log(np.log(2))]])
tau2 = np.array(
    [[gev2.median_survival_time_,
      ev1max2.median_survival_time_,
      ev1min2.median_survival_time_],
     [(np.power(np.log(2),-gev2.k_)-1)/(gev2.J_*gev2.k_)+gev2.mu_,
      ev1max2.mu_-1/ev1max2.J_*np.log(np.log(2)),
      ev1min2.mu_+1/ev1min2.J_*np.log(np.log(2))]])pd.DataFrame(tau1,columns=['GEV1','EV1max1','EV1min1'],index=['lifelines','analytical'])
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }
 
 pd.DataFrame(tau2,columns=['GEV2','EV1max2','EV1min2'],index=['lifelines','analytical'])
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }
 
 | 
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
| Hi @DerkJoester, thanks for the detailed issue. I'll have to dig in further as to what might be happening, but for your case: if you know the analytical formula, you can add it to the class, see example here. Note that  | 
Beta Was this translation helpful? Give feedback.

Hi @DerkJoester, thanks for the detailed issue. I'll have to dig in further as to what might be happening, but for your case: if you know the analytical formula, you can add it to the class, see example here. Note that
percentileis called bymedian_survival_time_, and if the former is not available, it's numerically computed - which is where I think the problem you are seeing lies.