Skip to content

Commit bb2af2d

Browse files
committed
Merge branch 'main' of github.com:broadinstitute/seqr-loading-pipelines
2 parents d86baba + 6416d7b commit bb2af2d

14 files changed

+680
-228
lines changed

v03_pipeline/lib/annotations/sv.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# ruff: noqa: N806
12
from typing import Any
23

34
import hail as hl
@@ -186,13 +187,28 @@ def gnomad_svs(
186187
)[ht['info.GNOMAD_V4.1_TRUTH_VID']]
187188

188189

189-
def gt_stats(ht: hl.Table, **_: Any) -> hl.Expression:
190+
def gt_stats(ht: hl.Table, callset_ht: hl.Table, **_: Any) -> hl.Expression:
191+
def _safe_gt_stats_fetch(ht: hl.Table, field: str):
192+
return hl.or_else(ht.gt_stats[field], 0) if hasattr(ht, 'gt_stats') else 0
193+
194+
# note that "ht" here is the annotations table
195+
# union-ed with new variants, subsetted to variants
196+
# present in the callset. gt_stats will be "missing"
197+
# on the new variants (due to union=True) or not at
198+
# all present if the annotations table does not yet exist.
199+
row = callset_ht[ht.key]
200+
AC = row.gt_stats.AC[1] + _safe_gt_stats_fetch(ht, 'AC')
201+
AN = row.gt_stats.AN + _safe_gt_stats_fetch(ht, 'AN')
202+
Hom = row.gt_stats.homozygote_count[1] + _safe_gt_stats_fetch(ht, 'Hom')
203+
Het = (
204+
row.gt_stats.AC[1] - (row.gt_stats.homozygote_count[1] * 2)
205+
) + _safe_gt_stats_fetch(ht, 'Het')
190206
return hl.struct(
191-
AF=hl.float32(ht['info.AF'][0]),
192-
AC=ht['info.AC'][0],
193-
AN=ht['info.AN'],
194-
Hom=ht['info.N_HOMALT'],
195-
Het=ht['info.N_HET'],
207+
AC=AC,
208+
AN=AN,
209+
AF=hl.float32(AC / AN),
210+
Hom=Hom,
211+
Het=Het,
196212
)
197213

198214

v03_pipeline/lib/annotations/sv_test.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import hail as hl
44

5+
from v03_pipeline.lib.annotations import sv
56
from v03_pipeline.lib.annotations.fields import get_fields
67
from v03_pipeline.lib.model import DatasetType
78

@@ -169,3 +170,74 @@ def test_sv_export_annotations(self) -> None:
169170
),
170171
],
171172
)
173+
174+
def test_allele_count_annotations(self) -> None:
175+
ht = hl.Table.parallelize(
176+
[
177+
{
178+
'variant_id': 0,
179+
'gt_stats': hl.Struct(
180+
AC=4,
181+
AN=8,
182+
AF=hl.float32(0.5),
183+
Hom=1,
184+
Het=2,
185+
),
186+
},
187+
{'variant_id': 1, 'gt_stats': None},
188+
],
189+
hl.tstruct(
190+
variant_id=hl.tint32,
191+
gt_stats=hl.tstruct(
192+
AC=hl.tint32,
193+
AN=hl.tint32,
194+
AF=hl.tfloat32,
195+
Hom=hl.tint32,
196+
Het=hl.tint32,
197+
),
198+
),
199+
key='variant_id',
200+
)
201+
callset_ht = hl.Table.parallelize(
202+
[
203+
{
204+
'variant_id': 0,
205+
'gt_stats': hl.Struct(
206+
AC=[0, 3],
207+
AN=6,
208+
homozygote_count=[0, 1],
209+
),
210+
},
211+
{
212+
'variant_id': 2,
213+
'gt_stats': hl.Struct(
214+
AC=[0, 2],
215+
AN=6,
216+
homozygote_count=[0, 1],
217+
),
218+
},
219+
],
220+
hl.tstruct(
221+
variant_id=hl.tint32,
222+
gt_stats=hl.tstruct(
223+
AC=hl.tarray(hl.tint32),
224+
AN=hl.tint32,
225+
homozygote_count=hl.tarray(hl.tint32),
226+
),
227+
),
228+
key='variant_id',
229+
)
230+
ht = ht.select(gt_stats=sv.gt_stats(ht, callset_ht))
231+
self.assertCountEqual(
232+
ht.collect(),
233+
[
234+
hl.Struct(
235+
variant_id=0,
236+
gt_stats=hl.Struct(AC=7, AN=14, AF=0.5, Hom=2, Het=3),
237+
),
238+
hl.Struct(
239+
variant_id=1,
240+
gt_stats=hl.Struct(AC=None, AN=None, AF=None, Hom=None, Het=None),
241+
),
242+
],
243+
)

