99from functools import partial
1010from copy import deepcopy
1111from typing import List , Tuple
12- from scipy .optimize import minimize_scalar , minimize
12+ from scipy .optimize import minimize_scalar
1313import pandas as pd
1414import numpy as np
1515import jax
@@ -76,6 +76,17 @@ def sample(self, n, alt, p, r, k):
7676 res , c = np .unique (res , axis = 0 , return_counts = True )
7777 res = np .append (res , c .reshape (- 1 , 1 ), axis = 1 )
7878 return res
79+
80+ def minimize_scalar (self , f , xatol = 1e-8 , steps = 10 ):
81+ if steps :
82+ ps = list (np .linspace (0.0001 , 0.9999 , steps ))
83+ i = np .argmin (list (map (f , ps ))) + 1
84+ ps = [0.0 ] + ps + [1.0 ]
85+ b = ps [i - 1 ], ps [i + 1 ]
86+ else :
87+ b = (0.0 , 1.0 )
88+ return minimize_scalar (f , bounds = b , method = 'bounded' , options = {'xatol' : xatol })
89+
7990
8091 def fit (self , data : np .ndarray , params : dict , compute_var = True , sandwich = True , n_bootstrap = 100 ):
8192 name = self .model_name
@@ -103,7 +114,7 @@ def fit(self, data: np.ndarray, params: dict, compute_var=True, sandwich=True, n
103114 f = partial (self .negloglik , r = r , k = k , data = data , w = w , mask = mask )
104115 grad_w = partial (self .grad , r = r , k = k , data = data , w = w , mask = mask )
105116 fim = partial (self .fim , r = r , k = k , data = data , w = w , mask = mask )
106- res = minimize_scalar (f , bounds = ( 0.0 , 1.0 ), method = 'bounded' )
117+ res = self . minimize_scalar (f )
107118 x = res .x
108119 bs = list ()
109120 for i in range (n_bootstrap ):
@@ -112,7 +123,7 @@ def fit(self, data: np.ndarray, params: dict, compute_var=True, sandwich=True, n
112123 data , r_ , k_ , w_ = self .update_mask (data , r_ , k_ , w_ )
113124 mask = self .mask
114125 f = partial (self .negloglik , r = r_ , k = k_ , data = data , w = w_ , mask = mask )
115- res_ = minimize_scalar (f , bounds = ( 0.0 , 1.0 ), method = 'bounded' )
126+ res_ = self . minimize_scalar (f )
116127 if res_ .success :
117128 bs .append (res_ .x )
118129
@@ -274,11 +285,17 @@ def wald_test(counts: Tuple[tuple, np.ndarray, np.ndarray, np.ndarray],
274285 if allele == 'alt' :
275286 counts_a = counts_a [:, (1 , 0 , 2 )]; counts_b = counts_b [:, (1 , 0 , 2 )]; counts = counts [:, (1 , 0 , 2 )]
276287 try :
288+ global globalk
289+ if snv == ('chr12' , 14882146 ):
290+ globalk = True
277291 a_r , a_var = model .fit (counts_a , params [allele ], sandwich = robust_se , n_bootstrap = n_bootstrap )
292+ a_var = np .clip (a_var , 0.0 , np .inf )
278293 a_p = a_r .x
279294 b_r , b_var = model .fit (counts_b , params [allele ], sandwich = robust_se , n_bootstrap = n_bootstrap )
295+ b_var = np .clip (b_var , 0.0 , np .inf )
280296 b_p = b_r .x
281- correct = (a_var >= 0 ) & (b_var >= 0 ) & np .isfinite (a_var ) & np .isfinite (b_var )
297+ correct = (a_var >= 0 ) & (b_var >= 0 ) & (a_var + b_var > 0 ) & np .isfinite (a_var ) & np .isfinite (b_var )
298+ globalk = False
282299 if logit_transform :
283300 a_p , a_var = transform_p (a_p , a_var )
284301 b_p , b_var = transform_p (b_p , b_var )
0 commit comments