Skip to content

Commit ddf867a

Browse files
authored
feat: add remap and family loading failures as validation exceptions … (#1005)
* feat: add remap and family loading failures as validation exceptions rather than runtime errors * move on * Update write_remapped_and_subsetted_callset_test.py * ruff
1 parent 2e8dbcf commit ddf867a

8 files changed

+299
-112
lines changed

v03_pipeline/lib/misc/sample_ids.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,16 @@
33
import hail as hl
44

55
from v03_pipeline.lib.logger import get_logger
6+
from v03_pipeline.lib.misc.validation import SeqrValidationError
67

78
logger = get_logger(__name__)
89

910

10-
class MatrixTableSampleSetError(Exception):
11-
def __init__(self, message, missing_samples):
12-
super().__init__(message)
13-
self.missing_samples = missing_samples
14-
15-
16-
def vcf_remap(mt: hl.MatrixTable) -> hl.MatrixTable:
17-
# TODO: add logic from Mike to remap vcf samples delivered from Broad WGS
18-
return mt
19-
20-
2111
def remap_sample_ids(
2212
mt: hl.MatrixTable,
2313
project_remap_ht: hl.Table,
2414
ignore_missing_samples_when_remapping: bool,
2515
) -> hl.MatrixTable:
26-
mt = vcf_remap(mt)
27-
2816
collected_remap = project_remap_ht.collect()
2917
s_dups = [k for k, v in Counter([r.s for r in collected_remap]).items() if v > 1]
3018
seqr_dups = [
@@ -33,7 +21,7 @@ def remap_sample_ids(
3321

3422
if len(s_dups) > 0 or len(seqr_dups) > 0:
3523
msg = f'Duplicate s or seqr_id entries in remap file were found. Duplicate s:{s_dups}. Duplicate seqr_id:{seqr_dups}.'
36-
raise ValueError(msg)
24+
raise SeqrValidationError(msg)
3725

3826
missing_samples = project_remap_ht.anti_join(mt.cols()).collect()
3927
remap_count = len(collected_remap)
@@ -48,7 +36,7 @@ def remap_sample_ids(
4836
if ignore_missing_samples_when_remapping:
4937
logger.info(message)
5038
else:
51-
raise MatrixTableSampleSetError(message, missing_samples)
39+
raise SeqrValidationError(message)
5240

5341
mt = mt.annotate_cols(**project_remap_ht[mt.s])
5442
remap_expr = hl.if_else(hl.is_missing(mt.seqr_id), mt.s, mt.seqr_id)
@@ -67,7 +55,7 @@ def subset_samples(
6755
anti_join_ht_count = anti_join_ht.count()
6856
if subset_count == 0:
6957
message = '0 sample ids found the subset HT, something is probably wrong.'
70-
raise MatrixTableSampleSetError(message, [])
58+
raise SeqrValidationError(message)
7159

7260
if anti_join_ht_count != 0:
7361
missing_samples = anti_join_ht.s.collect()
@@ -77,7 +65,7 @@ def subset_samples(
7765
f"IDs that aren't in the callset: {missing_samples}\n"
7866
f'All callset sample IDs:{mt.s.collect()}'
7967
)
80-
raise MatrixTableSampleSetError(message, missing_samples)
68+
raise SeqrValidationError(message)
8169
logger.info(f'Subsetted to {subset_count} sample ids')
8270
mt = mt.semi_join_cols(sample_subset_ht)
8371
return mt.filter_rows(hl.agg.any(hl.is_defined(mt.GT)))

v03_pipeline/lib/misc/sample_ids_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import hail as hl
44

55
from v03_pipeline.lib.misc.sample_ids import (
6-
MatrixTableSampleSetError,
76
remap_sample_ids,
87
subset_samples,
98
)
9+
from v03_pipeline.lib.misc.validation import SeqrValidationError
1010

1111
CALLSET_MT = hl.MatrixTable.from_parts(
1212
rows={'variants': [1, 2]},
@@ -76,7 +76,7 @@ def test_remap_sample_ids_remap_has_duplicate(self) -> None:
7676
key='s',
7777
)
7878

79-
with self.assertRaises(ValueError):
79+
with self.assertRaises(SeqrValidationError):
8080
remap_sample_ids(
8181
CALLSET_MT,
8282
project_remap_ht,
@@ -99,7 +99,7 @@ def test_remap_sample_ids_remap_has_missing_samples(self) -> None:
9999
key='s',
100100
)
101101

102-
with self.assertRaises(MatrixTableSampleSetError):
102+
with self.assertRaises(SeqrValidationError):
103103
remap_sample_ids(
104104
CALLSET_MT,
105105
project_remap_ht,
@@ -114,7 +114,7 @@ def test_subset_samples_zero_samples(self):
114114
key='s',
115115
)
116116

117-
with self.assertRaises(MatrixTableSampleSetError):
117+
with self.assertRaises(SeqrValidationError):
118118
subset_samples(
119119
CALLSET_MT,
120120
sample_subset_ht,
@@ -132,7 +132,7 @@ def test_subset_samples_missing_samples(self):
132132
key='s',
133133
)
134134

135-
with self.assertRaises(MatrixTableSampleSetError):
135+
with self.assertRaises(SeqrValidationError):
136136
subset_samples(
137137
CALLSET_MT,
138138
sample_subset_ht,

v03_pipeline/lib/tasks/validate_callset_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,6 @@ def test_validate_callset_multiple_exceptions(
8383
'Missing the following expected contigs:chr10, chr11, chr12, chr13, chr14, chr15, chr16, chr17, chr18, chr19, chr2, chr20, chr21, chr22, chr3, chr4, chr5, chr6, chr7, chr8, chr9, chrX',
8484
'Sample type validation error: dataset sample-type is specified as WES but appears to be WGS because it contains many common non-coding variants',
8585
],
86+
'failed_family_samples': {},
8687
},
8788
)

v03_pipeline/lib/tasks/write_imported_callset.py

Lines changed: 49 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
split_multi_hts,
1111
)
1212
from v03_pipeline.lib.misc.validation import (
13-
SeqrValidationError,
1413
validate_imported_field_types,
1514
)
1615
from v03_pipeline.lib.misc.vets import annotate_vets
@@ -24,7 +23,7 @@
2423
from v03_pipeline.lib.tasks.files import CallsetTask, GCSorLocalTarget
2524
from v03_pipeline.lib.tasks.write_tdr_metrics_files import WriteTDRMetricsFilesTask
2625
from v03_pipeline.lib.tasks.write_validation_errors_for_run import (
27-
WriteValidationErrorsForRunTask,
26+
with_persisted_validation_errors,
2827
)
2928

3029

@@ -77,65 +76,56 @@ def requires(self) -> list[luigi.Task]:
7776
CallsetTask(self.callset_path),
7877
]
7978

79+
@with_persisted_validation_errors
8080
def create_table(self) -> hl.MatrixTable:
81-
try:
82-
# NB: throws SeqrValidationError
83-
mt = import_callset(
84-
self.callset_path,
85-
self.reference_genome,
86-
self.dataset_type,
87-
)
88-
filters_path = None
89-
if (
90-
FeatureFlag.EXPECT_WES_FILTERS
91-
and not self.skip_expect_filters
92-
and self.dataset_type.expect_filters(
93-
self.sample_type,
94-
)
95-
):
96-
filters_path = valid_filters_path(
97-
self.dataset_type,
98-
self.sample_type,
99-
self.callset_path,
100-
)
101-
filters_ht = import_vcf(filters_path, self.reference_genome).rows()
102-
mt = mt.annotate_rows(filters=filters_ht[mt.row_key].filters)
103-
additional_row_fields = get_additional_row_fields(
104-
mt,
105-
self.dataset_type,
106-
self.skip_check_sex_and_relatedness,
81+
# NB: throws SeqrValidationError
82+
mt = import_callset(
83+
self.callset_path,
84+
self.reference_genome,
85+
self.dataset_type,
86+
)
87+
filters_path = None
88+
if (
89+
FeatureFlag.EXPECT_WES_FILTERS
90+
and not self.skip_expect_filters
91+
and self.dataset_type.expect_filters(
92+
self.sample_type,
10793
)
108-
# NB: throws SeqrValidationError
109-
mt = select_relevant_fields(
110-
mt,
94+
):
95+
filters_path = valid_filters_path(
11196
self.dataset_type,
112-
additional_row_fields,
97+
self.sample_type,
98+
self.callset_path,
11399
)
114-
# This validation isn't override-able by the skip option.
115-
# If a field is the wrong type, the pipeline will likely hard-fail downstream.
100+
filters_ht = import_vcf(filters_path, self.reference_genome).rows()
101+
mt = mt.annotate_rows(filters=filters_ht[mt.row_key].filters)
102+
additional_row_fields = get_additional_row_fields(
103+
mt,
104+
self.dataset_type,
105+
self.skip_check_sex_and_relatedness,
106+
)
107+
# NB: throws SeqrValidationError
108+
mt = select_relevant_fields(
109+
mt,
110+
self.dataset_type,
111+
additional_row_fields,
112+
)
113+
# This validation isn't override-able by the skip option.
114+
# If a field is the wrong type, the pipeline will likely hard-fail downstream.
115+
# NB: throws SeqrValidationError
116+
validate_imported_field_types(
117+
mt,
118+
self.dataset_type,
119+
additional_row_fields,
120+
)
121+
if self.dataset_type.has_multi_allelic_variants:
116122
# NB: throws SeqrValidationError
117-
validate_imported_field_types(
118-
mt,
119-
self.dataset_type,
120-
additional_row_fields,
121-
)
122-
if self.dataset_type.has_multi_allelic_variants:
123-
# NB: throws SeqrValidationError
124-
mt = split_multi_hts(mt, self.skip_validation)
125-
# Special handling of variant-level filter annotation for VETs filters.
126-
# The annotations are present on the sample-level FT field but are
127-
# expected upstream on "filters".
128-
mt = annotate_vets(mt)
129-
return mt.select_globals(
130-
callset_path=self.callset_path,
131-
filters_path=filters_path or hl.missing(hl.tstr),
132-
)
133-
except SeqrValidationError as e:
134-
write_validation_errors_for_run_task = self.clone(
135-
WriteValidationErrorsForRunTask,
136-
error_messages=[str(e)],
137-
)
138-
write_validation_errors_for_run_task.run()
139-
raise SeqrValidationError(
140-
write_validation_errors_for_run_task.to_single_error_message(),
141-
) from e
123+
mt = split_multi_hts(mt, self.skip_validation)
124+
# Special handling of variant-level filter annotation for VETs filters.
125+
# The annotations are present on the sample-level FT field but are
126+
# expected upstream on "filters".
127+
mt = annotate_vets(mt)
128+
return mt.select_globals(
129+
callset_path=self.callset_path,
130+
filters_path=filters_path or hl.missing(hl.tstr),
131+
)

v03_pipeline/lib/tasks/write_remapped_and_subsetted_callset.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import luigi
33
import luigi.util
44

5-
from v03_pipeline.lib.logger import get_logger
65
from v03_pipeline.lib.misc.family_loading_failures import (
76
get_families_failed_missing_samples,
87
get_families_failed_relatedness_check,
@@ -16,6 +15,7 @@
1615
)
1716
from v03_pipeline.lib.misc.pedigree import parse_pedigree_ht_to_families
1817
from v03_pipeline.lib.misc.sample_ids import remap_sample_ids, subset_samples
18+
from v03_pipeline.lib.misc.validation import SeqrValidationError
1919
from v03_pipeline.lib.model.feature_flag import FeatureFlag
2020
from v03_pipeline.lib.paths import (
2121
relatedness_check_table_path,
@@ -29,8 +29,19 @@
2929
WriteRelatednessCheckTsvTask,
3030
)
3131
from v03_pipeline.lib.tasks.write_sex_check_table import WriteSexCheckTableTask
32+
from v03_pipeline.lib.tasks.write_validation_errors_for_run import (
33+
with_persisted_validation_errors,
34+
)
35+
3236

33-
logger = get_logger(__name__)
37+
def format_failures(failed_families):
38+
return {
39+
f.family_guid: {
40+
'samples': sorted(f.samples.keys()),
41+
'reasons': reasons,
42+
}
43+
for f, reasons in failed_families.items()
44+
}
3445

3546

3647
@luigi.util.inherits(BaseLoadingRunParams)
@@ -73,6 +84,7 @@ def requires(self) -> list[luigi.Task]:
7384
]
7485
return requirements
7586

87+
@with_persisted_validation_errors
7688
def create_table(self) -> hl.MatrixTable:
7789
callset_mt = hl.read_matrix_table(self.input()[0].path)
7890
pedigree_ht = import_pedigree(self.input()[1].path)
@@ -130,16 +142,21 @@ def create_table(self) -> hl.MatrixTable:
130142
- families_failed_sex_check.keys()
131143
)
132144
if not len(loadable_families):
133-
msg = (
134-
f'families_failed_missing_samples: {families_failed_missing_samples}\n'
135-
f'families_failed_relatedness_check: {families_failed_relatedness_check}\n'
136-
f'families_failed_sex_check: {families_failed_sex_check}'
137-
)
138-
logger.info(
145+
msg = 'All families failed validation checks'
146+
raise SeqrValidationError(
139147
msg,
148+
{
149+
'failed_family_samples': {
150+
'missing_samples': format_failures(
151+
families_failed_missing_samples,
152+
),
153+
'relatedness_check': format_failures(
154+
families_failed_relatedness_check,
155+
),
156+
'sex_check': format_failures(families_failed_sex_check),
157+
},
158+
},
140159
)
141-
msg = 'All families failed checks'
142-
raise RuntimeError(msg)
143160

144161
mt = subset_samples(
145162
callset_mt,
@@ -172,33 +189,15 @@ def create_table(self) -> hl.MatrixTable:
172189
),
173190
failed_family_samples=hl.Struct(
174191
missing_samples=(
175-
{
176-
f.family_guid: {
177-
'samples': sorted(f.samples.keys()),
178-
'reasons': reasons,
179-
}
180-
for f, reasons in families_failed_missing_samples.items()
181-
}
192+
format_failures(families_failed_missing_samples)
182193
or hl.empty_dict(hl.tstr, hl.tdict(hl.tstr, hl.tarray(hl.tstr)))
183194
),
184195
relatedness_check=(
185-
{
186-
f.family_guid: {
187-
'samples': sorted(f.samples.keys()),
188-
'reasons': reasons,
189-
}
190-
for f, reasons in families_failed_relatedness_check.items()
191-
}
196+
format_failures(families_failed_relatedness_check)
192197
or hl.empty_dict(hl.tstr, hl.tdict(hl.tstr, hl.tarray(hl.tstr)))
193198
),
194199
sex_check=(
195-
{
196-
f.family_guid: {
197-
'samples': sorted(f.samples.keys()),
198-
'reasons': reasons,
199-
}
200-
for f, reasons in families_failed_sex_check.items()
201-
}
200+
format_failures(families_failed_sex_check)
202201
or hl.empty_dict(hl.tstr, hl.tdict(hl.tstr, hl.tarray(hl.tstr)))
203202
),
204203
),

0 commit comments

Comments
 (0)