Skip to content

Commit c56aa8f

Browse files
authored
refactor + chore: shuffle end of pipeline steps to enable dataproc for local users (#1009)
* Support gcs dirs in rsync * ws * Add create dataproc cluster task * add dataproc * ruff * requirements * still struggling * Gencode refactor to remove gcs * bump reqs * Run dataproc job * lib * running * merge requirements * Flip'em * Better exception handling * Cleaner approach if less generalizable * write a test * Fix tests * lint * Add test for success * refactor to use a base class... better for adding support for multiple jobs * cleanup * ruff * Fix missing mock * Fix flapping test * first commit * Finish test and cleanup * Allow any order * First commit * ruff * ruff * finish off * A few minor tweaks
1 parent db64b2d commit c56aa8f

12 files changed

+99
-72
lines changed

v03_pipeline/lib/model/feature_flag.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
INCLUDE_PIPELINE_VERSION_IN_PREFIX = (
1212
os.environ.get('INCLUDE_PIPELINE_VERSION_IN_PREFIX') == '1'
1313
)
14+
RUN_PIPELINE_ON_DATAPROC = os.environ.get('RUN_PIPELINE_ON_DATAPROC') == '1'
1415
SHOULD_TRIGGER_HAIL_BACKEND_RELOAD = (
1516
os.environ.get('SHOULD_TRIGGER_HAIL_BACKEND_RELOAD') == '1'
1617
)
@@ -23,4 +24,5 @@ class FeatureFlag:
2324
EXPECT_TDR_METRICS: bool = EXPECT_TDR_METRICS
2425
EXPECT_WES_FILTERS: bool = EXPECT_WES_FILTERS
2526
INCLUDE_PIPELINE_VERSION_IN_PREFIX: bool = INCLUDE_PIPELINE_VERSION_IN_PREFIX
27+
RUN_PIPELINE_ON_DATAPROC: bool = RUN_PIPELINE_ON_DATAPROC
2628
SHOULD_TRIGGER_HAIL_BACKEND_RELOAD: bool = SHOULD_TRIGGER_HAIL_BACKEND_RELOAD

v03_pipeline/lib/tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from v03_pipeline.lib.tasks.reference_data.update_variant_annotations_table_with_updated_reference_dataset import (
1212
UpdateVariantAnnotationsTableWithUpdatedReferenceDataset,
1313
)
14+
from v03_pipeline.lib.tasks.run_pipeline import RunPipelineTask
1415
from v03_pipeline.lib.tasks.update_lookup_table import (
1516
UpdateLookupTableTask,
1617
)
@@ -46,6 +47,7 @@
4647
'DeleteProjectTablesTask',
4748
'MigrateAllLookupTablesTask',
4849
'MigrateAllVariantAnnotationsTablesTask',
50+
'RunPipelineTask',
4951
'UpdateProjectTableTask',
5052
'UpdateProjectTablesWithDeletedFamiliesTask',
5153
'UpdateLookupTableTask',

v03_pipeline/lib/tasks/dataproc/base_run_job_on_dataproc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ def __init__(self, *args, **kwargs):
3333
)
3434

3535
@property
36-
def task_name(self):
37-
return self.get_task_family().split('.')[-1]
36+
def task(self):
37+
raise NotImplementedError
3838

3939
@property
4040
def job_id(self):
41-
return f'{self.task_name}-{self.run_id}'
41+
return f'{self.task.task_family}-{self.run_id}'
4242