v03_pipeline/lib/misc/callsets.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import functools
22

33
import hail as hl
4+
import hailtop.fs as hfs
45

56
from v03_pipeline.lib.model import DatasetType, ReferenceGenome
6-
from v03_pipeline.lib.paths import remapped_and_subsetted_callset_path
7+
from v03_pipeline.lib.paths import (
8+
remapped_and_subsetted_callset_path,
9+
variant_annotations_table_path,
10+
)
711

812

913
def get_callset_ht(
@@ -24,14 +28,39 @@ def get_callset_ht(
2428
for project_guid in project_guids
2529
]
2630
callset_ht = functools.reduce(
27-
(lambda ht1, ht2: ht1.union(ht2, unify=True)),
31+
(lambda ht1, ht2: ht1.union(ht2)),
2832
callset_hts,
2933
)
3034
return callset_ht.distinct()
3135

3236

37+
def get_callset_mt(
38+
reference_genome: ReferenceGenome,
39+
dataset_type: DatasetType,
40+
callset_path: str,
41+
project_guids: list[str],
42+
):
43+
callset_mts = [
44+
hl.read_matrix_table(
45+
remapped_and_subsetted_callset_path(
46+
reference_genome,
47+
dataset_type,
48+
callset_path,
49+
project_guid,
50+
),
51+
)
52+
for project_guid in project_guids
53+
]
54+
callset_mt = functools.reduce(
55+
(lambda mt1, mt2: mt1.union_rows(mt2)),
56+
callset_mts,
57+
)
58+
return callset_mt.distinct_by_row()
59+
60+
3361
def get_additional_row_fields(
3462
mt: hl.MatrixTable,
63+
reference_genome: ReferenceGenome,
3564
dataset_type: DatasetType,
3665
skip_check_sex_and_relatedness: bool,
3766
):
@@ -50,4 +79,26 @@ def get_additional_row_fields(
5079
if hasattr(mt, 'info') and hasattr(mt.info, 'CALIBRATION_SENSITIVITY')
5180
else {}
5281
),
82+
**(
83+
{'info.SEQR_INTERNAL_TRUTH_VID': hl.tstr}
84+
if dataset_type.re_key_by_seqr_internal_truth_vid
85+
and hfs.exists(
86+
variant_annotations_table_path(
87+
reference_genome,
88+
dataset_type,
89+
),
90+
)
91+
and hl.eval(
92+
hl.len(
93+
hl.read_table(
94+
variant_annotations_table_path(
95+
reference_genome,
96+
dataset_type,
97+
),
98+
).globals.updates,
99+
)
100+
> 0,
101+
)
102+
else {}
103+
),
53104
}

