Skip to content

Commit f3d8ac7

Browse files
authored
add unit tests for remap/subset sample ids (#654)
* add unit tests for remap sample ids * revert counter change * types
1 parent d9bd660 commit f3d8ac7

File tree

2 files changed

+169
-1
lines changed

2 files changed

+169
-1
lines changed

v03_pipeline/lib/misc/sample_ids.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def remap_sample_ids(
2020
ignore_missing_samples_when_remapping: bool,
2121
) -> hl.MatrixTable:
2222
mt = vcf_remap(mt)
23+
2324
collected_remap = project_remap_ht.collect()
2425
s_dups = [k for k, v in Counter([r.s for r in collected_remap]).items() if v > 1]
2526
seqr_dups = [
@@ -46,7 +47,7 @@ def remap_sample_ids(
4647
raise MatrixTableSampleSetError(message, missing_samples)
4748

4849
mt = mt.annotate_cols(**project_remap_ht[mt.s])
49-
remap_expr = hl.cond(hl.is_missing(mt.seqr_id), mt.s, mt.seqr_id)
50+
remap_expr = hl.if_else(hl.is_missing(mt.seqr_id), mt.s, mt.seqr_id)
5051
mt = mt.annotate_cols(seqr_id=remap_expr, vcf_id=mt.s)
5152
mt = mt.key_cols_by(s=mt.seqr_id)
5253
print(f'Remapped {remap_count} sample ids...')
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import unittest
2+
3+
import hail as hl
4+
5+
from v03_pipeline.lib.misc.sample_ids import (
6+
MatrixTableSampleSetError,
7+
remap_sample_ids,
8+
subset_samples,
9+
)
10+
11+
CALLSET_MT = hl.MatrixTable.from_parts(
12+
rows={'variants': [1, 2]},
13+
cols={'s': ['HG00731', 'HG00732', 'HG00733']},
14+
entries={
15+
'HL': [
16+
[0.0, hl.missing(hl.tfloat), 0.3],
17+
[0.1, 0.2, 0.3],
18+
],
19+
},
20+
).key_cols_by('s')
21+
22+
23+
class SampleLookupTest(unittest.TestCase):
24+
def test_remap_2_sample_ids(self) -> None:
25+
# remap 2 of 3 samples in callset
26+
project_remap_ht = hl.Table.parallelize(
27+
[
28+
{'s': 'HG00731', 'seqr_id': 'HG00731_1'},
29+
{'s': 'HG00732', 'seqr_id': 'HG00732_1'},
30+
],
31+
hl.tstruct(
32+
s=hl.tstr,
33+
seqr_id=hl.tstr,
34+
),
35+
key='s',
36+
)
37+
38+
remapped_mt = remap_sample_ids(
39+
CALLSET_MT,
40+
project_remap_ht,
41+
ignore_missing_samples_when_remapping=True,
42+
)
43+
44+
self.assertEqual(remapped_mt.cols().count(), 3)
45+
self.assertEqual(
46+
remapped_mt.cols().collect(),
47+
[
48+
hl.Struct(
49+
col_idx=0,
50+
s='HG00731_1',
51+
seqr_id='HG00731_1',
52+
vcf_id='HG00731',
53+
),
54+
hl.Struct(
55+
col_idx=1,
56+
s='HG00732_1',
57+
seqr_id='HG00732_1',
58+
vcf_id='HG00732',
59+
),
60+
hl.Struct(col_idx=2, s='HG00733', seqr_id='HG00733', vcf_id='HG00733'),
61+
],
62+
)
63+
64+
def test_remap_sample_ids_remap_has_duplicate(self) -> None:
65+
# remap file has 2 rows for HG00732
66+
project_remap_ht = hl.Table.parallelize(
67+
[
68+
{'s': 'HG00731', 'seqr_id': 'HG00731_1'},
69+
{'s': 'HG00732', 'seqr_id': 'HG00732_1'},
70+
{'s': 'HG00732', 'seqr_id': 'HG00732_1'}, # duplicate
71+
],
72+
hl.tstruct(
73+
s=hl.tstr,
74+
seqr_id=hl.tstr,
75+
),
76+
key='s',
77+
)
78+
79+
with self.assertRaises(ValueError):
80+
remap_sample_ids(
81+
CALLSET_MT,
82+
project_remap_ht,
83+
ignore_missing_samples_when_remapping=True,
84+
)
85+
86+
def test_remap_sample_ids_remap_has_missing_samples(self) -> None:
87+
# remap file has 4 rows, but only 3 samples in callset
88+
project_remap_ht = hl.Table.parallelize(
89+
[
90+
{'s': 'HG00731', 'seqr_id': 'HG00731_1'},
91+
{'s': 'HG00732', 'seqr_id': 'HG00732_1'},
92+
{'s': 'HG00733', 'seqr_id': 'HG00733_1'},
93+
{'s': 'HG00734', 'seqr_id': 'HG00734_1'}, # missing in callset
94+
],
95+
hl.tstruct(
96+
s=hl.tstr,
97+
seqr_id=hl.tstr,
98+
),
99+
key='s',
100+
)
101+
102+
with self.assertRaises(MatrixTableSampleSetError):
103+
remap_sample_ids(
104+
CALLSET_MT,
105+
project_remap_ht,
106+
ignore_missing_samples_when_remapping=False,
107+
)
108+
109+
def test_subset_samples(self):
110+
# subset 2 of 3 samples in callset
111+
sample_subset_ht = hl.Table.parallelize(
112+
[
113+
{'s': 'HG00731'},
114+
{'s': 'HG00732'},
115+
],
116+
hl.tstruct(s=hl.tstr),
117+
key='s',
118+
)
119+
120+
subset_mt = subset_samples(
121+
CALLSET_MT,
122+
sample_subset_ht,
123+
ignore_missing_samples_when_subsetting=True,
124+
)
125+
126+
self.assertEqual(subset_mt.cols().count(), 2)
127+
self.assertEqual(
128+
subset_mt.cols().collect(),
129+
[
130+
hl.Struct(col_idx=0, s='HG00731'),
131+
hl.Struct(col_idx=1, s='HG00732'),
132+
],
133+
)
134+
135+
def test_subset_samples_zero_samples(self):
136+
# subset 0 of 3 samples in callset
137+
sample_subset_ht = hl.Table.parallelize(
138+
[],
139+
hl.tstruct(s=hl.tstr),
140+
key='s',
141+
)
142+
143+
with self.assertRaises(MatrixTableSampleSetError):
144+
subset_samples(
145+
CALLSET_MT,
146+
sample_subset_ht,
147+
ignore_missing_samples_when_subsetting=True,
148+
)
149+
150+
def test_subset_samples_missing_samples(self):
151+
# subset 2 of 3 samples in callset, but 1 is missing
152+
sample_subset_ht = hl.Table.parallelize(
153+
[
154+
{'s': 'HG00731'},
155+
{'s': 'HG00732'},
156+
{'s': 'HG00734'}, # missing in callset
157+
],
158+
hl.tstruct(s=hl.tstr),
159+
key='s',
160+
)
161+
162+
with self.assertRaises(MatrixTableSampleSetError):
163+
subset_samples(
164+
CALLSET_MT,
165+
sample_subset_ht,
166+
ignore_missing_samples_when_subsetting=False,
167+
)

0 commit comments

Comments
 (0)