Skip to content

Commit 598e01b

Browse files
committed
Merge remote-tracking branch 'origin/dev' into dev
2 parents 894e468 + 7d81dc9 commit 598e01b

15 files changed

+184
-102
lines changed

v03_pipeline/lib/annotations/mito_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_allele_count_annotations(self) -> None:
6666
),
6767
key='id',
6868
globals=hl.Struct(
69-
project_guids=['project_1', 'project_2'],
69+
project_sample_types=[('project_1', 'WES'), ('project_2', 'WES')],
7070
project_families={'project_1': ['a'], 'project_2': []},
7171
),
7272
)

v03_pipeline/lib/misc/lookup.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import hail as hl
22

3-
from v03_pipeline.lib.model import DatasetType
3+
from v03_pipeline.lib.model import DatasetType, SampleType
44

55

66
def compute_callset_lookup_ht(
77
dataset_type: DatasetType,
88
mt: hl.MatrixTable,
99
project_guid: str,
10+
sample_type: SampleType,
1011
) -> hl.Table:
1112
sample_id_to_family_guid = hl.dict(
1213
{
@@ -38,7 +39,7 @@ def compute_callset_lookup_ht(
3839
],
3940
),
4041
).rows()
41-
ht = globalize_ids(ht, project_guid)
42+
ht = globalize_ids(ht, project_guid, sample_type)
4243
return ht.annotate(
4344
project_stats=[
4445
# Set a family to missing if all values are 0
@@ -52,13 +53,14 @@ def compute_callset_lookup_ht(
5253
)
5354

5455

55-
def globalize_ids(ht: hl.Table, project_guid: str) -> hl.Table:
56+
def globalize_ids(ht: hl.Table, project_guid: str, sample_type: SampleType) -> hl.Table:
5657
row = ht.take(1)[0] if ht.count() > 0 else None
5758
has_project_stats = row and len(row.project_stats) > 0
59+
project_key = (project_guid, sample_type.value)
5860
ht = ht.annotate_globals(
59-
project_guids=[project_guid],
61+
project_sample_types=[project_key],
6062
project_families=(
61-
{project_guid: [fs.family_guid for fs in ps] for ps in row.project_stats}
63+
{project_key: [fs.family_guid for fs in ps] for ps in row.project_stats}
6264
if has_project_stats
6365
else hl.empty_dict(hl.tstr, hl.tarray(hl.tstr))
6466
),
@@ -73,14 +75,16 @@ def globalize_ids(ht: hl.Table, project_guid: str) -> hl.Table:
7375
def remove_family_guids(
7476
ht: hl.Table,
7577
project_guid: str,
78+
sample_type: SampleType,
7679
family_guids: hl.SetExpression,
7780
) -> hl.Table:
78-
if project_guid not in hl.eval(ht.globals.project_families):
81+
project_key = (project_guid, sample_type.value)
82+
if project_key not in hl.eval(ht.globals.project_families):
7983
return ht
80-
project_i = ht.project_guids.index(project_guid)
84+
project_i = ht.project_sample_types.index(project_key)
8185
family_indexes_to_keep = hl.eval(
8286
hl.array(
83-
hl.enumerate(ht.globals.project_families[project_guid])
87+
hl.enumerate(ht.globals.project_families[project_key])
8488
.filter(lambda item: ~family_guids.contains(item[1]))
8589
.map(lambda item: item[0]),
8690
),
@@ -112,11 +116,11 @@ def remove_family_guids(
112116
ht.project_families.items().map(
113117
lambda item: (
114118
hl.if_else(
115-
item[0] != project_guid,
119+
item[0] != project_key,
116120
item,
117121
(
118122
item[0],
119-
ht.project_families[project_guid].filter(
123+
ht.project_families[project_key].filter(
120124
lambda family_guid: ~family_guids.contains(family_guid),
121125
),
122126
),
@@ -130,13 +134,15 @@ def remove_family_guids(
130134
def remove_project(
131135
ht: hl.Table,
132136
project_guid: str,
137+
sample_type: SampleType,
133138
) -> hl.Table:
134-
existing_project_guids = hl.eval(ht.globals.project_guids)
135-
if project_guid not in existing_project_guids:
139+
existing_projects = hl.eval(ht.globals.project_sample_types)
140+
project_key = (project_guid, sample_type.value)
141+
if project_key not in existing_projects:
136142
return ht
137143
project_indexes_to_keep = hl.eval(
138-
hl.enumerate(existing_project_guids)
139-
.filter(lambda item: item[1] != project_guid)
144+
hl.enumerate(existing_projects)
145+
.filter(lambda item: item[1] != project_key)
140146
.map(lambda item: item[0]),
141147
)
142148
ht = ht.annotate(
@@ -149,11 +155,11 @@ def remove_project(
149155
)
150156
ht = ht.filter(hl.any(ht.project_stats.map(hl.is_defined)))
151157
return ht.annotate_globals(
152-
project_guids=ht.project_guids.filter(
153-
lambda p: p != project_guid,
158+
project_sample_types=ht.project_sample_types.filter(
159+
lambda p: p != project_key,
154160
),
155161
project_families=hl.dict(
156-
ht.project_families.items().filter(lambda item: item[0] != project_guid),
162+
ht.project_families.items().filter(lambda item: item[0] != project_key),
157163
),
158164
)
159165

@@ -163,8 +169,8 @@ def join_lookup_hts(
163169
callset_ht: hl.Table,
164170
) -> hl.Table:
165171
ht = ht.join(callset_ht, 'outer')
166-
project_guid = ht.project_guids_1[0]
167-
ht_project_i = ht.project_guids.index(project_guid)
172+
project_key = ht.project_sample_types_1[0]
173+
ht_project_i = ht.project_sample_types.index(project_key)
168174
ht = ht.select(
169175
# We have 6 unique cases here.
170176
# 1) The project has not been loaded before, the row is missing
@@ -183,14 +189,14 @@ def join_lookup_hts(
183189
hl.case()
184190
.when(
185191
(hl.is_missing(ht_project_i) & hl.is_missing(ht.project_stats)),
186-
ht.project_guids.map(
192+
ht.project_sample_types.map(
187193
lambda _: hl.missing(ht.project_stats.dtype.element_type),
188194
).extend(ht.project_stats_1),
189195
)
190196
.when(
191197
(hl.is_missing(ht_project_i) & hl.is_missing(ht.project_stats_1)),
192198
ht.project_stats.extend(
193-
ht.project_guids_1.map(
199+
ht.project_sample_types_1.map(
194200
lambda _: hl.missing(ht.project_stats_1.dtype.element_type),
195201
),
196202
),
@@ -201,7 +207,7 @@ def join_lookup_hts(
201207
)
202208
.when(
203209
hl.is_missing(ht.project_stats),
204-
hl.enumerate(ht.project_guids).starmap(
210+
hl.enumerate(ht.project_sample_types).starmap(
205211
# Add a missing project_stats value for every loaded project,
206212
# then add a missing value for every family for "this project"
207213
# and extend the new families on the right.
@@ -230,7 +236,7 @@ def join_lookup_hts(
230236
i != ht_project_i,
231237
ps,
232238
ps.extend(
233-
ht.project_families_1[project_guid].map(
239+
ht.project_families_1[project_key].map(
234240
lambda _: hl.missing(
235241
ht.project_stats.dtype.element_type.element_type,
236242
),
@@ -256,26 +262,26 @@ def join_lookup_hts(
256262
),
257263
)
258264
# NB: double reference these because the source ht has changed :/
259-
project_guid = ht.project_guids_1[0]
260-
ht_project_i = ht.project_guids.index(project_guid)
265+
project_key = ht.project_sample_types_1[0]
266+
ht_project_i = ht.project_sample_types.index(project_key)
261267
return ht.transmute_globals(
262-
project_guids=hl.if_else(
268+
project_sample_types=hl.if_else(
263269
hl.is_missing(ht_project_i),
264-
ht.project_guids.extend(ht.project_guids_1),
265-
ht.project_guids,
270+
ht.project_sample_types.extend(ht.project_sample_types_1),
271+
ht.project_sample_types,
266272
),
267273
project_families=hl.if_else(
268274
hl.is_missing(ht_project_i),
269275
hl.dict(ht.project_families.items().extend(ht.project_families_1.items())),
270276
hl.dict(
271277
ht.project_families.items().map(
272278
lambda item: hl.if_else(
273-
item[0] != project_guid,
279+
item[0] != project_key,
274280
item,
275281
(
276282
item[0],
277-
ht.project_families[project_guid].extend(
278-
ht.project_families_1[project_guid],
283+
ht.project_families[project_key].extend(
284+
ht.project_families_1[project_key],
279285
),
280286
),
281287
),

0 commit comments

Comments
 (0)