v03_pipeline/lib/misc/io.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,9 @@ def import_callset(
181181
elif 'mt' in callset_path:
182182
mt = hl.read_matrix_table(callset_path)
183183
if dataset_type == DatasetType.SV:
184-
mt = mt.annotate_rows(variant_id=mt.rsid)
184+
mt = mt.annotate_rows(
185+
variant_id=mt.rsid,
186+
)
185187
return mt.key_rows_by(*dataset_type.table_key_type(reference_genome).fields)
186188

187189

v03_pipeline/lib/misc/sv.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,112 @@
1+
import itertools
2+
import math
3+
14
import hail as hl
25

36
from v03_pipeline.lib.annotations import sv
47
from v03_pipeline.lib.misc.pedigree import Family
58
from v03_pipeline.lib.model import ReferenceGenome, Sex
69

10+
WRONG_CHROM_PENALTY = 1e9
11+
12+
13+
def _get_grouped_new_callset_variants(
14+
mt: hl.MatrixTable,
15+
duplicate_internal_variant_ids: set[str],
16+
) -> itertools.groupby:
17+
mt = mt.select_rows(
18+
'info.SEQR_INTERNAL_TRUTH_VID',
19+
end_locus=sv.end_locus(mt),
20+
)
21+
return itertools.groupby(
22+
sorted(
23+
mt.filter_rows(
24+
duplicate_internal_variant_ids.contains(
25+
mt['info.SEQR_INTERNAL_TRUTH_VID'],
26+
),
27+
)
28+
.rows()
29+
.collect(),
30+
key=lambda x: x['info.SEQR_INTERNAL_TRUTH_VID'],
31+
),
32+
lambda x: x['info.SEQR_INTERNAL_TRUTH_VID'],
33+
)
34+
35+
36+
def deduplicate_merged_sv_concordance_calls(
37+
mt: hl.MatrixTable,
38+
annotations_ht: hl.Table,
39+
) -> hl.MatrixTable:
40+
# First find the seqr internal variant ids that are duplicated in the new callset.
41+
duplicate_internal_variant_ids = hl.set(
42+
{
43+
k
44+
for k, v in mt.aggregate_rows(
45+
hl.agg.counter(mt['info.SEQR_INTERNAL_TRUTH_VID']),
46+
).items()
47+
if v > 1
48+
}
49+
or hl.empty_set(hl.tstr),
50+
)
51+
52+
# Then, collect into memory the necessary existing variants & the new variants
53+
annotations_ht = annotations_ht.select('end_locus')
54+
existing_variants = {
55+
v.variant_id: v
56+
for v in (
57+
annotations_ht.filter(
58+
duplicate_internal_variant_ids.contains(annotations_ht.variant_id),
59+
).collect()
60+
)
61+
}
62+
grouped_new_variants = _get_grouped_new_callset_variants(
63+
mt,
64+
duplicate_internal_variant_ids,
65+
)
66+
67+
# Then, iterate over new variants and exclude all but the best match
68+
new_variant_ids_to_exclude = set()
69+
for existing_variant_id, new_variants in grouped_new_variants:
70+
existing_variant = existing_variants[existing_variant_id]
71+
closest_variant_id, min_distance = None, math.inf
72+
73+
# First pass to find the closest variant
74+
new_variants_it1, new_variants_it2 = itertools.tee(new_variants, 2)
75+
for new_variant in new_variants_it1:
76+
distance = math.fabs(
77+
new_variant.end_locus.position
78+
- existing_variant.end_locus.position
79+
+ (
80+
WRONG_CHROM_PENALTY
81+
if (
82+
new_variant.end_locus.contig
83+
!= existing_variant.end_locus.contig
84+
)
85+
else 0
86+
),
87+
)
88+
if distance < min_distance:
89+
min_distance = distance
90+
closest_variant_id = new_variant.variant_id
91+
92+
# Second pass to exclude all but the closest.
93+
for new_variant in new_variants_it2:
94+
if new_variant.variant_id != closest_variant_id:
95+
new_variant_ids_to_exclude.add(new_variant.variant_id)
96+
97+
# Finally, remove SEQR_INTERNAL_TRUTH_VID from those variants.
98+
return mt.annotate_rows(
99+
**{
100+
'info.SEQR_INTERNAL_TRUTH_VID': hl.if_else(
101+
hl.set(new_variant_ids_to_exclude or hl.empty_set(hl.tstr)).contains(
102+
mt.variant_id,
103+
),
104+
hl.missing(hl.tstr),
105+
mt['info.SEQR_INTERNAL_TRUTH_VID'],
106+
),
107+
},
108+
)
109+
7110

8111
def overwrite_male_non_par_calls(
9112
mt: hl.MatrixTable,

0 commit comments

Comments
 (0)