@@ -45,7 +45,39 @@ def negloglik(self, p: float, r: jax.numpy.ndarray, k: jax.numpy.ndarray, data:
4545 mask : jax .numpy .ndarray ):
4646 return - self .fun (p , r , k , data , w , mask ).sum ()
4747
48- def fit (self , data : np .ndarray , params : dict , compute_var = True , sandwich = True ):
48+ def update_mask (self , data , r , k , w ):
49+ mask = self .mask
50+ n = max (len (r ), len (data ))
51+ m = len (mask )
52+ if n > m :
53+ mask = np .zeros (n , dtype = bool )
54+ self .mask = mask
55+ m = n
56+ mask [:n ] = False
57+ mask [n :] = True
58+ c = self .allowed_const
59+ v = max (0 , m - len (data )); data = np .pad (data , (0 , v ), constant_values = c );
60+ v = max (0 , m - len (w )); w = np .pad (w , (0 , v ), constant_values = c );
61+ v = max (0 , m - len (r )); r = np .pad (r , (0 , v ), constant_values = c )
62+ v = max (0 , m - len (k )); k = np .pad (k , (0 , v ), constant_values = c )
63+ return data , r , k , w
64+
65+ def sample (self , n , alt , p , r , k ):
66+ n = n .astype (int )
67+ alt = np .repeat (alt , n )
68+ r = np .repeat (r , n )
69+ k = np .repeat (k , n )
70+ left = self .left
71+ if self .dist == 'BetaNB' :
72+ res = dists .LeftTruncatedBetaNB .sample (p , k , r , left , size = (len (r ), ))
73+ else :
74+ res = dists .LeftTruncatedNB .sample (r , p , left , size = (len (r ), ))
75+ res = np .stack ([res , alt ]).T
76+ res , c = np .unique (res , axis = 0 , return_counts = True )
77+ res = np .append (res , c .reshape (- 1 , 1 ), axis = 1 )
78+ return res
79+
80+ def fit (self , data : np .ndarray , params : dict , compute_var = True , sandwich = True , n_bootstrap = 100 ):
4981 name = self .model_name
5082 data , fixed , w = data .T
5183 n = len (data )
@@ -66,33 +98,38 @@ def fit(self, data: np.ndarray, params: dict, compute_var=True, sandwich=True):
6698 k = ps ['mu_k' ] + ps ['b_k' ] * np .log (fixed )
6799 else :
68100 k = np .zeros (len (r ))
101+ data , r , k , w = self .update_mask (data , r , k , w )
69102 mask = self .mask
70- m = len (mask )
71- if n > m :
72- mask = np .zeros (data .shape [0 ], dtype = bool )
73- self .mask = mask
74- mask [:n ] = False
75- mask [n :] = True
76- v = max (0 , m - n )
77- c = self .allowed_const
78- data = np .pad (data , (0 , v ), constant_values = c ); w = np .pad (w , (0 , v ), constant_values = c ); r = np .pad (r , (0 , v ), constant_values = c )
79- k = np .pad (k , (0 , v ), constant_values = c )
80103 f = partial (self .negloglik , r = r , k = k , data = data , w = w , mask = mask )
81104 grad_w = partial (self .grad , r = r , k = k , data = data , w = w , mask = mask )
82105 fim = partial (self .fim , r = r , k = k , data = data , w = w , mask = mask )
83- r = minimize_scalar (f , bounds = (0.0 , 1.0 ), method = 'bounded' )
84- r .x = float (r .x )
85- lf = float (f (r .x ))
106+ res = minimize_scalar (f , bounds = (0.0 , 1.0 ), method = 'bounded' )
107+ x = res .x
108+ bs = list ()
109+ for i in range (n_bootstrap ):
110+ r_ , k_ , w_ = r [:n ], k [:n ], w [:n ]
111+ data , _ , w_ = self .sample (w_ , fixed [:n ], x , r_ , k_ ).T
112+ data , r_ , k_ , w_ = self .update_mask (data , r_ , k_ , w_ )
113+ mask = self .mask
114+ f = partial (self .negloglik , r = r_ , k = k_ , data = data , w = w_ , mask = mask )
115+ res_ = minimize_scalar (f , bounds = (0.0 , 1.0 ), method = 'bounded' )
116+ if res_ .success :
117+ bs .append (res_ .x )
118+
119+ res .x = float (res .x )
120+ correction = res .x - np .mean (bs ) if n_bootstrap else 0.0
121+ res .x = res .x + correction
122+ lf = float (f (res .x ))
86123 if compute_var :
87124 if sandwich :
88- g = grad_w (r .x )
125+ g = grad_w (res .x )
89126 v = (g ** 2 / w ).sum ()
90- fim = fim (r .x )
127+ fim = fim (res .x )
91128 s = - 1 if fim < - 1e-9 else 1
92- return r , s * float (1 / fim ** 2 * v )
129+ return res , s * float (1 / fim ** 2 * v )
93130 else :
94- return r , float (1 / fim (r .x ))
95- return r , lf
131+ return res , float (1 / fim (res .x ))
132+ return res , lf
96133
97134
98135
@@ -221,7 +258,7 @@ def transform_p(p, var):
221258def wald_test (counts : Tuple [tuple , np .ndarray , np .ndarray , np .ndarray ],
222259 inst_params : dict , params : dict , skip_failures = False , max_sz = None , bad = 1.0 ,
223260 contrasts : Tuple [float , float , float ] = (1 , - 1 , 0 ), logit_transform = False ,
224- param_mode = 'window' , robust_se = True ):
261+ param_mode = 'window' , robust_se = True , n_bootstrap = 0 ):
225262 if not hasattr (wald_test , '_cache' ):
226263 wald_test ._cache = dict ()
227264 snv , counts_a , counts_b , counts = counts
@@ -237,9 +274,9 @@ def wald_test(counts: Tuple[tuple, np.ndarray, np.ndarray, np.ndarray],
237274 if allele == 'alt' :
238275 counts_a = counts_a [:, (1 , 0 , 2 )]; counts_b = counts_b [:, (1 , 0 , 2 )]; counts = counts [:, (1 , 0 , 2 )]
239276 try :
240- a_r , a_var = model .fit (counts_a , params [allele ], sandwich = robust_se )
277+ a_r , a_var = model .fit (counts_a , params [allele ], sandwich = robust_se , n_bootstrap = n_bootstrap )
241278 a_p = a_r .x
242- b_r , b_var = model .fit (counts_b , params [allele ], sandwich = robust_se )
279+ b_r , b_var = model .fit (counts_b , params [allele ], sandwich = robust_se , n_bootstrap = n_bootstrap )
243280 b_p = b_r .x
244281 correct = (a_var >= 0 ) & (b_var >= 0 ) & np .isfinite (a_var ) & np .isfinite (b_var )
245282 if logit_transform :
@@ -267,7 +304,7 @@ def wald_test(counts: Tuple[tuple, np.ndarray, np.ndarray, np.ndarray],
267304def differential_test (name : str , group_a : List [str ], group_b : List [str ], mode = 'wald' , min_samples = 2 , min_cover = 0 ,
268305 max_cover = np .inf , skip_failures = True , group_test = True , alpha = 0.05 , max_cover_group_test = None ,
269306 filter_chr = None , filter_id = None , contrasts = (1 , - 1 , 0 ), subname = None , param_mode = 'window' ,
270- logit_transform = False , robust_se = True , n_jobs = - 1 ):
307+ logit_transform = False , robust_se = True , n_bootstrap = 0 , n_jobs = - 1 ):
271308 if max_cover is None :
272309 max_cover = np .inf
273310 if min_cover is None :
@@ -319,7 +356,7 @@ def differential_test(name: str, group_a: List[str], group_b: List[str], mode='w
319356 'alt_pval' , 'alt_p_a' , 'alt_p_b' , 'alt_std_a' , 'alt_std_b' ]
320357 test_fun = partial (wald_test , inst_params = inst_params , params = params , skip_failures = False , bad = bad ,
321358 contrasts = contrasts , param_mode = param_mode , logit_transform = logit_transform ,
322- robust_se = robust_se )
359+ robust_se = robust_se , n_bootstrap = n_bootstrap )
323360 if group_test :
324361 counts_a = _counts_a [bad ]; counts_b = _counts_b [bad ]; counts = _counts [bad ]
325362 counts_a = counts_a [counts_a [:, 0 ] + counts_a [:, 1 ] < max_cover_group_test ]
0 commit comments