Skip to content

Commit c9a0ce8

Browse files
committed
difftest wald now treats p values close to 1.0 correctly
1 parent 4ca8d94 commit c9a0ce8

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

mixalime/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = '2.2.11'
1+
__version__ = '2.2.13'
22
import importlib
33

44
__min_reqs__ = [

mixalime/diff.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from functools import partial
1010
from copy import deepcopy
1111
from typing import List, Tuple
12-
from scipy.optimize import minimize_scalar, minimize
12+
from scipy.optimize import minimize_scalar
1313
import pandas as pd
1414
import numpy as np
1515
import jax
@@ -76,6 +76,17 @@ def sample(self, n, alt, p, r, k):
7676
res, c = np.unique(res, axis=0, return_counts=True)
7777
res = np.append(res, c.reshape(-1, 1), axis=1)
7878
return res
79+
80+
def minimize_scalar(self, f, xatol=1e-8, steps=10):
81+
if steps:
82+
ps = list(np.linspace(0.0001, 0.9999, steps))
83+
i = np.argmin(list(map(f, ps))) + 1
84+
ps = [0.0] + ps + [1.0]
85+
b = ps[i - 1], ps[i + 1]
86+
else:
87+
b = (0.0, 1.0)
88+
return minimize_scalar(f, bounds=b, method='bounded', options={'xatol': xatol})
89+
7990

8091
def fit(self, data: np.ndarray, params: dict, compute_var=True, sandwich=True, n_bootstrap=100):
8192
name = self.model_name
@@ -103,7 +114,7 @@ def fit(self, data: np.ndarray, params: dict, compute_var=True, sandwich=True, n
103114
f = partial(self.negloglik, r=r, k=k, data=data, w=w, mask=mask)
104115
grad_w = partial(self.grad, r=r, k=k, data=data, w=w, mask=mask)
105116
fim = partial(self.fim, r=r, k=k, data=data, w=w, mask=mask)
106-
res = minimize_scalar(f, bounds=(0.0, 1.0), method='bounded')
117+
res = self.minimize_scalar(f)
107118
x = res.x
108119
bs = list()
109120
for i in range(n_bootstrap):
@@ -112,7 +123,7 @@ def fit(self, data: np.ndarray, params: dict, compute_var=True, sandwich=True, n
112123
data, r_, k_, w_ = self.update_mask(data, r_, k_, w_)
113124
mask = self.mask
114125
f = partial(self.negloglik, r=r_, k=k_, data=data, w=w_, mask=mask)
115-
res_ = minimize_scalar(f, bounds=(0.0, 1.0), method='bounded')
126+
res_ = self.minimize_scalar(f)
116127
if res_.success:
117128
bs.append(res_.x)
118129

@@ -274,11 +285,17 @@ def wald_test(counts: Tuple[tuple, np.ndarray, np.ndarray, np.ndarray],
274285
if allele == 'alt':
275286
counts_a = counts_a[:, (1, 0, 2)]; counts_b = counts_b[:, (1, 0, 2)]; counts = counts[:, (1, 0, 2)]
276287
try:
288+
global globalk
289+
if snv == ('chr12', 14882146):
290+
globalk = True
277291
a_r, a_var = model.fit(counts_a, params[allele], sandwich=robust_se, n_bootstrap=n_bootstrap)
292+
a_var = np.clip(a_var, 0.0, np.inf)
278293
a_p = a_r.x
279294
b_r, b_var = model.fit(counts_b, params[allele], sandwich=robust_se, n_bootstrap=n_bootstrap)
295+
b_var = np.clip(b_var, 0.0, np.inf)
280296
b_p = b_r.x
281-
correct = (a_var >= 0) & (b_var >= 0) & np.isfinite(a_var) & np.isfinite(b_var)
297+
correct = (a_var >= 0) & (b_var >= 0) & (a_var + b_var > 0) & np.isfinite(a_var) & np.isfinite(b_var)
298+
globalk = False
282299
if logit_transform:
283300
a_p, a_var = transform_p(a_p, a_var)
284301
b_p, b_var = transform_p(b_p, b_var)

0 commit comments

Comments
 (0)