Skip to content

Commit 04476bd

Browse files
committed
Add parameteric bootstrap bias correction for Wald test
1 parent fd71975 commit 04476bd

File tree

3 files changed

+66
-27
lines changed

3 files changed

+66
-27
lines changed

mixalime/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = '2.2.9'
1+
__version__ = '2.2.10'
22
import importlib
33

44
__min_reqs__ = [

mixalime/diff.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,39 @@ def negloglik(self, p: float, r: jax.numpy.ndarray, k: jax.numpy.ndarray, data:
4545
mask: jax.numpy.ndarray):
4646
return -self.fun(p, r, k, data, w, mask).sum()
4747

48-
def fit(self, data: np.ndarray, params: dict, compute_var=True, sandwich=True):
48+
def update_mask(self, data, r, k, w):
49+
mask = self.mask
50+
n = max(len(r), len(data))
51+
m = len(mask)
52+
if n > m:
53+
mask = np.zeros(n, dtype=bool)
54+
self.mask = mask
55+
m = n
56+
mask[:n] = False
57+
mask[n:] = True
58+
c = self.allowed_const
59+
v = max(0, m - len(data)); data = np.pad(data, (0, v), constant_values=c);
60+
v = max(0, m - len(w)); w = np.pad(w, (0, v), constant_values=c);
61+
v = max(0, m - len(r)); r = np.pad(r, (0, v), constant_values=c)
62+
v = max(0, m - len(k)); k = np.pad(k, (0, v), constant_values=c)
63+
return data, r, k, w
64+
65+
def sample(self, n, alt, p, r, k):
66+
n = n.astype(int)
67+
alt = np.repeat(alt, n)
68+
r = np.repeat(r, n)
69+
k = np.repeat(k, n)
70+
left = self.left
71+
if self.dist == 'BetaNB':
72+
res = dists.LeftTruncatedBetaNB.sample(p, k, r, left, size=(len(r), ))
73+
else:
74+
res = dists.LeftTruncatedNB.sample(r, p, left, size=(len(r), ))
75+
res = np.stack([res, alt]).T
76+
res, c = np.unique(res, axis=0, return_counts=True)
77+
res = np.append(res, c.reshape(-1, 1), axis=1)
78+
return res
79+
80+
def fit(self, data: np.ndarray, params: dict, compute_var=True, sandwich=True, n_bootstrap=100):
4981
name = self.model_name
5082
data, fixed, w = data.T
5183
n = len(data)
@@ -66,33 +98,38 @@ def fit(self, data: np.ndarray, params: dict, compute_var=True, sandwich=True):
6698
k = ps['mu_k'] + ps['b_k'] * np.log(fixed)
6799
else:
68100
k = np.zeros(len(r))
101+
data, r, k, w = self.update_mask(data, r, k, w)
69102
mask = self.mask
70-
m = len(mask)
71-
if n > m:
72-
mask = np.zeros(data.shape[0], dtype=bool)
73-
self.mask = mask
74-
mask[:n] = False
75-
mask[n:] = True
76-
v = max(0, m - n)
77-
c = self.allowed_const
78-
data = np.pad(data, (0, v), constant_values=c); w = np.pad(w, (0, v), constant_values=c); r = np.pad(r, (0, v), constant_values=c)
79-
k = np.pad(k, (0, v), constant_values=c)
80103
f = partial(self.negloglik, r=r, k=k, data=data, w=w, mask=mask)
81104
grad_w = partial(self.grad, r=r, k=k, data=data, w=w, mask=mask)
82105
fim = partial(self.fim, r=r, k=k, data=data, w=w, mask=mask)
83-
r = minimize_scalar(f, bounds=(0.0, 1.0), method='bounded')
84-
r.x = float(r.x)
85-
lf = float(f(r.x))
106+
res = minimize_scalar(f, bounds=(0.0, 1.0), method='bounded')
107+
x = res.x
108+
bs = list()
109+
for i in range(n_bootstrap):
110+
r_, k_, w_ = r[:n], k[:n], w[:n]
111+
data, _, w_ = self.sample(w_, fixed[:n], x, r_, k_).T
112+
data, r_, k_, w_ = self.update_mask(data, r_, k_, w_)
113+
mask = self.mask
114+
f = partial(self.negloglik, r=r_, k=k_, data=data, w=w_, mask=mask)
115+
res_ = minimize_scalar(f, bounds=(0.0, 1.0), method='bounded')
116+
if res_.success:
117+
bs.append(res_.x)
118+
119+
res.x = float(res.x)
120+
correction = res.x - np.mean(bs) if n_bootstrap else 0.0
121+
res.x = res.x + correction
122+
lf = float(f(res.x))
86123
if compute_var:
87124
if sandwich:
88-
g = grad_w(r.x)
125+
g = grad_w(res.x)
89126
v = (g ** 2 / w).sum()
90-
fim = fim(r.x)
127+
fim = fim(res.x)
91128
s = -1 if fim < -1e-9 else 1
92-
return r, s * float(1 / fim ** 2 * v)
129+
return res, s * float(1 / fim ** 2 * v)
93130
else:
94-
return r, float(1 / fim(r.x))
95-
return r, lf
131+
return res, float(1 / fim(res.x))
132+
return res, lf
96133

97134

98135

@@ -221,7 +258,7 @@ def transform_p(p, var):
221258
def wald_test(counts: Tuple[tuple, np.ndarray, np.ndarray, np.ndarray],
222259
inst_params: dict, params: dict, skip_failures=False, max_sz=None, bad=1.0,
223260
contrasts: Tuple[float, float, float] = (1, -1, 0), logit_transform=False,
224-
param_mode='window', robust_se=True):
261+
param_mode='window', robust_se=True, n_bootstrap=0):
225262
if not hasattr(wald_test, '_cache'):
226263
wald_test._cache = dict()
227264
snv, counts_a, counts_b, counts = counts
@@ -237,9 +274,9 @@ def wald_test(counts: Tuple[tuple, np.ndarray, np.ndarray, np.ndarray],
237274
if allele == 'alt':
238275
counts_a = counts_a[:, (1, 0, 2)]; counts_b = counts_b[:, (1, 0, 2)]; counts = counts[:, (1, 0, 2)]
239276
try:
240-
a_r, a_var = model.fit(counts_a, params[allele], sandwich=robust_se)
277+
a_r, a_var = model.fit(counts_a, params[allele], sandwich=robust_se, n_bootstrap=n_bootstrap)
241278
a_p = a_r.x
242-
b_r, b_var = model.fit(counts_b, params[allele], sandwich=robust_se)
279+
b_r, b_var = model.fit(counts_b, params[allele], sandwich=robust_se, n_bootstrap=n_bootstrap)
243280
b_p = b_r.x
244281
correct = (a_var >= 0) & (b_var >= 0) & np.isfinite(a_var) & np.isfinite(b_var)
245282
if logit_transform:
@@ -267,7 +304,7 @@ def wald_test(counts: Tuple[tuple, np.ndarray, np.ndarray, np.ndarray],
267304
def differential_test(name: str, group_a: List[str], group_b: List[str], mode='wald', min_samples=2, min_cover=0,
268305
max_cover=np.inf, skip_failures=True, group_test=True, alpha=0.05, max_cover_group_test=None,
269306
filter_chr=None, filter_id=None, contrasts=(1, -1, 0), subname=None, param_mode='window',
270-
logit_transform=False, robust_se=True, n_jobs=-1):
307+
logit_transform=False, robust_se=True, n_bootstrap=0, n_jobs=-1):
271308
if max_cover is None:
272309
max_cover = np.inf
273310
if min_cover is None:
@@ -319,7 +356,7 @@ def differential_test(name: str, group_a: List[str], group_b: List[str], mode='w
319356
'alt_pval', 'alt_p_a', 'alt_p_b', 'alt_std_a', 'alt_std_b']
320357
test_fun = partial(wald_test, inst_params=inst_params, params=params, skip_failures=False, bad=bad,
321358
contrasts=contrasts, param_mode=param_mode, logit_transform=logit_transform,
322-
robust_se=robust_se)
359+
robust_se=robust_se, n_bootstrap=n_bootstrap)
323360
if group_test:
324361
counts_a = _counts_a[bad]; counts_b = _counts_b[bad]; counts = _counts[bad]
325362
counts_a = counts_a[counts_a[:, 0] + counts_a[:, 1] < max_cover_group_test]

