Skip to content

Commit 2bef2ff

Browse files
authored
Validation refactoring (#904)
* make run id shared * A fix * another test * A few more * Few more * Unnecessary * ruf * run id * string * unused * missed one * Fix it correctly * missed one * Last one! * Validation refactoring * Fix it * case this function better * missing comma * change arg name * Moving * PR comments
1 parent e019eb4 commit 2bef2ff

File tree

6 files changed

+104
-56
lines changed

6 files changed

+104
-56
lines changed

v03_pipeline/lib/misc/callsets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_callset_ht(
3030
return callset_ht.distinct()
3131

3232

33-
def additional_row_fields(
33+
def get_additional_row_fields(
3434
mt: hl.MatrixTable,
3535
dataset_type: DatasetType,
3636
skip_check_sex_and_relatedness: bool,

v03_pipeline/lib/misc/validation.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
1+
from typing import Any
2+
13
import hail as hl
24

3-
from v03_pipeline.lib.model import DatasetType, ReferenceGenome, SampleType, Sex
5+
from v03_pipeline.lib.model import (
6+
CachedReferenceDatasetQuery,
7+
DatasetType,
8+
Env,
9+
ReferenceGenome,
10+
SampleType,
11+
Sex,
12+
)
13+
from v03_pipeline.lib.paths import (
14+
cached_reference_dataset_query_path,
15+
sex_check_table_path,
16+
)
417

518
AMBIGUOUS_THRESHOLD_PERC: float = 0.01 # Fraction of samples identified as "ambiguous_sex" above which an error will be thrown.
619
MIN_ROWS_PER_CONTIG = 100
@@ -11,9 +24,40 @@ class SeqrValidationError(Exception):
1124
pass
1225

1326

27+
def get_validation_dependencies(
28+
dataset_type: DatasetType,
29+
reference_genome: ReferenceGenome,
30+
callset_path: str,
31+
skip_check_sex_and_relatedness: bool,
32+
**_: Any,
33+
) -> dict[str, hl.Table]:
34+
deps = {}
35+
deps['coding_and_noncoding_variants_ht'] = hl.read_table(
36+
cached_reference_dataset_query_path(
37+
reference_genome,
38+
dataset_type,
39+
CachedReferenceDatasetQuery.GNOMAD_CODING_AND_NONCODING_VARIANTS,
40+
),
41+
)
42+
if (
43+
Env.CHECK_SEX_AND_RELATEDNESS
44+
and dataset_type.check_sex_and_relatedness
45+
and not skip_check_sex_and_relatedness
46+
):
47+
deps['sex_check_ht'] = hl.read_table(
48+
sex_check_table_path(
49+
reference_genome,
50+
dataset_type,
51+
callset_path,
52+
),
53+
)
54+
return deps
55+
56+
1457
def validate_allele_type(
1558
mt: hl.MatrixTable,
1659
dataset_type: DatasetType,
60+
**_: Any,
1761
) -> None:
1862
ht = mt.rows()
1963
ht = ht.filter(
@@ -31,6 +75,7 @@ def validate_allele_type(
3175

3276
def validate_no_duplicate_variants(
3377
mt: hl.MatrixTable,
78+
**_: Any,
3479
) -> None:
3580
ht = mt.rows()
3681
ht = ht.group_by(*ht.key).aggregate(n=hl.agg.count())
@@ -44,6 +89,7 @@ def validate_expected_contig_frequency(
4489
mt: hl.MatrixTable,
4590
reference_genome: ReferenceGenome,
4691
min_rows_per_contig: int = MIN_ROWS_PER_CONTIG,
92+
**_: Any,
4793
) -> None:
4894
rows_per_contig = mt.aggregate_rows(hl.agg.counter(mt.locus.contig))
4995
missing_contigs = (
@@ -69,6 +115,7 @@ def validate_imported_field_types(
69115
mt: hl.MatrixTable,
70116
dataset_type: DatasetType,
71117
additional_row_fields: dict[str, hl.expr.types.HailType | set],
118+
**_: Any,
72119
) -> None:
73120
def _validate_field(
74121
mt_schema: hl.StructExpression,
@@ -104,8 +151,12 @@ def _validate_field(
104151

105152
def validate_imputed_sex_ploidy(
106153
mt: hl.MatrixTable,
107-
sex_check_ht: hl.Table,
154+
# NB: sex_check_ht will be undefined if sex checking is disabled for the run
155+
sex_check_ht: hl.Table | None = None,
156+
**_: Any,
108157
) -> None:
158+
if not sex_check_ht:
159+
return
109160
mt = mt.select_cols(
110161
discrepant=(
111162
(
@@ -132,6 +183,7 @@ def validate_sample_type(
132183
reference_genome: ReferenceGenome,
133184
sample_type: SampleType,
134185
sample_type_match_threshold: float = SAMPLE_TYPE_MATCH_THRESHOLD,
186+
**_: Any,
135187
) -> None:
136188
coding_variants_ht = coding_and_noncoding_variants_ht.filter(
137189
coding_and_noncoding_variants_ht.coding,

v03_pipeline/lib/misc/validation_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from unittest.mock import Mock, patch
23

34
import hail as hl
45

@@ -80,7 +81,9 @@ def test_validate_allele_type(self) -> None:
8081
DatasetType.SNV_INDEL,
8182
)
8283

83-
def test_validate_imputed_sex_ploidy(self) -> None:
84+
@patch('v03_pipeline.lib.misc.validation.Env')
85+
def test_validate_imputed_sex_ploidy(self, mock_env: Mock) -> None:
86+
mock_env.CHECK_SEX_AND_RELATEDNESS = True
8487
sex_check_ht = hl.read_table(TEST_SEX_CHECK_1)
8588
mt = hl.MatrixTable.from_parts(
8689
rows={

v03_pipeline/lib/tasks/validate_callset.py

Lines changed: 26 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,18 @@
22
import luigi
33
import luigi.util
44

5-
from v03_pipeline.lib.misc.callsets import additional_row_fields
65
from v03_pipeline.lib.misc.validation import (
6+
get_validation_dependencies,
77
validate_allele_type,
88
validate_expected_contig_frequency,
9-
validate_imported_field_types,
109
validate_imputed_sex_ploidy,
1110
validate_no_duplicate_variants,
1211
validate_sample_type,
1312
)
1413
from v03_pipeline.lib.model import CachedReferenceDatasetQuery
1514
from v03_pipeline.lib.model.environment import Env
1615
from v03_pipeline.lib.paths import (
17-
cached_reference_dataset_query_path,
1816
imported_callset_path,
19-
sex_check_table_path,
2017
)
2118
from v03_pipeline.lib.tasks.base.base_loading_run_params import BaseLoadingRunParams
2219
from v03_pipeline.lib.tasks.base.base_update import BaseUpdateTask
@@ -63,8 +60,8 @@ def requires(self) -> list[luigi.Task]:
6360
]
6461
if (
6562
Env.CHECK_SEX_AND_RELATEDNESS
66-
and not self.skip_check_sex_and_relatedness
6763
and self.dataset_type.check_sex_and_relatedness
64+
and not self.skip_check_sex_and_relatedness
6865
):
6966
requirements = [
7067
*requirements,
@@ -83,17 +80,6 @@ def update_table(self, mt: hl.MatrixTable) -> hl.MatrixTable:
8380
self.callset_path,
8481
),
8582
)
86-
# This validation isn't override-able. If a field is the wrong
87-
# type, the pipeline will likely hard-fail downstream.
88-
validate_imported_field_types(
89-
mt,
90-
self.dataset_type,
91-
additional_row_fields(
92-
mt,
93-
self.dataset_type,
94-
self.skip_check_sex_and_relatedness,
95-
),
96-
)
9783
if self.dataset_type.can_run_validation:
9884
# Rather than throwing an error, we silently remove invalid contigs.
9985
# This happens fairly often for AnVIL requests.
@@ -104,38 +90,34 @@ def update_table(self, mt: hl.MatrixTable) -> hl.MatrixTable:
10490
)
10591

10692
if not self.skip_validation and self.dataset_type.can_run_validation:
107-
validate_allele_type(mt, self.dataset_type)
108-
validate_no_duplicate_variants(mt)
109-
validate_expected_contig_frequency(mt, self.reference_genome)
110-
coding_and_noncoding_ht = hl.read_table(
111-
cached_reference_dataset_query_path(
112-
self.reference_genome,
113-
self.dataset_type,
114-
CachedReferenceDatasetQuery.GNOMAD_CODING_AND_NONCODING_VARIANTS,
115-
),
93+
validation_dependencies = get_validation_dependencies(
94+
**self.param_kwargs,
95+
)
96+
validate_allele_type(
97+
mt,
98+
**self.param_kwargs,
99+
**validation_dependencies,
100+
)
101+
validate_no_duplicate_variants(
102+
mt,
103+
**self.param_kwargs,
104+
**validation_dependencies,
105+
)
106+
validate_expected_contig_frequency(
107+
mt,
108+
**self.param_kwargs,
109+
**validation_dependencies,
116110
)
117111
validate_sample_type(
118112
mt,
119-
coding_and_noncoding_ht,
120-
self.reference_genome,
121-
self.sample_type,
113+
**self.param_kwargs,
114+
**validation_dependencies,
115+
)
116+
validate_imputed_sex_ploidy(
117+
mt,
118+
**self.param_kwargs,
119+
**validation_dependencies,
122120
)
123-
if (
124-
Env.CHECK_SEX_AND_RELATEDNESS
125-
and not self.skip_check_sex_and_relatedness
126-
and self.dataset_type.check_sex_and_relatedness
127-
):
128-
sex_check_ht = hl.read_table(
129-
sex_check_table_path(
130-
self.reference_genome,
131-
self.dataset_type,
132-
self.callset_path,
133-
),
134-
)
135-
validate_imputed_sex_ploidy(
136-
mt,
137-
sex_check_ht,
138-
)
139121
return mt.select_globals(
140122
callset_path=self.callset_path,
141123
validated_sample_type=self.sample_type.value,

v03_pipeline/lib/tasks/write_imported_callset.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
import luigi
33
import luigi.util
44

5-
from v03_pipeline.lib.misc.callsets import additional_row_fields
5+
from v03_pipeline.lib.misc.callsets import get_additional_row_fields
66
from v03_pipeline.lib.misc.io import (
77
import_callset,
88
import_vcf,
99
select_relevant_fields,
1010
split_multi_hts,
1111
)
12+
from v03_pipeline.lib.misc.validation import (
13+
validate_imported_field_types,
14+
)
1215
from v03_pipeline.lib.misc.vets import annotate_vets
1316
from v03_pipeline.lib.model.environment import Env
1417
from v03_pipeline.lib.paths import (
@@ -79,14 +82,22 @@ def create_table(self) -> hl.MatrixTable:
7982
)
8083
filters_ht = import_vcf(filters_path, self.reference_genome).rows()
8184
mt = mt.annotate_rows(filters=filters_ht[mt.row_key].filters)
85+
additional_row_fields = get_additional_row_fields(
86+
mt,
87+
self.dataset_type,
88+
self.skip_check_sex_and_relatedness,
89+
)
8290
mt = select_relevant_fields(
8391
mt,
8492
self.dataset_type,
85-
additional_row_fields(
86-
mt,
87-
self.dataset_type,
88-
self.skip_check_sex_and_relatedness,
89-
),
93+
additional_row_fields,
94+
)
95+
# This validation isn't override-able by the skip option.
96+
# If a field is the wrong type, the pipeline will likely hard-fail downstream.
97+
validate_imported_field_types(
98+
mt,
99+
self.dataset_type,
100+
additional_row_fields,
90101
)
91102
if self.dataset_type.has_multi_allelic_variants:
92103
mt = split_multi_hts(mt)

v03_pipeline/lib/tasks/write_remapped_and_subsetted_callset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def requires(self) -> list[luigi.Task]:
6262
]
6363
if (
6464
Env.CHECK_SEX_AND_RELATEDNESS
65-
and not self.skip_check_sex_and_relatedness
6665
and self.dataset_type.check_sex_and_relatedness
66+
and not self.skip_check_sex_and_relatedness
6767
):
6868
requirements = [
6969
*requirements,
@@ -98,8 +98,8 @@ def create_table(self) -> hl.MatrixTable:
9898
families_failed_sex_check = {}
9999
if (
100100
Env.CHECK_SEX_AND_RELATEDNESS
101-
and not self.skip_check_sex_and_relatedness
102101
and self.dataset_type.check_sex_and_relatedness
102+
and not self.skip_check_sex_and_relatedness
103103
):
104104
relatedness_check_ht = hl.read_table(self.input()[2].path)
105105
sex_check_ht = hl.read_table(self.input()[3].path)

0 commit comments

Comments
 (0)