diff --git a/requirements.txt b/requirements.txt index 86e9f058b..96f6fb725 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# pip-compile --resolver=backtracking requirements.in +# pip-compile requirements.in # aiodns==2.0.0 # via hail diff --git a/v03_pipeline/bin/pipeline_worker.py b/v03_pipeline/bin/pipeline_worker.py index 7586fe8f6..7be97a9b0 100755 --- a/v03_pipeline/bin/pipeline_worker.py +++ b/v03_pipeline/bin/pipeline_worker.py @@ -7,7 +7,7 @@ from v03_pipeline.api.model import LoadingPipelineRequest from v03_pipeline.lib.logger import get_logger -from v03_pipeline.lib.model import Env +from v03_pipeline.lib.model import FeatureFlag from v03_pipeline.lib.paths import ( loading_pipeline_queue_path, project_pedigree_path, @@ -54,7 +54,7 @@ def main(): 'run_id': run_id, **{k: v for k, v in lpr.model_dump().items() if k != 'projects_to_run'}, } - if Env.SHOULD_TRIGGER_HAIL_BACKEND_RELOAD: + if FeatureFlag.SHOULD_TRIGGER_HAIL_BACKEND_RELOAD: tasks = [ TriggerHailBackendReload(**loading_run_task_params), ] diff --git a/v03_pipeline/lib/misc/family_loading_failures.py b/v03_pipeline/lib/misc/family_loading_failures.py index e3b1b59db..438b1706e 100644 --- a/v03_pipeline/lib/misc/family_loading_failures.py +++ b/v03_pipeline/lib/misc/family_loading_failures.py @@ -16,7 +16,8 @@ def passes_relatedness_check( relatedness_check_lookup: dict[tuple[str, str], list], sample_id: str, other_id: str, - relation: Relation, + expected_relation: Relation, + additional_allowed_relation: Relation | None, ) -> tuple[bool, str | None]: # No relationship to check, return true if other_id is None: @@ -24,85 +25,62 @@ def passes_relatedness_check( coefficients = relatedness_check_lookup.get( (min(sample_id, other_id), max(sample_id, other_id)), ) - if not coefficients or not np.allclose( - coefficients, - relation.coefficients, - atol=RELATEDNESS_TOLERANCE, + if not coefficients or not any( + np.allclose( + coefficients, + relation.coefficients, + atol=RELATEDNESS_TOLERANCE, + ) + for relation in ( + [expected_relation, additional_allowed_relation] + if additional_allowed_relation + else [expected_relation] + ) ): return ( False, - f'Sample {sample_id} has expected relation "{relation.value}" to {other_id} but has coefficients {coefficients or []}', + f'Sample {sample_id} has expected relation "{expected_relation.value}" to {other_id} but has coefficients {coefficients or []}', ) return True, None -def all_relatedness_checks( # noqa: C901 +def all_relatedness_checks( relatedness_check_lookup: dict[tuple[str, str], list], + family: Family, sample: Sample, ) -> list[str]: failure_reasons = [] - for parent_id in [sample.mother, sample.father]: - success, reason = passes_relatedness_check( - relatedness_check_lookup, - sample.sample_id, - parent_id, - Relation.PARENT, - ) - if not success: - failure_reasons.append(reason) - - for grandparent_id in [ - sample.maternal_grandmother, - sample.maternal_grandfather, - sample.paternal_grandmother, - sample.paternal_grandfather, + for relationship_set, relation, additional_allowed_relation in [ + ([sample.mother, sample.father], Relation.PARENT_CHILD, None), + ( + [ + sample.maternal_grandmother, + sample.maternal_grandfather, + sample.paternal_grandmother, + sample.paternal_grandfather, + ], + Relation.GRANDPARENT_GRANDCHILD, + None, + ), + (sample.siblings, Relation.SIBLING, None), + (sample.half_siblings, Relation.HALF_SIBLING, Relation.SIBLING), + (sample.aunt_nephews, Relation.AUNT_NEPHEW, None), ]: - success, reason = passes_relatedness_check( - relatedness_check_lookup, - sample.sample_id, - grandparent_id, - Relation.GRANDPARENT, - ) - if not success: - failure_reasons.append(reason) - - for sibling_id in sample.siblings: - success, reason = passes_relatedness_check( - relatedness_check_lookup, - sample.sample_id, - sibling_id, - Relation.SIBLING, - ) - if not success: - failure_reasons.append(reason) - - for half_sibling_id in sample.half_siblings: - # NB: A "half sibling" parsed from the pedigree may actually be a sibling, so we allow those - # through as well. - success1, _ = passes_relatedness_check( - relatedness_check_lookup, - sample.sample_id, - half_sibling_id, - Relation.SIBLING, - ) - success2, reason = passes_relatedness_check( - relatedness_check_lookup, - sample.sample_id, - half_sibling_id, - Relation.HALF_SIBLING, - ) - if not success1 and not success2: - failure_reasons.append(reason) - - for aunt_nephew_id in sample.aunt_nephews: - success, reason = passes_relatedness_check( - relatedness_check_lookup, - sample.sample_id, - aunt_nephew_id, - Relation.AUNT_NEPHEW, - ) - if not success: - failure_reasons.append(reason) + for other_id in relationship_set: + # Handle case where relation is identified in the + # pedigree as a "dummy" but is not included in + # the list of samples to load. + if other_id not in family.samples: + continue + success, reason = passes_relatedness_check( + relatedness_check_lookup, + sample.sample_id, + other_id, + relation, + additional_allowed_relation, + ) + if not success: + failure_reasons.append(reason) return failure_reasons @@ -162,6 +140,7 @@ def get_families_failed_relatedness_check( for sample in family.samples.values(): failure_reasons = all_relatedness_checks( relatedness_check_lookup, + family, sample, ) if failure_reasons: diff --git a/v03_pipeline/lib/misc/family_loading_failures_test.py b/v03_pipeline/lib/misc/family_loading_failures_test.py index bcd783f5f..c9b1858b2 100644 --- a/v03_pipeline/lib/misc/family_loading_failures_test.py +++ b/v03_pipeline/lib/misc/family_loading_failures_test.py @@ -9,7 +9,7 @@ get_families_failed_sex_check, ) from v03_pipeline.lib.misc.io import import_pedigree -from v03_pipeline.lib.misc.pedigree import Sample, parse_pedigree_ht_to_families +from v03_pipeline.lib.misc.pedigree import Family, Sample, parse_pedigree_ht_to_families from v03_pipeline.lib.model import Sex TEST_PEDIGREE_6 = 'v03_pipeline/var/test/pedigrees/test_pedigree_6.tsv' @@ -104,7 +104,21 @@ def test_all_relatedness_checks(self): paternal_grandfather='sample_3', half_siblings=['sample_4'], ) - failure_reasons = all_relatedness_checks(relatedness_check_lookup, sample) + family = Family( + family_guid='family_1a', + samples={ + 'sample_1': sample, + 'sample_2': Sample(sex=Sex.MALE, sample_id='sample_2'), + 'sample_3': Sample(sex=Sex.MALE, sample_id='sample_3'), + 'sample_4': Sample(sex=Sex.MALE, sample_id='sample_4'), + 'sample_5': Sample(sex=Sex.MALE, sample_id='sample_5'), + }, + ) + failure_reasons = all_relatedness_checks( + relatedness_check_lookup, + family, + sample, + ) self.assertListEqual(failure_reasons, []) # Defined grandparent missing in relatedness table @@ -117,12 +131,13 @@ def test_all_relatedness_checks(self): ) failure_reasons = all_relatedness_checks( relatedness_check_lookup, + family, sample, ) self.assertListEqual( failure_reasons, [ - 'Sample sample_1 has expected relation "grandparent" to sample_5 but has coefficients []', + 'Sample sample_1 has expected relation "grandparent_grandchild" to sample_5 but has coefficients []', ], ) @@ -140,6 +155,7 @@ def test_all_relatedness_checks(self): ) failure_reasons = all_relatedness_checks( relatedness_check_lookup, + family, sample, ) self.assertListEqual( @@ -167,16 +183,42 @@ def test_all_relatedness_checks(self): ) failure_reasons = all_relatedness_checks( relatedness_check_lookup, + family, sample, ) self.assertListEqual( failure_reasons, [ - 'Sample sample_1 has expected relation "parent" to sample_2 but has coefficients [0.5, 0.5, 0.5, 0.5]', + 'Sample sample_1 has expected relation "parent_child" to sample_2 but has coefficients [0.5, 0.5, 0.5, 0.5]', 'Sample sample_1 has expected relation "sibling" to sample_4 but has coefficients [0.5, 0.5, 0, 0.25]', ], ) + # Some samples will include relationships with + # samples that are not expected to be included + # in the callset. These should not trigger relatedness + # failures. + sample = Sample( + sex=Sex.FEMALE, + sample_id='sample_1', + mother='sample_2', + ) + family = Family( + family_guid='family_1a', + samples={ + 'sample_1': sample, + }, + ) + failure_reasons = all_relatedness_checks( + {}, + family, + sample, + ) + self.assertListEqual( + failure_reasons, + [], + ) + def test_get_families_failed_sex_check(self): sex_check_ht = hl.Table.parallelize( [ diff --git a/v03_pipeline/lib/misc/pedigree.py b/v03_pipeline/lib/misc/pedigree.py index ee2d86521..2e17f27f0 100644 --- a/v03_pipeline/lib/misc/pedigree.py +++ b/v03_pipeline/lib/misc/pedigree.py @@ -8,8 +8,8 @@ class Relation(Enum): - PARENT = 'parent' - GRANDPARENT = 'grandparent' + PARENT_CHILD = 'parent_child' + GRANDPARENT_GRANDCHILD = 'grandparent_grandchild' SIBLING = 'sibling' HALF_SIBLING = 'half_sibling' AUNT_NEPHEW = 'aunt_nephew' @@ -17,8 +17,8 @@ class Relation(Enum): @property def coefficients(self): return { - Relation.PARENT: [0, 1, 0, 0.5], - Relation.GRANDPARENT: [0.5, 0.5, 0, 0.25], + Relation.PARENT_CHILD: [0, 1, 0, 0.5], + Relation.GRANDPARENT_GRANDCHILD: [0.5, 0.5, 0, 0.25], Relation.SIBLING: [0.25, 0.5, 0.25, 0.5], Relation.HALF_SIBLING: [0.5, 0.5, 0, 0.25], Relation.AUNT_NEPHEW: [0.5, 0.5, 0, 0.25], diff --git a/v03_pipeline/lib/misc/sample_ids.py b/v03_pipeline/lib/misc/sample_ids.py index 77f173ffc..7a1c2bc0b 100644 --- a/v03_pipeline/lib/misc/sample_ids.py +++ b/v03_pipeline/lib/misc/sample_ids.py @@ -3,28 +3,16 @@ import hail as hl from v03_pipeline.lib.logger import get_logger +from v03_pipeline.lib.misc.validation import SeqrValidationError logger = get_logger(__name__) -class MatrixTableSampleSetError(Exception): - def __init__(self, message, missing_samples): - super().__init__(message) - self.missing_samples = missing_samples - - -def vcf_remap(mt: hl.MatrixTable) -> hl.MatrixTable: - # TODO: add logic from Mike to remap vcf samples delivered from Broad WGS - return mt - - def remap_sample_ids( mt: hl.MatrixTable, project_remap_ht: hl.Table, ignore_missing_samples_when_remapping: bool, ) -> hl.MatrixTable: - mt = vcf_remap(mt) - collected_remap = project_remap_ht.collect() s_dups = [k for k, v in Counter([r.s for r in collected_remap]).items() if v > 1] seqr_dups = [ @@ -33,7 +21,7 @@ def remap_sample_ids( if len(s_dups) > 0 or len(seqr_dups) > 0: msg = f'Duplicate s or seqr_id entries in remap file were found. Duplicate s:{s_dups}. Duplicate seqr_id:{seqr_dups}.' - raise ValueError(msg) + raise SeqrValidationError(msg) missing_samples = project_remap_ht.anti_join(mt.cols()).collect() remap_count = len(collected_remap) @@ -48,7 +36,7 @@ def remap_sample_ids( if ignore_missing_samples_when_remapping: logger.info(message) else: - raise MatrixTableSampleSetError(message, missing_samples) + raise SeqrValidationError(message) mt = mt.annotate_cols(**project_remap_ht[mt.s]) remap_expr = hl.if_else(hl.is_missing(mt.seqr_id), mt.s, mt.seqr_id) @@ -67,7 +55,7 @@ def subset_samples( anti_join_ht_count = anti_join_ht.count() if subset_count == 0: message = '0 sample ids found the subset HT, something is probably wrong.' - raise MatrixTableSampleSetError(message, []) + raise SeqrValidationError(message) if anti_join_ht_count != 0: missing_samples = anti_join_ht.s.collect() @@ -77,7 +65,7 @@ def subset_samples( f"IDs that aren't in the callset: {missing_samples}\n" f'All callset sample IDs:{mt.s.collect()}' ) - raise MatrixTableSampleSetError(message, missing_samples) + raise SeqrValidationError(message) logger.info(f'Subsetted to {subset_count} sample ids') mt = mt.semi_join_cols(sample_subset_ht) return mt.filter_rows(hl.agg.any(hl.is_defined(mt.GT))) diff --git a/v03_pipeline/lib/misc/sample_ids_test.py b/v03_pipeline/lib/misc/sample_ids_test.py index db264a1f9..770f5451e 100644 --- a/v03_pipeline/lib/misc/sample_ids_test.py +++ b/v03_pipeline/lib/misc/sample_ids_test.py @@ -3,10 +3,10 @@ import hail as hl from v03_pipeline.lib.misc.sample_ids import ( - MatrixTableSampleSetError, remap_sample_ids, subset_samples, ) +from v03_pipeline.lib.misc.validation import SeqrValidationError CALLSET_MT = hl.MatrixTable.from_parts( rows={'variants': [1, 2]}, @@ -76,7 +76,7 @@ def test_remap_sample_ids_remap_has_duplicate(self) -> None: key='s', ) - with self.assertRaises(ValueError): + with self.assertRaises(SeqrValidationError): remap_sample_ids( CALLSET_MT, project_remap_ht, @@ -99,7 +99,7 @@ def test_remap_sample_ids_remap_has_missing_samples(self) -> None: key='s', ) - with self.assertRaises(MatrixTableSampleSetError): + with self.assertRaises(SeqrValidationError): remap_sample_ids( CALLSET_MT, project_remap_ht, @@ -114,7 +114,7 @@ def test_subset_samples_zero_samples(self): key='s', ) - with self.assertRaises(MatrixTableSampleSetError): + with self.assertRaises(SeqrValidationError): subset_samples( CALLSET_MT, sample_subset_ht, @@ -132,7 +132,7 @@ def test_subset_samples_missing_samples(self): key='s', ) - with self.assertRaises(MatrixTableSampleSetError): + with self.assertRaises(SeqrValidationError): subset_samples( CALLSET_MT, sample_subset_ht, diff --git a/v03_pipeline/lib/model/__init__.py b/v03_pipeline/lib/model/__init__.py index bb4325d47..45847f4f5 100644 --- a/v03_pipeline/lib/model/__init__.py +++ b/v03_pipeline/lib/model/__init__.py @@ -7,11 +7,13 @@ Sex, ) from v03_pipeline.lib.model.environment import Env +from v03_pipeline.lib.model.feature_flag import FeatureFlag __all__ = [ 'AccessControl', 'DatasetType', 'Env', + 'FeatureFlag', 'Sex', 'PipelineVersion', 'ReferenceGenome', diff --git a/v03_pipeline/lib/model/environment.py b/v03_pipeline/lib/model/environment.py index e307b0226..42ae4c658 100644 --- a/v03_pipeline/lib/model/environment.py +++ b/v03_pipeline/lib/model/environment.py @@ -51,30 +51,12 @@ GCLOUD_REGION = os.environ.get('GCLOUD_REGION') PIPELINE_RUNNER_APP_VERSION = os.environ.get('PIPELINE_RUNNER_APP_VERSION', 'latest') -# Feature Flags -ACCESS_PRIVATE_REFERENCE_DATASETS = ( - os.environ.get('ACCESS_PRIVATE_REFERENCE_DATASETS') == '1' -) -CHECK_SEX_AND_RELATEDNESS = os.environ.get('CHECK_SEX_AND_RELATEDNESS') == '1' -EXPECT_TDR_METRICS = os.environ.get('EXPECT_TDR_METRICS') == '1' -EXPECT_WES_FILTERS = os.environ.get('EXPECT_WES_FILTERS') == '1' -INCLUDE_PIPELINE_VERSION_IN_PREFIX = ( - os.environ.get('INCLUDE_PIPELINE_VERSION_IN_PREFIX') == '1' -) -SHOULD_TRIGGER_HAIL_BACKEND_RELOAD = ( - os.environ.get('SHOULD_TRIGGER_HAIL_BACKEND_RELOAD') == '1' -) - @dataclass class Env: - ACCESS_PRIVATE_REFERENCE_DATASETS: bool = ACCESS_PRIVATE_REFERENCE_DATASETS - CHECK_SEX_AND_RELATEDNESS: bool = CHECK_SEX_AND_RELATEDNESS CLINGEN_ALLELE_REGISTRY_LOGIN: str | None = CLINGEN_ALLELE_REGISTRY_LOGIN CLINGEN_ALLELE_REGISTRY_PASSWORD: str | None = CLINGEN_ALLELE_REGISTRY_PASSWORD DEPLOYMENT_TYPE: Literal['dev', 'prod'] = DEPLOYMENT_TYPE - EXPECT_TDR_METRICS: bool = EXPECT_TDR_METRICS - EXPECT_WES_FILTERS: bool = EXPECT_WES_FILTERS GCLOUD_DATAPROC_SECONDARY_WORKERS: str = GCLOUD_DATAPROC_SECONDARY_WORKERS GCLOUD_PROJECT: str | None = GCLOUD_PROJECT GCLOUD_ZONE: str | None = GCLOUD_ZONE @@ -85,10 +67,8 @@ class Env: HAIL_BACKEND_SERVICE_PORT: int = HAIL_BACKEND_SERVICE_PORT HAIL_TMP_DIR: str = HAIL_TMP_DIR HAIL_SEARCH_DATA_DIR: str = HAIL_SEARCH_DATA_DIR - INCLUDE_PIPELINE_VERSION_IN_PREFIX: bool = INCLUDE_PIPELINE_VERSION_IN_PREFIX LOADING_DATASETS_DIR: str = LOADING_DATASETS_DIR PIPELINE_RUNNER_APP_VERSION: str = PIPELINE_RUNNER_APP_VERSION PRIVATE_REFERENCE_DATASETS_DIR: str = PRIVATE_REFERENCE_DATASETS_DIR REFERENCE_DATASETS_DIR: str = REFERENCE_DATASETS_DIR - SHOULD_TRIGGER_HAIL_BACKEND_RELOAD: bool = SHOULD_TRIGGER_HAIL_BACKEND_RELOAD VEP_REFERENCE_DATASETS_DIR: str = VEP_REFERENCE_DATASETS_DIR diff --git a/v03_pipeline/lib/model/feature_flag.py b/v03_pipeline/lib/model/feature_flag.py new file mode 100644 index 000000000..0b57e8643 --- /dev/null +++ b/v03_pipeline/lib/model/feature_flag.py @@ -0,0 +1,26 @@ +import os +from dataclasses import dataclass + +# Feature Flags +ACCESS_PRIVATE_REFERENCE_DATASETS = ( + os.environ.get('ACCESS_PRIVATE_REFERENCE_DATASETS') == '1' +) +CHECK_SEX_AND_RELATEDNESS = os.environ.get('CHECK_SEX_AND_RELATEDNESS') == '1' +EXPECT_TDR_METRICS = os.environ.get('EXPECT_TDR_METRICS') == '1' +EXPECT_WES_FILTERS = os.environ.get('EXPECT_WES_FILTERS') == '1' +INCLUDE_PIPELINE_VERSION_IN_PREFIX = ( + os.environ.get('INCLUDE_PIPELINE_VERSION_IN_PREFIX') == '1' +) +SHOULD_TRIGGER_HAIL_BACKEND_RELOAD = ( + os.environ.get('SHOULD_TRIGGER_HAIL_BACKEND_RELOAD') == '1' +) + + +@dataclass +class FeatureFlag: + ACCESS_PRIVATE_REFERENCE_DATASETS: bool = ACCESS_PRIVATE_REFERENCE_DATASETS + CHECK_SEX_AND_RELATEDNESS: bool = CHECK_SEX_AND_RELATEDNESS + EXPECT_TDR_METRICS: bool = EXPECT_TDR_METRICS + EXPECT_WES_FILTERS: bool = EXPECT_WES_FILTERS + INCLUDE_PIPELINE_VERSION_IN_PREFIX: bool = INCLUDE_PIPELINE_VERSION_IN_PREFIX + SHOULD_TRIGGER_HAIL_BACKEND_RELOAD: bool = SHOULD_TRIGGER_HAIL_BACKEND_RELOAD diff --git a/v03_pipeline/lib/paths.py b/v03_pipeline/lib/paths.py index 5a54880fd..2942b1dae 100644 --- a/v03_pipeline/lib/paths.py +++ b/v03_pipeline/lib/paths.py @@ -6,6 +6,7 @@ AccessControl, DatasetType, Env, + FeatureFlag, PipelineVersion, ReferenceGenome, SampleType, @@ -21,7 +22,7 @@ def _pipeline_prefix( reference_genome: ReferenceGenome, dataset_type: DatasetType, ) -> str: - if Env.INCLUDE_PIPELINE_VERSION_IN_PREFIX: + if FeatureFlag.INCLUDE_PIPELINE_VERSION_IN_PREFIX: return os.path.join( root, PipelineVersion.V3_1.value, @@ -45,7 +46,7 @@ def _v03_reference_data_prefix( if access_control == AccessControl.PRIVATE else Env.REFERENCE_DATASETS_DIR ) - if Env.INCLUDE_PIPELINE_VERSION_IN_PREFIX: + if FeatureFlag.INCLUDE_PIPELINE_VERSION_IN_PREFIX: return os.path.join( root, PipelineVersion.V03.value, @@ -68,7 +69,7 @@ def _v03_reference_dataset_prefix( if access_control == AccessControl.PRIVATE else Env.REFERENCE_DATASETS_DIR ) - if Env.INCLUDE_PIPELINE_VERSION_IN_PREFIX: + if FeatureFlag.INCLUDE_PIPELINE_VERSION_IN_PREFIX: return os.path.join( root, PipelineVersion.V3_1.value, @@ -287,7 +288,7 @@ def valid_filters_path( callset_path: str, ) -> str | None: if ( - not Env.EXPECT_WES_FILTERS + not FeatureFlag.EXPECT_WES_FILTERS or not dataset_type.expect_filters(sample_type) or 'part_one_outputs' not in callset_path ): diff --git a/v03_pipeline/lib/paths_test.py b/v03_pipeline/lib/paths_test.py index 8c5f097f4..eac76232f 100644 --- a/v03_pipeline/lib/paths_test.py +++ b/v03_pipeline/lib/paths_test.py @@ -36,7 +36,9 @@ def test_family_table_path(self) -> None: ), '/var/seqr/seqr-hail-search-data/v3.1/GRCh37/SNV_INDEL/families/WES/franklin.ht', ) - with patch('v03_pipeline.lib.paths.Env') as mock_env: + with patch('v03_pipeline.lib.paths.Env') as mock_env, patch( + 'v03_pipeline.lib.paths.FeatureFlag', + ) as mock_ff: mock_env.HAIL_SEARCH_DATA_DIR = 'gs://seqr-datasets/' self.assertEqual( family_table_path( @@ -47,7 +49,7 @@ def test_family_table_path(self) -> None: ), 'gs://seqr-datasets/v3.1/GRCh37/SNV_INDEL/families/WES/franklin.ht', ) - mock_env.INCLUDE_PIPELINE_VERSION_IN_PREFIX = False + mock_ff.INCLUDE_PIPELINE_VERSION_IN_PREFIX = False self.assertEqual( family_table_path( ReferenceGenome.GRCh37, @@ -67,8 +69,8 @@ def test_valid_filters_path(self) -> None: ), None, ) - with patch('v03_pipeline.lib.paths.Env') as mock_env: - mock_env.EXPECT_WES_FILTERS = True + with patch('v03_pipeline.lib.paths.FeatureFlag') as mock_ff: + mock_ff.EXPECT_WES_FILTERS = True self.assertEqual( valid_filters_path( DatasetType.SNV_INDEL, diff --git a/v03_pipeline/lib/reference_datasets/reference_dataset.py b/v03_pipeline/lib/reference_datasets/reference_dataset.py index 47480e723..445e191aa 100644 --- a/v03_pipeline/lib/reference_datasets/reference_dataset.py +++ b/v03_pipeline/lib/reference_datasets/reference_dataset.py @@ -10,7 +10,12 @@ validate_allele_type, validate_no_duplicate_variants, ) -from v03_pipeline.lib.model import AccessControl, DatasetType, Env, ReferenceGenome +from v03_pipeline.lib.model import ( + AccessControl, + DatasetType, + FeatureFlag, + ReferenceGenome, +) from v03_pipeline.lib.reference_datasets import clinvar, dbnsfp from v03_pipeline.lib.reference_datasets.misc import ( compress_floats, @@ -41,7 +46,7 @@ def for_reference_genome_dataset_type( for dataset, config in CONFIG.items() if dataset_type in config.get(reference_genome, {}).get(DATASET_TYPES, []) ] - if not Env.ACCESS_PRIVATE_REFERENCE_DATASETS: + if not FeatureFlag.ACCESS_PRIVATE_REFERENCE_DATASETS: return { dataset for dataset in reference_datasets diff --git a/v03_pipeline/lib/tasks/dataproc/base_run_job_on_dataproc.py b/v03_pipeline/lib/tasks/dataproc/base_run_job_on_dataproc.py new file mode 100644 index 000000000..1094d651e --- /dev/null +++ b/v03_pipeline/lib/tasks/dataproc/base_run_job_on_dataproc.py @@ -0,0 +1,106 @@ +import time + +import google.api_core.exceptions +import luigi +from google.cloud import dataproc_v1 as dataproc + +from v03_pipeline.lib.logger import get_logger +from v03_pipeline.lib.model import Env +from v03_pipeline.lib.tasks.base.base_loading_pipeline_params import ( + BaseLoadingPipelineParams, +) +from v03_pipeline.lib.tasks.dataproc.create_dataproc_cluster import ( + CreateDataprocClusterTask, +) +from v03_pipeline.lib.tasks.dataproc.misc import get_cluster_name, to_kebab_str_args + +DONE_STATE = 'DONE' +ERROR_STATE = 'ERROR' +SEQR_PIPELINE_RUNNER_BUILD = f'gs://seqr-pipeline-runner-builds/{Env.DEPLOYMENT_TYPE}/{Env.PIPELINE_RUNNER_APP_VERSION}' +TIMEOUT_S = 172800 # 2 days + +logger = get_logger(__name__) + + +@luigi.util.inherits(BaseLoadingPipelineParams) +class BaseRunJobOnDataprocTask(luigi.Task): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = dataproc.JobControllerClient( + client_options={ + 'api_endpoint': f'{Env.GCLOUD_REGION}-dataproc.googleapis.com:443', + }, + ) + + @property + def task_name(self): + return self.get_task_family().split('.')[-1] + + @property + def job_id(self): + return f'{self.task_name}-{self.run_id}' + + def requires(self) -> [luigi.Task]: + return [self.clone(CreateDataprocClusterTask)] + + def complete(self) -> bool: + if not self.dataset_type.requires_dataproc: + msg = f'{self.dataset_type} should not require a dataproc job' + raise RuntimeError(msg) + try: + job = self.client.get_job( + request={ + 'project_id': Env.GCLOUD_PROJECT, + 'region': Env.GCLOUD_REGION, + 'job_id': self.job_id, + }, + ) + except google.api_core.exceptions.NotFound: + return False + if job.status.state == ERROR_STATE: + msg = f'Job {self.task_name}-{self.run_id} entered ERROR state' + logger.error(msg) + logger.error(job.status.details) + return job.status.state == DONE_STATE + + def run(self): + operation = self.client.submit_job_as_operation( + request={ + 'project_id': Env.GCLOUD_PROJECT, + 'region': Env.GCLOUD_REGION, + 'job': { + 'reference': { + 'job_id': self.job_id, + }, + 'placement': { + 'cluster_name': get_cluster_name( + self.reference_genome, + self.run_id, + ), + }, + 'pyspark_job': { + 'main_python_file_uri': f'{SEQR_PIPELINE_RUNNER_BUILD}/bin/run_task.py', + 'args': [ + self.task_name, + '--local-scheduler', + *to_kebab_str_args(self), + ], + 'python_file_uris': [ + f'{SEQR_PIPELINE_RUNNER_BUILD}/pyscripts.zip', + ], + }, + }, + }, + ) + wait_s = 0 + while wait_s < TIMEOUT_S: + if operation.done(): + operation.result() # Will throw on failure! + msg = f'Finished {self.job_id}' + logger.info(msg) + break + logger.info( + f'Waiting for job completion {self.job_id}', + ) + time.sleep(3) + wait_s += 3 diff --git a/v03_pipeline/lib/tasks/dataproc/create_dataproc_cluster.py b/v03_pipeline/lib/tasks/dataproc/create_dataproc_cluster.py index f7d0bc21d..6fce572a0 100644 --- a/v03_pipeline/lib/tasks/dataproc/create_dataproc_cluster.py +++ b/v03_pipeline/lib/tasks/dataproc/create_dataproc_cluster.py @@ -1,5 +1,6 @@ import time +import google.api_core.exceptions import hail as hl import luigi from google.cloud import dataproc_v1 as dataproc @@ -7,17 +8,19 @@ from v03_pipeline.lib.logger import get_logger from v03_pipeline.lib.misc.gcp import get_service_account_credentials -from v03_pipeline.lib.model import Env, ReferenceGenome +from v03_pipeline.lib.model import Env, FeatureFlag, ReferenceGenome from v03_pipeline.lib.tasks.base.base_loading_pipeline_params import ( BaseLoadingPipelineParams, ) +from v03_pipeline.lib.tasks.dataproc.misc import get_cluster_name -CLUSTER_NAME_PREFIX = 'pipeline-runner' DEBIAN_IMAGE = '2.1.33-debian11' +ERROR_STATE = 'ERROR' HAIL_VERSION = hl.version().split('-')[0] INSTANCE_TYPE = 'n1-highmem-8' PKGS = '|'.join(pip_freeze.freeze()) -SUCCESS_STATE = 'RUNNING' +RUNNING_STATE = 'RUNNING' +TIMEOUT_S = 900 logger = get_logger(__name__) @@ -26,7 +29,7 @@ def get_cluster_config(reference_genome: ReferenceGenome, run_id: str): service_account_credentials = get_service_account_credentials() return { 'project_id': Env.GCLOUD_PROJECT, - 'cluster_name': f'{CLUSTER_NAME_PREFIX}-{reference_genome.value.lower()}-{run_id}', + 'cluster_name': get_cluster_name(reference_genome, run_id), # Schema found at https://cloud.google.com/dataproc/docs/reference/rest/v1/ClusterConfig 'config': { 'gce_cluster_config': { @@ -88,18 +91,21 @@ def get_cluster_config(reference_genome: ReferenceGenome, run_id: str): 'spark:spark.executorEnv.HAIL_WORKER_OFF_HEAP_MEMORY_PER_CORE_MB': '6323', 'spark:spark.speculation': 'true', 'spark-env:ACCESS_PRIVATE_REFERENCE_DATASETS': '1' - if Env.ACCESS_PRIVATE_REFERENCE_DATASETS + if FeatureFlag.ACCESS_PRIVATE_REFERENCE_DATASETS else '0', 'spark-env:CHECK_SEX_AND_RELATEDNESS': '1' - if Env.CHECK_SEX_AND_RELATEDNESS + if FeatureFlag.CHECK_SEX_AND_RELATEDNESS + else '0', + 'spark-env:EXPECT_TDR_METRICS': '1' + if FeatureFlag.EXPECT_TDR_METRICS else '0', 'spark-env:EXPECT_WES_FILTERS': '1' - if Env.EXPECT_WES_FILTERS + if FeatureFlag.EXPECT_WES_FILTERS else '0', 'spark-env:HAIL_SEARCH_DATA_DIR': Env.HAIL_SEARCH_DATA_DIR, 'spark-env:HAIL_TMP_DIR': Env.HAIL_TMP_DIR, 'spark-env:INCLUDE_PIPELINE_VERSION_IN_PREFIX': '1' - if Env.INCLUDE_PIPELINE_VERSION_IN_PREFIX + if FeatureFlag.INCLUDE_PIPELINE_VERSION_IN_PREFIX else '0', 'spark-env:LOADING_DATASETS_DIR': Env.LOADING_DATASETS_DIR, 'spark-env:PRIVATE_REFERENCE_DATASETS_DIR': Env.PRIVATE_REFERENCE_DATASETS_DIR, @@ -133,27 +139,32 @@ def __init__(self, *args, **kwargs): # https://cloud.google.com/dataproc/docs/tutorials/python-library-example self.client = dataproc.ClusterControllerClient( client_options={ - 'api_endpoint': f'{Env.GCLOUD_REGION}-dataproc.googleapis.com:443'.format( - Env.GCLOUD_REGION, - ), + 'api_endpoint': f'{Env.GCLOUD_REGION}-dataproc.googleapis.com:443', }, ) def complete(self) -> bool: if not self.dataset_type.requires_dataproc: - return True + msg = f'{self.dataset_type} should not require a dataproc cluster' + raise RuntimeError(msg) try: - client = self.client.get_cluster( + cluster = self.client.get_cluster( request={ 'project_id': Env.GCLOUD_PROJECT, 'region': Env.GCLOUD_REGION, - 'cluster_name': f'{CLUSTER_NAME_PREFIX}-{self.reference_genome.value.lower()}', + 'cluster_name': get_cluster_name( + self.reference_genome, + self.run_id, + ), }, ) - except Exception: # noqa: BLE001 + except google.api_core.exceptions.NotFound: return False - else: - return client.status.state == SUCCESS_STATE + if cluster.status.state == ERROR_STATE: + msg = f'Cluster {cluster.cluster_name} entered ERROR state' + logger.error(msg) + # This will return False when the cluster is "CREATING" + return cluster.status.state == RUNNING_STATE def run(self): operation = self.client.create_cluster( @@ -163,7 +174,8 @@ def run(self): 'cluster': get_cluster_config(self.reference_genome, self.run_id), }, ) - while True: + wait_s = 0 + while wait_s < TIMEOUT_S: if operation.done(): result = operation.result() # Will throw on failure! msg = f'Created cluster {result.cluster_name} with cluster uuid: {result.cluster_uuid}' @@ -171,3 +183,4 @@ def run(self): break logger.info('Waiting for cluster spinup') time.sleep(3) + wait_s += 3 diff --git a/v03_pipeline/lib/tasks/dataproc/create_dataproc_cluster_test.py b/v03_pipeline/lib/tasks/dataproc/create_dataproc_cluster_test.py index c7f4f2958..e3fa8f3fa 100644 --- a/v03_pipeline/lib/tasks/dataproc/create_dataproc_cluster_test.py +++ b/v03_pipeline/lib/tasks/dataproc/create_dataproc_cluster_test.py @@ -28,15 +28,12 @@ def test_dataset_type_unsupported( mock_cluster_controller: Mock, _: Mock, ) -> None: - worker = luigi.worker.Worker() task = CreateDataprocClusterTask( reference_genome=ReferenceGenome.GRCh38, dataset_type=DatasetType.MITO, run_id='1', ) - worker.add(task) - worker.run() - self.assertTrue(task.complete()) + self.assertRaises(RuntimeError, task.complete) def test_spinup_cluster_already_exists_failed( self, @@ -45,7 +42,8 @@ def test_spinup_cluster_already_exists_failed( ) -> None: mock_client = mock_cluster_controller.return_value mock_client.get_cluster.return_value = SimpleNamespace( - status=SimpleNamespace(state='FAILED'), + status=SimpleNamespace(state='ERROR'), + cluster_name='abc', ) mock_client.create_cluster.side_effect = ( google.api_core.exceptions.AlreadyExists('cluster exists') @@ -122,7 +120,7 @@ def test_spinup_cluster_doesnt_exist_success( operation = mock_client.create_cluster.return_value operation.done.side_effect = [False, True] operation.result.return_value = SimpleNamespace( - cluster_name='dataproc-cluster-1', + cluster_name='dataproc-cluster-5', cluster_uuid='12345', ) worker = luigi.worker.Worker() @@ -136,6 +134,6 @@ def test_spinup_cluster_doesnt_exist_success( mock_logger.info.assert_has_calls( [ call('Waiting for cluster spinup'), - call('Created cluster dataproc-cluster-1 with cluster uuid: 12345'), + call('Created cluster dataproc-cluster-5 with cluster uuid: 12345'), ], ) diff --git a/v03_pipeline/lib/tasks/dataproc/misc.py b/v03_pipeline/lib/tasks/dataproc/misc.py new file mode 100644 index 000000000..672f86dcc --- /dev/null +++ b/v03_pipeline/lib/tasks/dataproc/misc.py @@ -0,0 +1,21 @@ +import re + +import luigi + +from v03_pipeline.lib.model import ReferenceGenome + +CLUSTER_NAME_PREFIX = 'pipeline-runner' + + +def get_cluster_name(reference_genome: ReferenceGenome, run_id: str): + return f'{CLUSTER_NAME_PREFIX}-{reference_genome.value.lower()}-{run_id}' + + +def snake_to_kebab_arg(snake_string: str) -> str: + return '--' + re.sub(r'\_', '-', snake_string).lower() + + +def to_kebab_str_args(task: luigi.Task): + return [ + e for k, v in task.to_str_params().items() for e in (snake_to_kebab_arg(k), v) + ] diff --git a/v03_pipeline/lib/tasks/dataproc/misc_test.py b/v03_pipeline/lib/tasks/dataproc/misc_test.py new file mode 100644 index 000000000..335cacbf7 --- /dev/null +++ b/v03_pipeline/lib/tasks/dataproc/misc_test.py @@ -0,0 +1,58 @@ +import unittest +from unittest.mock import Mock, patch + +from v03_pipeline.lib.model import DatasetType, ReferenceGenome, SampleType +from v03_pipeline.lib.tasks.dataproc.misc import to_kebab_str_args +from v03_pipeline.lib.tasks.dataproc.write_success_file_on_dataproc import ( + WriteSuccessFileOnDataprocTask, +) + + +@patch( + 'v03_pipeline.lib.tasks.dataproc.base_run_job_on_dataproc.dataproc.JobControllerClient', +) +class MiscTest(unittest.TestCase): + def test_to_kebab_str_args(self, _: Mock): + t = WriteSuccessFileOnDataprocTask( + reference_genome=ReferenceGenome.GRCh38, + dataset_type=DatasetType.SNV_INDEL, + sample_type=SampleType.WGS, + callset_path='test_callset', + project_guids=['R0113_test_project'], + project_remap_paths=['test_remap'], + project_pedigree_paths=['test_pedigree'], + run_id='a_misc_run', + ) + self.assertListEqual( + to_kebab_str_args(t), + [ + '--reference-genome', + 'GRCh38', + '--dataset-type', + 'SNV_INDEL', + '--run-id', + 'a_misc_run', + '--sample-type', + 'WGS', + '--callset-path', + 'test_callset', + '--project-guids', + '["R0113_test_project"]', + '--project-remap-paths', + '["test_remap"]', + '--project-pedigree-paths', + '["test_pedigree"]', + '--ignore-missing-samples-when-remapping', + 'False', + '--skip-check-sex-and-relatedness', + 'False', + '--skip-expect-filters', + 'False', + '--skip-expect-tdr-metrics', + 'False', + '--skip-validation', + 'False', + '--is-new-gcnv-joint-call', + 'False', + ], + ) diff --git a/v03_pipeline/lib/tasks/dataproc/write_success_file_on_dataproc.py b/v03_pipeline/lib/tasks/dataproc/write_success_file_on_dataproc.py new file mode 100644 index 000000000..0c48c344f --- /dev/null +++ b/v03_pipeline/lib/tasks/dataproc/write_success_file_on_dataproc.py @@ -0,0 +1,22 @@ +import luigi + +from v03_pipeline.lib.paths import pipeline_run_success_file_path +from v03_pipeline.lib.tasks.base.base_loading_run_params import ( + BaseLoadingRunParams, +) +from v03_pipeline.lib.tasks.dataproc.base_run_job_on_dataproc import ( + BaseRunJobOnDataprocTask, +) +from v03_pipeline.lib.tasks.files import GCSorLocalTarget + + +@luigi.util.inherits(BaseLoadingRunParams) +class WriteSuccessFileOnDataprocTask(BaseRunJobOnDataprocTask): + def output(self) -> luigi.Target: + return GCSorLocalTarget( + pipeline_run_success_file_path( + self.reference_genome, + self.dataset_type, + self.run_id, + ), + ) diff --git a/v03_pipeline/lib/tasks/dataproc/write_success_file_on_dataproc_test.py b/v03_pipeline/lib/tasks/dataproc/write_success_file_on_dataproc_test.py new file mode 100644 index 000000000..62cb4d678 --- /dev/null +++ b/v03_pipeline/lib/tasks/dataproc/write_success_file_on_dataproc_test.py @@ -0,0 +1,156 @@ +import unittest +from types import SimpleNamespace +from unittest.mock import Mock, call, patch + +import google.api_core.exceptions +import luigi + +from v03_pipeline.lib.model import DatasetType, ReferenceGenome, SampleType +from v03_pipeline.lib.tasks.dataproc.write_success_file_on_dataproc import ( + WriteSuccessFileOnDataprocTask, +) +from v03_pipeline.lib.test.mock_complete_task import MockCompleteTask + + +@patch( + 'v03_pipeline.lib.tasks.dataproc.base_run_job_on_dataproc.CreateDataprocClusterTask', +) +@patch( + 'v03_pipeline.lib.tasks.dataproc.base_run_job_on_dataproc.dataproc.JobControllerClient', +) +class WriteSuccessFileOnDataprocTaskTest(unittest.TestCase): + @patch('v03_pipeline.lib.tasks.dataproc.base_run_job_on_dataproc.logger') + def test_job_already_exists_failed( + self, + mock_logger: Mock, + mock_job_controller_client: Mock, + mock_create_dataproc_cluster: Mock, + ) -> None: + mock_create_dataproc_cluster.return_value = MockCompleteTask() + mock_client = mock_job_controller_client.return_value + mock_client.get_job.return_value = SimpleNamespace( + status=SimpleNamespace( + state='ERROR', + details='Google Cloud Dataproc Agent reports job failure. If logs are available, they can be found at...', + ), + ) + mock_client.submit_job_as_operation.side_effect = ( + google.api_core.exceptions.AlreadyExists('job exists') + ) + worker = luigi.worker.Worker() + task = WriteSuccessFileOnDataprocTask( + reference_genome=ReferenceGenome.GRCh38, + dataset_type=DatasetType.SNV_INDEL, + sample_type=SampleType.WGS, + callset_path='test_callset', + project_guids=['R0113_test_project'], + project_remap_paths=['test_remap'], + project_pedigree_paths=['test_pedigree'], + run_id='manual__2024-04-03', + ) + worker.add(task) + worker.run() + self.assertFalse(task.complete()) + mock_logger.error.assert_has_calls( + [ + call( + 'Job WriteSuccessFileOnDataprocTask-manual__2024-04-03 entered ERROR state', + ), + ], + ) + + def test_job_already_exists_success( + self, + mock_job_controller_client: Mock, + mock_create_dataproc_cluster: Mock, + ) -> None: + mock_create_dataproc_cluster.return_value = MockCompleteTask() + mock_client = mock_job_controller_client.return_value + mock_client.get_job.return_value = SimpleNamespace( + status=SimpleNamespace(state='DONE'), + ) + worker = luigi.worker.Worker() + task = WriteSuccessFileOnDataprocTask( + reference_genome=ReferenceGenome.GRCh38, + dataset_type=DatasetType.SNV_INDEL, + sample_type=SampleType.WGS, + callset_path='test_callset', + project_guids=['R0113_test_project'], + project_remap_paths=['test_remap'], + project_pedigree_paths=['test_pedigree'], + run_id='manual__2024-04-04', + ) + worker.add(task) + worker.run() + self.assertTrue(task.complete()) + + @patch('v03_pipeline.lib.tasks.dataproc.base_run_job_on_dataproc.logger') + def test_job_failed( + self, + mock_logger: Mock, + mock_job_controller_client: Mock, + mock_create_dataproc_cluster: Mock, + ) -> None: + mock_create_dataproc_cluster.return_value = MockCompleteTask() + mock_client = mock_job_controller_client.return_value + mock_client.get_job.side_effect = google.api_core.exceptions.NotFound( + 'job not found', + ) + operation = mock_client.submit_job_as_operation.return_value + operation.done.side_effect = [False, True] + operation.result.side_effect = Exception( + 'FailedPrecondition: 400 Job failed with message', + ) + worker = luigi.worker.Worker() + task = WriteSuccessFileOnDataprocTask( + reference_genome=ReferenceGenome.GRCh38, + dataset_type=DatasetType.SNV_INDEL, + sample_type=SampleType.WGS, + callset_path='test_callset', + project_guids=['R0113_test_project'], + project_remap_paths=['test_remap'], + project_pedigree_paths=['test_pedigree'], + run_id='manual__2024-04-05', + ) + worker.add(task) + worker.run() + self.assertFalse(task.complete()) + mock_logger.info.assert_has_calls( + [ + call( + 'Waiting for job completion WriteSuccessFileOnDataprocTask-manual__2024-04-05', + ), + ], + ) + + def test_job_success( + self, + mock_job_controller_client: Mock, + mock_create_dataproc_cluster: Mock, + ) -> None: + mock_create_dataproc_cluster.return_value = MockCompleteTask() + mock_client = mock_job_controller_client.return_value + mock_client.get_job.side_effect = [ + google.api_core.exceptions.NotFound( + 'job not found', + ), + SimpleNamespace( + status=SimpleNamespace(state='DONE'), + ), + ] + operation = mock_client.submit_job_as_operation.return_value + operation.done.side_effect = [False, True] + worker = luigi.worker.Worker() + task = WriteSuccessFileOnDataprocTask( + reference_genome=ReferenceGenome.GRCh38, + dataset_type=DatasetType.SNV_INDEL, + sample_type=SampleType.WGS, + callset_path='test_callset', + project_guids=['R0113_test_project'], + project_remap_paths=['test_remap'], + project_pedigree_paths=['test_pedigree'], + run_id='manual__2024-04-06', + ) + worker.add(task) + worker.run() + self.assertTrue(task.complete()) diff --git a/v03_pipeline/lib/tasks/update_variant_annotations_table_with_new_samples_test.py b/v03_pipeline/lib/tasks/update_variant_annotations_table_with_new_samples_test.py index d29b58ff6..47511c695 100644 --- a/v03_pipeline/lib/tasks/update_variant_annotations_table_with_new_samples_test.py +++ b/v03_pipeline/lib/tasks/update_variant_annotations_table_with_new_samples_test.py @@ -712,7 +712,7 @@ def test_update_vat_grch37( @patch( 'v03_pipeline.lib.tasks.write_new_variants_table.UpdateVariantAnnotationsTableWithUpdatedReferenceDataset', ) - @patch('v03_pipeline.lib.reference_datasets.reference_dataset.Env') + @patch('v03_pipeline.lib.reference_datasets.reference_dataset.FeatureFlag') @patch('v03_pipeline.lib.vep.hl.vep') @patch( 'v03_pipeline.lib.tasks.write_new_variants_table.load_gencode_ensembl_to_refseq_id', @@ -721,7 +721,7 @@ def test_update_vat_without_accessing_private_datasets( self, mock_load_gencode_ensembl_to_refseq_id: Mock, mock_vep: Mock, - mock_rd_env: Mock, + mock_rd_ff: Mock, mock_update_vat_with_rd_task: Mock, mock_register_alleles: Mock, ) -> None: @@ -740,7 +740,7 @@ def test_update_vat_without_accessing_private_datasets( ReferenceDataset.hgmd, ), ) - mock_rd_env.ACCESS_PRIVATE_REFERENCE_DATASETS = False + mock_rd_ff.ACCESS_PRIVATE_REFERENCE_DATASETS = False mock_vep.side_effect = lambda ht, **_: ht.annotate(vep=MOCK_38_VEP_DATA) mock_register_alleles.side_effect = None diff --git a/v03_pipeline/lib/tasks/validate_callset.py b/v03_pipeline/lib/tasks/validate_callset.py index e5601875b..3d0d979db 100644 --- a/v03_pipeline/lib/tasks/validate_callset.py +++ b/v03_pipeline/lib/tasks/validate_callset.py @@ -10,7 +10,7 @@ validate_no_duplicate_variants, validate_sample_type, ) -from v03_pipeline.lib.model.environment import Env +from v03_pipeline.lib.model.feature_flag import FeatureFlag from v03_pipeline.lib.paths import ( imported_callset_path, sex_check_table_path, @@ -41,7 +41,7 @@ def get_validation_dependencies(self) -> dict[str, hl.Table]: ), ) if ( - Env.CHECK_SEX_AND_RELATEDNESS + FeatureFlag.CHECK_SEX_AND_RELATEDNESS and self.dataset_type.check_sex_and_relatedness and not self.skip_check_sex_and_relatedness ): @@ -86,7 +86,7 @@ def requires(self) -> list[luigi.Task]: ), ] if ( - Env.CHECK_SEX_AND_RELATEDNESS + FeatureFlag.CHECK_SEX_AND_RELATEDNESS and self.dataset_type.check_sex_and_relatedness and not self.skip_check_sex_and_relatedness ): diff --git a/v03_pipeline/lib/tasks/validate_callset_test.py b/v03_pipeline/lib/tasks/validate_callset_test.py index 8f3638376..a6d2b0377 100644 --- a/v03_pipeline/lib/tasks/validate_callset_test.py +++ b/v03_pipeline/lib/tasks/validate_callset_test.py @@ -83,5 +83,6 @@ def test_validate_callset_multiple_exceptions( 'Missing the following expected contigs:chr10, chr11, chr12, chr13, chr14, chr15, chr16, chr17, chr18, chr19, chr2, chr20, chr21, chr22, chr3, chr4, chr5, chr6, chr7, chr8, chr9, chrX', 'Sample type validation error: dataset sample-type is specified as WES but appears to be WGS because it contains many common non-coding variants', ], + 'failed_family_samples': {}, }, ) diff --git a/v03_pipeline/lib/tasks/write_imported_callset.py b/v03_pipeline/lib/tasks/write_imported_callset.py index 9caa71e58..c544b4666 100644 --- a/v03_pipeline/lib/tasks/write_imported_callset.py +++ b/v03_pipeline/lib/tasks/write_imported_callset.py @@ -10,11 +10,10 @@ split_multi_hts, ) from v03_pipeline.lib.misc.validation import ( - SeqrValidationError, validate_imported_field_types, ) from v03_pipeline.lib.misc.vets import annotate_vets -from v03_pipeline.lib.model.environment import Env +from v03_pipeline.lib.model.feature_flag import FeatureFlag from v03_pipeline.lib.paths import ( imported_callset_path, valid_filters_path, @@ -24,7 +23,7 @@ from v03_pipeline.lib.tasks.files import CallsetTask, GCSorLocalTarget from v03_pipeline.lib.tasks.write_tdr_metrics_files import WriteTDRMetricsFilesTask from v03_pipeline.lib.tasks.write_validation_errors_for_run import ( - WriteValidationErrorsForRunTask, + with_persisted_validation_errors, ) @@ -45,7 +44,7 @@ def output(self) -> luigi.Target: def requires(self) -> list[luigi.Task]: requirements = [] if ( - Env.EXPECT_WES_FILTERS + FeatureFlag.EXPECT_WES_FILTERS and not self.skip_expect_filters and self.dataset_type.expect_filters( self.sample_type, @@ -62,7 +61,7 @@ def requires(self) -> list[luigi.Task]: ), ] if ( - Env.EXPECT_TDR_METRICS + FeatureFlag.EXPECT_TDR_METRICS and not self.skip_expect_tdr_metrics and self.dataset_type.expect_tdr_metrics( self.reference_genome, @@ -77,65 +76,56 @@ def requires(self) -> list[luigi.Task]: CallsetTask(self.callset_path), ] + @with_persisted_validation_errors def create_table(self) -> hl.MatrixTable: - try: - # NB: throws SeqrValidationError - mt = import_callset( - self.callset_path, - self.reference_genome, - self.dataset_type, - ) - filters_path = None - if ( - Env.EXPECT_WES_FILTERS - and not self.skip_expect_filters - and self.dataset_type.expect_filters( - self.sample_type, - ) - ): - filters_path = valid_filters_path( - self.dataset_type, - self.sample_type, - self.callset_path, - ) - filters_ht = import_vcf(filters_path, self.reference_genome).rows() - mt = mt.annotate_rows(filters=filters_ht[mt.row_key].filters) - additional_row_fields = get_additional_row_fields( - mt, - self.dataset_type, - self.skip_check_sex_and_relatedness, + # NB: throws SeqrValidationError + mt = import_callset( + self.callset_path, + self.reference_genome, + self.dataset_type, + ) + filters_path = None + if ( + FeatureFlag.EXPECT_WES_FILTERS + and not self.skip_expect_filters + and self.dataset_type.expect_filters( + self.sample_type, ) - # NB: throws SeqrValidationError - mt = select_relevant_fields( - mt, + ): + filters_path = valid_filters_path( self.dataset_type, - additional_row_fields, + self.sample_type, + self.callset_path, ) - # This validation isn't override-able by the skip option. - # If a field is the wrong type, the pipeline will likely hard-fail downstream. + filters_ht = import_vcf(filters_path, self.reference_genome).rows() + mt = mt.annotate_rows(filters=filters_ht[mt.row_key].filters) + additional_row_fields = get_additional_row_fields( + mt, + self.dataset_type, + self.skip_check_sex_and_relatedness, + ) + # NB: throws SeqrValidationError + mt = select_relevant_fields( + mt, + self.dataset_type, + additional_row_fields, + ) + # This validation isn't override-able by the skip option. + # If a field is the wrong type, the pipeline will likely hard-fail downstream. + # NB: throws SeqrValidationError + validate_imported_field_types( + mt, + self.dataset_type, + additional_row_fields, + ) + if self.dataset_type.has_multi_allelic_variants: # NB: throws SeqrValidationError - validate_imported_field_types( - mt, - self.dataset_type, - additional_row_fields, - ) - if self.dataset_type.has_multi_allelic_variants: - # NB: throws SeqrValidationError - mt = split_multi_hts(mt, self.skip_validation) - # Special handling of variant-level filter annotation for VETs filters. - # The annotations are present on the sample-level FT field but are - # expected upstream on "filters". - mt = annotate_vets(mt) - return mt.select_globals( - callset_path=self.callset_path, - filters_path=filters_path or hl.missing(hl.tstr), - ) - except SeqrValidationError as e: - write_validation_errors_for_run_task = self.clone( - WriteValidationErrorsForRunTask, - error_messages=[str(e)], - ) - write_validation_errors_for_run_task.run() - raise SeqrValidationError( - write_validation_errors_for_run_task.to_single_error_message(), - ) from e + mt = split_multi_hts(mt, self.skip_validation) + # Special handling of variant-level filter annotation for VETs filters. + # The annotations are present on the sample-level FT field but are + # expected upstream on "filters". + mt = annotate_vets(mt) + return mt.select_globals( + callset_path=self.callset_path, + filters_path=filters_path or hl.missing(hl.tstr), + ) diff --git a/v03_pipeline/lib/tasks/write_remapped_and_subsetted_callset.py b/v03_pipeline/lib/tasks/write_remapped_and_subsetted_callset.py index f4c934662..f099dcfcc 100644 --- a/v03_pipeline/lib/tasks/write_remapped_and_subsetted_callset.py +++ b/v03_pipeline/lib/tasks/write_remapped_and_subsetted_callset.py @@ -2,7 +2,6 @@ import luigi import luigi.util -from v03_pipeline.lib.logger import get_logger from v03_pipeline.lib.misc.family_loading_failures import ( get_families_failed_missing_samples, get_families_failed_relatedness_check, @@ -16,7 +15,8 @@ ) from v03_pipeline.lib.misc.pedigree import parse_pedigree_ht_to_families from v03_pipeline.lib.misc.sample_ids import remap_sample_ids, subset_samples -from v03_pipeline.lib.model.environment import Env +from v03_pipeline.lib.misc.validation import SeqrValidationError +from v03_pipeline.lib.model.feature_flag import FeatureFlag from v03_pipeline.lib.paths import ( relatedness_check_table_path, remapped_and_subsetted_callset_path, @@ -29,8 +29,19 @@ WriteRelatednessCheckTsvTask, ) from v03_pipeline.lib.tasks.write_sex_check_table import WriteSexCheckTableTask +from v03_pipeline.lib.tasks.write_validation_errors_for_run import ( + with_persisted_validation_errors, +) + -logger = get_logger(__name__) +def format_failures(failed_families): + return { + f.family_guid: { + 'samples': sorted(f.samples.keys()), + 'reasons': reasons, + } + for f, reasons in failed_families.items() + } @luigi.util.inherits(BaseLoadingRunParams) @@ -62,7 +73,7 @@ def requires(self) -> list[luigi.Task]: RawFileTask(self.project_pedigree_paths[self.project_i]), ] if ( - Env.CHECK_SEX_AND_RELATEDNESS + FeatureFlag.CHECK_SEX_AND_RELATEDNESS and self.dataset_type.check_sex_and_relatedness and not self.skip_check_sex_and_relatedness ): @@ -73,6 +84,7 @@ def requires(self) -> list[luigi.Task]: ] return requirements + @with_persisted_validation_errors def create_table(self) -> hl.MatrixTable: callset_mt = hl.read_matrix_table(self.input()[0].path) pedigree_ht = import_pedigree(self.input()[1].path) @@ -98,7 +110,7 @@ def create_table(self) -> hl.MatrixTable: families_failed_relatedness_check = {} families_failed_sex_check = {} if ( - Env.CHECK_SEX_AND_RELATEDNESS + FeatureFlag.CHECK_SEX_AND_RELATEDNESS and self.dataset_type.check_sex_and_relatedness and not self.skip_check_sex_and_relatedness ): @@ -130,16 +142,21 @@ def create_table(self) -> hl.MatrixTable: - families_failed_sex_check.keys() ) if not len(loadable_families): - msg = ( - f'families_failed_missing_samples: {families_failed_missing_samples}\n' - f'families_failed_relatedness_check: {families_failed_relatedness_check}\n' - f'families_failed_sex_check: {families_failed_sex_check}' - ) - logger.info( + msg = 'All families failed validation checks' + raise SeqrValidationError( msg, + { + 'failed_family_samples': { + 'missing_samples': format_failures( + families_failed_missing_samples, + ), + 'relatedness_check': format_failures( + families_failed_relatedness_check, + ), + 'sex_check': format_failures(families_failed_sex_check), + }, + }, ) - msg = 'All families failed checks' - raise RuntimeError(msg) mt = subset_samples( callset_mt, @@ -172,33 +189,15 @@ def create_table(self) -> hl.MatrixTable: ), failed_family_samples=hl.Struct( missing_samples=( - { - f.family_guid: { - 'samples': sorted(f.samples.keys()), - 'reasons': reasons, - } - for f, reasons in families_failed_missing_samples.items() - } + format_failures(families_failed_missing_samples) or hl.empty_dict(hl.tstr, hl.tdict(hl.tstr, hl.tarray(hl.tstr))) ), relatedness_check=( - { - f.family_guid: { - 'samples': sorted(f.samples.keys()), - 'reasons': reasons, - } - for f, reasons in families_failed_relatedness_check.items() - } + format_failures(families_failed_relatedness_check) or hl.empty_dict(hl.tstr, hl.tdict(hl.tstr, hl.tarray(hl.tstr))) ), sex_check=( - { - f.family_guid: { - 'samples': sorted(f.samples.keys()), - 'reasons': reasons, - } - for f, reasons in families_failed_sex_check.items() - } + format_failures(families_failed_sex_check) or hl.empty_dict(hl.tstr, hl.tdict(hl.tstr, hl.tarray(hl.tstr))) ), ), diff --git a/v03_pipeline/lib/tasks/write_remapped_and_subsetted_callset_test.py b/v03_pipeline/lib/tasks/write_remapped_and_subsetted_callset_test.py index 4a0c84660..a5ed24799 100644 --- a/v03_pipeline/lib/tasks/write_remapped_and_subsetted_callset_test.py +++ b/v03_pipeline/lib/tasks/write_remapped_and_subsetted_callset_test.py @@ -1,3 +1,4 @@ +import json import shutil from unittest.mock import Mock, patch @@ -10,12 +11,16 @@ from v03_pipeline.lib.tasks.write_remapped_and_subsetted_callset import ( WriteRemappedAndSubsettedCallsetTask, ) +from v03_pipeline.lib.tasks.write_validation_errors_for_run import ( + WriteValidationErrorsForRunTask, +) from v03_pipeline.lib.test.mocked_dataroot_testcase import MockedDatarootTestCase TEST_VCF = 'v03_pipeline/var/test/callsets/1kg_30variants.vcf' TEST_REMAP = 'v03_pipeline/var/test/remaps/test_remap_1.tsv' TEST_PEDIGREE_3 = 'v03_pipeline/var/test/pedigrees/test_pedigree_3.tsv' TEST_PEDIGREE_4 = 'v03_pipeline/var/test/pedigrees/test_pedigree_4.tsv' +TEST_PEDIGREE_7 = 'v03_pipeline/var/test/pedigrees/test_pedigree_7.tsv' TEST_SEX_CHECK_1 = 'v03_pipeline/var/test/sex_check/test_sex_check_1.ht' TEST_RELATEDNESS_CHECK_1 = ( 'v03_pipeline/var/test/relatedness_check/test_relatedness_check_1.ht' @@ -115,12 +120,12 @@ def test_write_remapped_and_subsetted_callset_task( ], ) - @patch('v03_pipeline.lib.tasks.write_remapped_and_subsetted_callset.Env') + @patch('v03_pipeline.lib.tasks.write_remapped_and_subsetted_callset.FeatureFlag') def test_write_remapped_and_subsetted_callset_task_failed_sex_check_family( self, - mock_env: Mock, + mock_ff: Mock, ) -> None: - mock_env.CHECK_SEX_AND_RELATEDNESS = True + mock_ff.CHECK_SEX_AND_RELATEDNESS = True worker = luigi.worker.Worker() wrsc_task = WriteRemappedAndSubsettedCallsetTask( reference_genome=ReferenceGenome.GRCh38, @@ -179,3 +184,159 @@ def test_write_remapped_and_subsetted_callset_task_failed_sex_check_family( ), ], ) + + @patch('v03_pipeline.lib.tasks.write_remapped_and_subsetted_callset.FeatureFlag') + def test_write_remapped_and_subsetted_callset_task_all_families_failed( + self, + mock_ff: Mock, + ) -> None: + mock_ff.CHECK_SEX_AND_RELATEDNESS = True + worker = luigi.worker.Worker() + wrsc_task = WriteRemappedAndSubsettedCallsetTask( + reference_genome=ReferenceGenome.GRCh38, + dataset_type=DatasetType.SNV_INDEL, + run_id=TEST_RUN_ID, + sample_type=SampleType.WGS, + callset_path=TEST_VCF, + project_guids=['R0114_project4'], + project_remap_paths=[TEST_REMAP], + project_pedigree_paths=[TEST_PEDIGREE_7], + project_i=0, + skip_validation=True, + ) + worker.add(wrsc_task) + worker.run() + self.assertFalse(wrsc_task.complete()) + write_validation_errors_task = WriteValidationErrorsForRunTask( + reference_genome=ReferenceGenome.GRCh38, + dataset_type=DatasetType.SNV_INDEL, + sample_type=SampleType.WES, + callset_path=TEST_VCF, + project_guids=['R0114_project4'], + skip_validation=True, + run_id=TEST_RUN_ID, + ) + self.assertTrue(write_validation_errors_task.complete()) + with write_validation_errors_task.output().open('r') as f: + self.assertDictEqual( + json.load(f), + { + 'project_guids': [ + 'R0114_project4', + ], + 'error_messages': [ + 'All families failed validation checks', + ], + 'failed_family_samples': { + 'missing_samples': { + 'efg_1': { + 'samples': [ + 'NA99999_1', + ], + 'reasons': [ + "Missing samples: {'NA99999_1'}", + ], + }, + }, + 'relatedness_check': {}, + 'sex_check': { + '789_1': { + 'samples': [ + 'NA20875_1', + ], + 'reasons': [ + 'Sample NA20875_1 has pedigree sex M but imputed sex F', + ], + }, + '456_1': { + 'samples': [ + 'NA20870_1', + ], + 'reasons': [ + 'Sample NA20870_1 has pedigree sex M but imputed sex F', + ], + }, + '123_1': { + 'samples': [ + 'NA19675_1', + ], + 'reasons': [ + 'Sample NA19675_1 has pedigree sex M but imputed sex F', + ], + }, + 'cde_1': { + 'samples': [ + 'NA20881_1', + ], + 'reasons': [ + 'Sample NA20881_1 has pedigree sex F but imputed sex M', + ], + }, + '901_1': { + 'samples': [ + 'NA20877_1', + ], + 'reasons': [ + 'Sample NA20877_1 has pedigree sex M but imputed sex F', + ], + }, + '678_1': { + 'samples': [ + 'NA20874_1', + ], + 'reasons': [ + 'Sample NA20874_1 has pedigree sex M but imputed sex F', + ], + }, + '345_1': { + 'samples': [ + 'NA19679_1', + ], + 'reasons': [ + 'Sample NA19679_1 has pedigree sex M but imputed sex F', + ], + }, + '890_1': { + 'samples': [ + 'NA20876_1', + ], + 'reasons': [ + 'Sample NA20876_1 has pedigree sex M but imputed sex F', + ], + }, + 'def_1': { + 'samples': [ + 'NA20885_1', + ], + 'reasons': [ + 'Sample NA20885_1 has pedigree sex M but imputed sex F', + ], + }, + '234_1': { + 'samples': [ + 'NA19678_1', + ], + 'reasons': [ + 'Sample NA19678_1 has pedigree sex F but imputed sex M', + ], + }, + 'bcd_1': { + 'samples': [ + 'NA20878_1', + ], + 'reasons': [ + 'Sample NA20878_1 has pedigree sex F but imputed sex M', + ], + }, + '567_1': { + 'samples': [ + 'NA20872_1', + ], + 'reasons': [ + 'Sample NA20872_1 has pedigree sex F but imputed sex M', + ], + }, + }, + }, + }, + ) diff --git a/v03_pipeline/lib/tasks/write_validation_errors_for_run.py b/v03_pipeline/lib/tasks/write_validation_errors_for_run.py index 9149f6158..d99800f6f 100644 --- a/v03_pipeline/lib/tasks/write_validation_errors_for_run.py +++ b/v03_pipeline/lib/tasks/write_validation_errors_for_run.py @@ -1,8 +1,11 @@ import json +from collections.abc import Callable import luigi +import luigi.freezing import luigi.util +from v03_pipeline.lib.misc.validation import SeqrValidationError from v03_pipeline.lib.paths import validation_errors_for_run_path from v03_pipeline.lib.tasks.base.base_loading_run_params import BaseLoadingRunParams from v03_pipeline.lib.tasks.files import GCSorLocalTarget @@ -12,6 +15,7 @@ class WriteValidationErrorsForRunTask(luigi.Task): project_guids = luigi.ListParameter() error_messages = luigi.ListParameter(default=[]) + failed_family_samples = luigi.DictParameter(default={}) def to_single_error_message(self) -> str: with self.output().open('r') as f: @@ -33,6 +37,36 @@ def run(self) -> None: validation_errors_json = { 'project_guids': self.project_guids, 'error_messages': self.error_messages, + 'failed_family_samples': luigi.freezing.recursively_unfreeze( + self.failed_family_samples, + ), } with self.output().open('w') as f: json.dump(validation_errors_json, f) + + +def with_persisted_validation_errors(f: Callable) -> Callable[[Callable], Callable]: + def wrapper(self: luigi.Task): + try: + return f(self) + except SeqrValidationError as e: + if isinstance( + e.args[1], + object, + ): # TODO: improve type checking with a pydantic model/typed dict + write_validation_errors_for_run_task = self.clone( + WriteValidationErrorsForRunTask, + error_messages=[str(e.args[0])], + failed_family_samples=e.args[1]['failed_family_samples'], + ) + else: + write_validation_errors_for_run_task = self.clone( + WriteValidationErrorsForRunTask, + error_messages=[str(e)], + ) + write_validation_errors_for_run_task.run() + raise SeqrValidationError( + write_validation_errors_for_run_task.to_single_error_message(), + ) from None + + return wrapper diff --git a/v03_pipeline/var/test/pedigrees/test_pedigree_7.tsv b/v03_pipeline/var/test/pedigrees/test_pedigree_7.tsv new file mode 100644 index 000000000..fee3a7458 --- /dev/null +++ b/v03_pipeline/var/test/pedigrees/test_pedigree_7.tsv @@ -0,0 +1,14 @@ +Project_GUID Family_GUID Family_ID Individual_ID Paternal_ID Maternal_ID Sex +R0114_project4 123_1 123 NA19675_1 M +R0114_project4 234_1 234 NA19678_1 F +R0114_project4 345_1 345 NA19679_1 M +R0114_project4 456_1 456 NA20870_1 M +R0114_project4 567_1 567 NA20872_1 F +R0114_project4 678_1 678 NA20874_1 M +R0114_project4 789_1 789 NA20875_1 M +R0114_project4 890_1 890 NA20876_1 M +R0114_project4 901_1 901 NA20877_1 M +R0114_project4 bcd_1 bcd NA20878_1 F +R0114_project4 cde_1 cde NA20881_1 F +R0114_project4 def_1 def NA20885_1 M +R0114_project4 efg_1 efg NA99999_1 F