mixalime/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,8 @@ def _difftest(name: str = Argument(..., help='Project name.'),
489489
' reps/samples.'),
490490
robust_se: bool = Option(False, help='Use robust standard errors (Huber-White Sandwich correction). Applicable only if '
491491
'[cyan]--mode[/cyan]=[yellow]wald[/yellow].'),
492+
n_bootstrap: int = Option(0, help='Boostrap iterations used in stochastic bias correction. Applicable only if [cyan]--mode[/cyan]=='
493+
'[yellow]wald[/yellow].'),
492494
logit_transform: bool = Option(False, help='Apply logit transform to [bold]p[/bold] and its variance with Delta method. Applicable '
493495
'only if [cyan]--mode[/cyan]=[yellow]wald[/yellow].'),
494496
group_test: bool = Option(False, help='Whole groups will be tested against each other first. Note that this will take'
@@ -532,7 +534,7 @@ def _difftest(name: str = Argument(..., help='Project name.'),
532534
max_cover=max_cover, group_test=group_test, subname=subname, filter_id=filter_id,
533535
max_cover_group_test=max_cover_group_test, filter_chr=filter_chr, alpha=alpha, n_jobs=n_jobs,
534536
param_mode='window' if param_window else 'line', logit_transform=logit_transform,
535-
robust_se=robust_se, contrasts=contrasts)[subname]
537+
robust_se=robust_se, contrasts=contrasts, n_bootstrap=n_bootstrap)[subname]
536538
if pretty:
537539
p.stop()
538540
if group_test:
@@ -566,7 +568,7 @@ def _difftest(name: str = Argument(..., help='Project name.'),
566568
mode=mode, subname=subname, group_test=group_test, max_cover=max_cover, filter_id=filter_id,
567569
filter_chr=filter_chr, max_cover_group_test=max_cover_group_test, n_jobs=n_jobs,
568570
param_window=param_window, logit_transform=logit_transform, robust_se=robust_se,
569-
contrasts=contrasts, expected_result=expected_res)
571+
contrasts=contrasts, n_bootstrap=n_bootstrap, expected_result=expected_res)
570572
dt = time() - t0
571573
if pretty:
572574
rprint(f'[green][bold]✔️[/bold] Done![/green]\t time: {dt:.2f} s.')

0 commit comments

Comments
 (0)