4343
def requires(self) -> [luigi.Task]:
4444
return [self.clone(CreateDataprocClusterTask)]
@@ -58,7 +58,7 @@ def complete(self) -> bool:
5858
except google.api_core.exceptions.NotFound:
5959
return False
6060
if job.status.state == ERROR_STATE:
61-
msg = f'Job {self.task_name}-{self.run_id} entered ERROR state'
61+
msg = f'Job {self.task.task_family}-{self.run_id} entered ERROR state'
6262
logger.error(msg)
6363
logger.error(job.status.details)
6464
return job.status.state == DONE_STATE
@@ -81,7 +81,7 @@ def run(self):
8181
'pyspark_job': {
8282
'main_python_file_uri': f'{SEQR_PIPELINE_RUNNER_BUILD}/bin/run_task.py',
8383
'args': [
84-
self.task_name,
84+
self.task.task_family,
8585
'--local-scheduler',
8686
*to_kebab_str_args(self),
8787
],

v03_pipeline/lib/tasks/dataproc/misc_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
from v03_pipeline.lib.model import DatasetType, ReferenceGenome, SampleType
55
from v03_pipeline.lib.tasks.dataproc.misc import to_kebab_str_args
6-
from v03_pipeline.lib.tasks.dataproc.write_success_file_on_dataproc import (
7-
WriteSuccessFileOnDataprocTask,
6+
from v03_pipeline.lib.tasks.dataproc.rsync_to_seqr_app_dirs import (
7+
RsyncToSeqrAppDirsTask,
88
)
99

1010

@@ -13,7 +13,7 @@
1313
)
1414
class MiscTest(unittest.TestCase):
1515
def test_to_kebab_str_args(self, _: Mock):
16-
t = WriteSuccessFileOnDataprocTask(
16+
t = RsyncToSeqrAppDirsTask(
1717
reference_genome=ReferenceGenome.GRCh38,
1818
dataset_type=DatasetType.SNV_INDEL,
1919
sample_type=SampleType.WGS,

v03_pipeline/lib/tasks/dataproc/rsync_to_seqr_app_dirs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from v03_pipeline.lib.tasks.base.base_loading_run_params import (
1010
BaseLoadingRunParams,
1111
)
12+
from v03_pipeline.lib.tasks.dataproc.run_pipeline_on_dataproc import (
13+
RunPipelineOnDataprocTask,
14+
)
1215

1316

1417
def hail_search_value(value: str) -> str:
@@ -38,6 +41,9 @@ def output(self) -> None:
3841
def complete(self) -> bool:
3942
return self.done
4043

44+
def requires(self) -> luigi.Task:
45+
return self.clone(RunPipelineOnDataprocTask)
46+
4147
def run(self) -> None:
4248
if not (
4349
Env.SEQR_APP_HAIL_SEARCH_DATA_DIR and Env.SEQR_APP_REFERENCE_DATASETS_DIR

v03_pipeline/lib/tasks/dataproc/rsync_to_seqr_app_dirs_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,20 @@
1313
from v03_pipeline.lib.tasks.dataproc.rsync_to_seqr_app_dirs import (
1414
RsyncToSeqrAppDirsTask,
1515
)
16+
from v03_pipeline.lib.test.mock_complete_task import MockCompleteTask
1617

1718

1819
class RsyncToSeqrAppDirsTaskTest(unittest.TestCase):
20+
@patch(
21+
'v03_pipeline.lib.tasks.dataproc.rsync_to_seqr_app_dirs.RunPipelineOnDataprocTask',
22+
)
1923
@patch('v03_pipeline.lib.tasks.dataproc.rsync_to_seqr_app_dirs.subprocess')
2024
def test_rsync_to_seqr_app_dirs_no_sync(
2125
self,
2226
mock_subprocess: Mock,
27+
mock_run_pipeline_task: Mock,
2328
) -> None:
29+
mock_run_pipeline_task.return_value = MockCompleteTask()
2430
worker = luigi.worker.Worker()
2531
task = RsyncToSeqrAppDirsTask(
2632
reference_genome=ReferenceGenome.GRCh38,
@@ -37,6 +43,9 @@ def test_rsync_to_seqr_app_dirs_no_sync(
3743
self.assertTrue(task.complete())
3844
mock_subprocess.call.assert_not_called()
3945

46+
@patch(
47+
'v03_pipeline.lib.tasks.dataproc.rsync_to_seqr_app_dirs.RunPipelineOnDataprocTask',
48+
)
4049
@patch('v03_pipeline.lib.tasks.dataproc.rsync_to_seqr_app_dirs.subprocess')
4150
@patch.object(Env, 'HAIL_SEARCH_DATA_DIR', 'gs://test-hail-search-dir')
4251
@patch.object(Env, 'REFERENCE_DATASETS_DIR', 'gs://test-reference-data-dir')
@@ -58,7 +67,9 @@ def test_rsync_to_seqr_app_dirs_no_sync(
5867
def test_rsync_to_seqr_app_dirs_sync(
5968
self,
6069
mock_subprocess: Mock,
70+
mock_run_pipeline_task: Mock,
6171
) -> None:
72+
mock_run_pipeline_task.return_value = MockCompleteTask()
6273
worker = luigi.worker.Worker()
6374
task = RsyncToSeqrAppDirsTask(
6475
reference_genome=ReferenceGenome.GRCh38,
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import luigi
2+
3+
from v03_pipeline.lib.tasks.base.base_loading_run_params import (
4+
BaseLoadingRunParams,
5+
)
6+
from v03_pipeline.lib.tasks.dataproc.base_run_job_on_dataproc import (
7+
BaseRunJobOnDataprocTask,
8+
)
9+
from v03_pipeline.lib.tasks.run_pipeline import RunPipelineTask
10+
11+
12+
@luigi.util.inherits(BaseLoadingRunParams)
13+
class RunPipelineOnDataprocTask(BaseRunJobOnDataprocTask):
14+
@property
15+
def task(self) -> luigi.Task:
16+
return RunPipelineTask

v03_pipeline/lib/tasks/dataproc/write_success_file_on_dataproc_test.py renamed to v03_pipeline/lib/tasks/dataproc/run_pipeline_on_dataproc_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import luigi
77

88
from v03_pipeline.lib.model import DatasetType, ReferenceGenome, SampleType
9-
from v03_pipeline.lib.tasks.dataproc.write_success_file_on_dataproc import (
10-
WriteSuccessFileOnDataprocTask,
9+
from v03_pipeline.lib.tasks.dataproc.run_pipeline_on_dataproc import (
10+
RunPipelineOnDataprocTask,
1111
)
1212
from v03_pipeline.lib.test.mock_complete_task import MockCompleteTask
1313

@@ -38,7 +38,7 @@ def test_job_already_exists_failed(
3838
google.api_core.exceptions.AlreadyExists('job exists')
3939
)
4040
worker = luigi.worker.Worker()
41-
task = WriteSuccessFileOnDataprocTask(
41+
task = RunPipelineOnDataprocTask(
4242
reference_genome=ReferenceGenome.GRCh38,
4343
dataset_type=DatasetType.SNV_INDEL,
4444
sample_type=SampleType.WGS,
@@ -54,7 +54,7 @@ def test_job_already_exists_failed(
5454
mock_logger.error.assert_has_calls(
5555
[
5656
call(
57-
'Job WriteSuccessFileOnDataprocTask-manual__2024-04-03 entered ERROR state',
57+
'Job RunPipelineTask-manual__2024-04-03 entered ERROR state',
5858
),
5959
],
6060
)
@@ -70,7 +70,7 @@ def test_job_already_exists_success(
7070
status=SimpleNamespace(state='DONE'),
7171
)
7272
worker = luigi.worker.Worker()
73-
task = WriteSuccessFileOnDataprocTask(
73+
task = RunPipelineOnDataprocTask(
7474
reference_genome=ReferenceGenome.GRCh38,
7575
dataset_type=DatasetType.SNV_INDEL,
7676
sample_type=SampleType.WGS,
@@ -102,7 +102,7 @@ def test_job_failed(
102102
'FailedPrecondition: 400 Job failed with message',
103103
)
104104
worker = luigi.worker.Worker()
105-
task = WriteSuccessFileOnDataprocTask(
105+
task = RunPipelineOnDataprocTask(
106106
reference_genome=ReferenceGenome.GRCh38,
107107
dataset_type=DatasetType.SNV_INDEL,
108108
sample_type=SampleType.WGS,
@@ -118,7 +118,7 @@ def test_job_failed(
118118
mock_logger.info.assert_has_calls(
119119
[
120120
call(
121-
'Waiting for job completion WriteSuccessFileOnDataprocTask-manual__2024-04-05',
121+
'Waiting for job completion RunPipelineTask-manual__2024-04-05',
122122
),
123123
],
124124
)
@@ -141,7 +141,7 @@ def test_job_success(
141141
operation = mock_client.submit_job_as_operation.return_value
142142
operation.done.side_effect = [False, True]
143143
worker = luigi.worker.Worker()
144-
task = WriteSuccessFileOnDataprocTask(
144+
task = RunPipelineOnDataprocTask(
145145
reference_genome=ReferenceGenome.GRCh38,
146146
dataset_type=DatasetType.SNV_INDEL,
147147
sample_type=SampleType.WGS,

v03_pipeline/lib/tasks/dataproc/write_success_file_on_dataproc.py

Lines changed: 0 additions & 22 deletions
This file was deleted.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import luigi
2+
import luigi.util
3+
4+
from v03_pipeline.lib.tasks.base.base_loading_run_params import (
5+
BaseLoadingRunParams,
6+
)
7+
from v03_pipeline.lib.tasks.update_variant_annotations_table_with_new_samples import (
8+
UpdateVariantAnnotationsTableWithNewSamplesTask,
9+
)
10+
from v03_pipeline.lib.tasks.write_metadata_for_run import WriteMetadataForRunTask
11+
from v03_pipeline.lib.tasks.write_project_family_tables import (
12+
WriteProjectFamilyTablesTask,
13+
)
14+
15+
16+
@luigi.util.inherits(BaseLoadingRunParams)
17+
class RunPipelineTask(luigi.WrapperTask):
18+
def requires(self):
19+
requirements = [
20+
self.clone(WriteMetadataForRunTask),
21+
self.clone(UpdateVariantAnnotationsTableWithNewSamplesTask),
22+
]
23+
return [
24+
*requirements,
25+
*[
26+
self.clone(
27+
WriteProjectFamilyTablesTask,
28+
project_i=i,
29+
)
30+
for i in range(len(self.project_guids))
31+
],
32+
]

0 commit comments

Comments
 (0)