Skip to content

Commit 801fc32

Browse files
meyerkmPMBio
and
PMBio
authored
Create PyTests for config file setups (#130)
* incorporate config for train-only setup update quick start * Update docs to reflect config updates * add pytests for deeprvat_config generation * add pretrained_models dir link to pytest * bugfix smoke test base path * bugfix smoke test base path * fixup! Format Python code with psf/black pull_request * update alpha param from config * add in gene-file for train-only pipeline * fixup! Format Python code with psf/black pull_request --------- Co-authored-by: PMBio <PMBio@users.noreply.github.com>
1 parent d02b759 commit 801fc32

File tree

6 files changed

+297
-126
lines changed

6 files changed

+297
-126
lines changed

.github/workflows/pipeline-tests.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,29 @@ run-name: DeepRVAT Pipeline Tests 🧬🧪💻🧑‍🔬
33
on: [ push ]
44

55
jobs:
6+
# Config Setup
7+
Smoke-GenerateConfig-Training:
8+
uses: ./.github/workflows/run-pipeline.yml
9+
with:
10+
pipeline_file: ./pipelines/run_training.snakefile
11+
environment_file: ./deeprvat_env_no_gpu.yml
12+
prerun_cmd: cp ./example/config/deeprvat_input_training_config.yaml ./example/
13+
14+
Smoke-GenerateConfig-Training-AssociationTesting:
15+
uses: ./.github/workflows/run-pipeline.yml
16+
with:
17+
pipeline_file: ./pipelines/training_association_testing.snakefile
18+
environment_file: ./deeprvat_env_no_gpu.yml
19+
prerun_cmd: cp ./example/config/deeprvat_input_config.yaml ./example/
20+
21+
Smoke-GenerateConfig-PreTrained:
22+
uses: ./.github/workflows/run-pipeline.yml
23+
with:
24+
pipeline_file: ./pipelines/association_testing_pretrained.snakefile
25+
environment_file: ./deeprvat_env_no_gpu.yml
26+
prerun_cmd: cp ./example/config/deeprvat_input_pretrained_models_config.yaml ./example/ && ln -s $GITHUB_WORKSPACE/pretrained_models ./example/
27+
28+
629
# Training Pipeline
730
Smoke-RunTraining:
831
uses: ./.github/workflows/run-pipeline.yml

deeprvat/deeprvat/config.py

Lines changed: 139 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,23 @@ def create_main_config(
8383
"regenie_options",
8484
]
8585

86+
# Check if Training Only
87+
if input_config.get("training_only", False):
88+
train_only = True
89+
to_remove = {
90+
"phenotypes_for_association_testing",
91+
"association_testing_data_thresholds",
92+
"evaluation",
93+
}
94+
expected_input_keys = [
95+
item for item in expected_input_keys if item not in to_remove
96+
]
97+
input_config.pop("training_only", None)
98+
else:
99+
train_only = False
100+
86101
# CV setup parameters
87-
if not input_config["cv_options"]["cv_exp"]:
88-
logger.info("Not CV setup...removing CV pipeline parameters from config")
89-
full_config["cv_exp"] = False
90-
else: # CV experiment setup specified
102+
if input_config.get("cv_options", {}).get("cv_exp", False):
91103
if any(
92104
key not in input_config["cv_options"]
93105
for key in ["cv_exp", "cv_path", "n_folds"]
@@ -99,14 +111,14 @@ def create_main_config(
99111
full_config["cv_path"] = input_config["cv_options"]["cv_path"]
100112
full_config["n_folds"] = input_config["cv_options"]["n_folds"]
101113
full_config["cv_exp"] = True
114+
else:
115+
logger.info("Not CV setup...removing CV pipeline parameters from config")
116+
full_config["cv_exp"] = False
117+
expected_input_keys.remove("cv_options")
118+
input_config.pop("cv_options", None)
102119

103120
# REGENIE setup parameters
104-
if not input_config["regenie_options"]["regenie_exp"]:
105-
logger.info(
106-
"Not using REGENIE integration...removing REGENIE parameters from config"
107-
)
108-
full_config["regenie_exp"] = False
109-
else: # REGENIE integration
121+
if input_config.get("regenie_options", {}).get("regenie_exp", False):
110122
if any(
111123
key not in input_config["regenie_options"]
112124
for key in ["regenie_exp", "step_1", "step_2"]
@@ -124,58 +136,68 @@ def create_main_config(
124136
full_config["regenie_options"]["step_2"] = input_config["regenie_options"][
125137
"step_2"
126138
]
139+
else:
140+
logger.info(
141+
"Not using REGENIE integration...removing REGENIE parameters from config"
142+
)
143+
full_config["regenie_exp"] = False
144+
expected_input_keys.remove("regenie_options")
145+
input_config.pop("regenie_options", None)
127146

128147
no_pretrain = True
129-
if "use_pretrained_models" in input_config:
130-
if input_config["use_pretrained_models"]:
131-
no_pretrain = False
132-
logger.info("Pretrained Model setup specified.")
133-
to_remove = {"training", "phenotypes_for_training", "seed_gene_results"}
134-
expected_input_keys = [
135-
item for item in expected_input_keys if item not in to_remove
136-
]
148+
if input_config.get("use_pretrained_models", False):
149+
no_pretrain = False
150+
logger.info("Pretrained Model setup specified.")
151+
to_remove = {"training", "phenotypes_for_training", "seed_gene_results"}
152+
expected_input_keys = [
153+
item for item in expected_input_keys if item not in to_remove
154+
]
137155

138-
pretrained_model_path = Path(input_config["pretrained_model_path"])
156+
pretrained_model_path = Path(input_config["pretrained_model_path"])
139157

140-
expected_input_keys.extend(
141-
["use_pretrained_models", "model", "pretrained_model_path"]
142-
)
158+
expected_input_keys.extend(
159+
["use_pretrained_models", "model", "pretrained_model_path"]
160+
)
143161

144-
with open(f"{pretrained_model_path}/model_config.yaml") as f:
145-
pretrained_config = yaml.safe_load(f)
162+
with open(f"{pretrained_model_path}/model_config.yaml") as f:
163+
pretrained_config = yaml.safe_load(f)
146164

147-
required_keys = [
148-
"model",
149-
"rare_variant_annotations",
150-
"training_data_thresholds",
151-
]
152-
for k in pretrained_config:
153-
if k not in required_keys:
154-
raise KeyError(
155-
(
156-
f"Unexpected key in pretrained_model_path/model_config.yaml file : {k} "
157-
"Please review DEEPRVAT_DIR/pretrained_models/model_config.yaml for expected list of keys."
158-
)
165+
required_keys = [
166+
"model",
167+
"rare_variant_annotations",
168+
"training_data_thresholds",
169+
]
170+
for k in pretrained_config:
171+
if k not in required_keys:
172+
raise KeyError(
173+
(
174+
f"Unexpected key in pretrained_model_path/model_config.yaml file : {k} "
175+
"Please review DEEPRVAT_DIR/pretrained_models/model_config.yaml for expected list of keys."
159176
)
160-
else:
161-
input_config[k] = deepcopy(pretrained_config[k])
177+
)
178+
else:
179+
input_config[k] = deepcopy(pretrained_config[k])
162180

163181
if no_pretrain and "phenotypes_for_training" not in input_config:
164182
logger.info("Unspecified phenotype list for training.")
165-
logger.info(
166-
" Setting training phenotypes to be the same set as specified by phenotypes_for_association_testing."
167-
)
168-
input_config["phenotypes_for_training"] = input_config[
169-
"phenotypes_for_association_testing"
170-
]
183+
if train_only:
184+
raise KeyError(("Must specify phenotypes_for_training in config file!"))
185+
else:
186+
logger.info(
187+
" Setting training phenotypes to be the same set as specified by phenotypes_for_association_testing."
188+
)
189+
input_config["phenotypes_for_training"] = input_config[
190+
"phenotypes_for_association_testing"
191+
]
171192

172193
if "y_transformation" in input_config:
173194
full_config["training_data"]["dataset_config"]["y_transformation"] = (
174195
input_config["y_transformation"]
175196
)
176-
full_config["association_testing_data"]["dataset_config"][
177-
"y_transformation"
178-
] = input_config["y_transformation"]
197+
if not train_only:
198+
full_config["association_testing_data"]["dataset_config"][
199+
"y_transformation"
200+
] = input_config["y_transformation"]
179201
else:
180202
expected_input_keys.remove("y_transformation")
181203

@@ -186,7 +208,10 @@ def create_main_config(
186208
"Please review DEEPRVAT_DIR/example/config/deeprvat_input_config.yaml for list of keys."
187209
)
188210
)
189-
if "MAF" not in input_config["association_testing_data_thresholds"]:
211+
if (
212+
not train_only
213+
and "MAF" not in input_config["association_testing_data_thresholds"]
214+
):
190215
raise KeyError(
191216
(
192217
"Missing required MAF threshold in config['association_testing_data_thresholds']. "
@@ -223,51 +248,59 @@ def create_main_config(
223248
"Please review DEEPRVAT_DIR/example/config/deeprvat_input_config.yaml for list of keys."
224249
)
225250

226-
# Phenotypes
227-
full_config["phenotypes"] = input_config["phenotypes_for_association_testing"]
228-
# genotypes.h5
229-
full_config["training_data"]["gt_file"] = input_config["gt_filename"]
230-
full_config["association_testing_data"]["gt_file"] = input_config["gt_filename"]
231-
# variants.parquet
232-
full_config["training_data"]["variant_file"] = input_config["variant_filename"]
233-
full_config["association_testing_data"]["variant_file"] = input_config[
251+
full_config["training_data"]["gt_file"] = input_config[
252+
"gt_filename"
253+
] # genotypes.h5
254+
full_config["training_data"]["variant_file"] = input_config[
234255
"variant_filename"
235-
]
236-
# phenotypes.parquet
256+
] # variants.parquet
237257
full_config["training_data"]["dataset_config"]["phenotype_file"] = input_config[
238258
"phenotype_filename"
239-
]
240-
full_config["association_testing_data"]["dataset_config"]["phenotype_file"] = (
241-
input_config["phenotype_filename"]
242-
)
243-
# annotations.parquet
259+
] # phenotypes.parquet
244260
full_config["training_data"]["dataset_config"]["annotation_file"] = input_config[
245261
"annotation_filename"
246-
]
247-
full_config["association_testing_data"]["dataset_config"]["annotation_file"] = (
248-
input_config["annotation_filename"]
249-
)
250-
# protein_coding_genes.parquet
262+
] # annotations.parquet
251263
full_config["association_testing_data"]["dataset_config"]["gene_file"] = (
252264
input_config["gene_filename"]
253-
)
254-
full_config["association_testing_data"]["dataset_config"]["rare_embedding"][
255-
"config"
256-
]["gene_file"] = input_config["gene_filename"]
257-
# rare_variant_annotations
265+
) # protein_coding_genes.parquet
258266
full_config["training_data"]["dataset_config"]["rare_embedding"]["config"][
259267
"annotations"
260-
] = input_config["rare_variant_annotations"]
261-
full_config["association_testing_data"]["dataset_config"]["rare_embedding"][
262-
"config"
263-
]["annotations"] = input_config["rare_variant_annotations"]
264-
# covariates
268+
] = input_config[
269+
"rare_variant_annotations"
270+
] # rare_variant_annotations
265271
full_config["training_data"]["dataset_config"]["x_phenotypes"] = input_config[
266272
"covariates"
267-
]
268-
full_config["association_testing_data"]["dataset_config"]["x_phenotypes"] = (
269-
input_config["covariates"]
270-
)
273+
] # covariates
274+
if not train_only:
275+
full_config["phenotypes"] = input_config[
276+
"phenotypes_for_association_testing"
277+
] # Phenotypes
278+
full_config["association_testing_data"]["gt_file"] = input_config[
279+
"gt_filename"
280+
] # genotypes.h5
281+
full_config["association_testing_data"]["variant_file"] = input_config[
282+
"variant_filename"
283+
] # variants.parquet
284+
full_config["association_testing_data"]["dataset_config"]["phenotype_file"] = (
285+
input_config["phenotype_filename"]
286+
) # phenotypes.parquet
287+
full_config["association_testing_data"]["dataset_config"]["annotation_file"] = (
288+
input_config["annotation_filename"]
289+
) # annotations.parquet
290+
full_config["association_testing_data"]["dataset_config"]["rare_embedding"][
291+
"config"
292+
]["gene_file"] = input_config[
293+
"gene_filename"
294+
] # protein_coding_genes.parquet
295+
full_config["association_testing_data"]["dataset_config"]["rare_embedding"][
296+
"config"
297+
]["annotations"] = input_config[
298+
"rare_variant_annotations"
299+
] # rare_variant_annotations
300+
full_config["association_testing_data"]["dataset_config"]["x_phenotypes"] = (
301+
input_config["covariates"] # covariates
302+
)
303+
271304
# Thresholds & variant annotations
272305
anno_list = deepcopy(input_config["rare_variant_annotations"])
273306
full_config["training_data"]["dataset_config"]["rare_embedding"]["config"][
@@ -280,29 +313,29 @@ def create_main_config(
280313
][k] = f"{k} {v}"
281314
training_anno_list.insert(i + 1, k)
282315
full_config["training_data"]["dataset_config"]["annotations"] = training_anno_list
283-
284-
full_config["association_testing_data"]["dataset_config"]["rare_embedding"][
285-
"config"
286-
]["thresholds"] = {}
287-
association_anno_list = deepcopy(anno_list)
288-
for i, (k, v) in enumerate(
289-
input_config["association_testing_data_thresholds"].items()
290-
):
316+
if not train_only:
291317
full_config["association_testing_data"]["dataset_config"]["rare_embedding"][
292318
"config"
293-
]["thresholds"][k] = f"{k} {v}"
294-
association_anno_list.insert(i + 1, k)
295-
full_config["association_testing_data"]["dataset_config"][
296-
"annotations"
297-
] = association_anno_list
319+
]["thresholds"] = {}
320+
association_anno_list = deepcopy(anno_list)
321+
for i, (k, v) in enumerate(
322+
input_config["association_testing_data_thresholds"].items()
323+
):
324+
full_config["association_testing_data"]["dataset_config"]["rare_embedding"][
325+
"config"
326+
]["thresholds"][k] = f"{k} {v}"
327+
association_anno_list.insert(i + 1, k)
328+
full_config["association_testing_data"]["dataset_config"][
329+
"annotations"
330+
] = association_anno_list
331+
# Results evaluation parameters; alpha parameter for significance threshold
332+
if "evaluation" not in full_config:
333+
full_config["evaluation"] = {}
334+
full_config["evaluation"]["correction_method"] = input_config["evaluation"][
335+
"correction_method"
336+
]
337+
full_config["evaluation"]["alpha"] = input_config["evaluation"]["alpha"]
298338

299-
# Results evaluation parameters; alpha parameter for significance threshold
300-
if "evaluation" not in full_config:
301-
full_config["evaluation"] = {}
302-
full_config["evaluation"]["correction_method"] = input_config["evaluation"][
303-
"correction_method"
304-
]
305-
full_config["evaluation"]["alpha"] = input_config["evaluation"]["alpha"]
306339
# DeepRVAT model
307340
full_config["n_repeats"] = input_config["n_repeats"]
308341

@@ -585,9 +618,10 @@ def update_config(
585618
else:
586619
logger.info("Not performing EAC filtering of baseline results")
587620
logger.info(f" Correcting p-values using {correction_method} method")
588-
alpha = config["baseline_results"].get(
589-
"alpha_seed_genes", config["evaluation"].get("alpha")
590-
)
621+
if config["baseline_results"].get("alpha_seed_genes", False):
622+
alpha = config["baseline_results"]["alpha_seed_genes"]
623+
else:
624+
alpha = config["evaluation"].get("alpha")
591625
baseline_df = pval_correction(
592626
baseline_df, alpha, correction_type=correction_method
593627
)

docs/input_data.md

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,12 @@ Configuration for all pipelines is specified in the file `deeprvat_input_config.
77

88
In the following, we describe the parameters (both optional and required) that can be specified in the `deeprvat_input_config.yaml` by way of an [example file](https://github.com/PMBio/deeprvat/blob/main/example/config/deeprvat_input_config.yaml), which we explain block by block.
99

10-
```
11-
deeprvat_repo_dir: ../..
12-
```
13-
14-
_Required._ This specifies the path to your copy of the DeepRVAT repository.
15-
1610
```
1711
use_pretrained_models: True
18-
pretrained_model_path : ../../pretrained_models
12+
pretrained_model_path : pretrained_models
1913
```
2014

21-
These parameters are relevant when using pretrained models. `use_pretrained_models` defaults to `False` if not specified.
15+
These parameters are relevant when using pretrained models. `use_pretrained_models` defaults to `False` if not specified. Update the `pretrained_model_path` to the path where the `pretrained_models` directory is, if not in the same folder as your current experiment directory.
2216

2317
```
2418
phenotypes_for_association_testing:

docs/quickstart.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@ snakemake -j 1 --snakefile [path_to_deeprvat]/pipelines/association_testing_pret
5252
### Run the training pipeline on some example data
5353

5454
```shell
55+
DEEPRVAT_REPO_PATH="[path_to_deeprvat]"
5556
mkdir deeprvat_train
5657
cd deeprvat_train
57-
ln -s [path_to_deeprvat]/example/* .
58-
snakemake -j 1 --snakefile [path_to_deeprvat]/pipelines/run_training.snakefile
58+
ln -s "$DEEPRVAT_REPO_PATH"/example/* .
59+
ln -s config/deeprvat_input_training_config.yaml . #get the corresponding config.
60+
snakemake -j 1 --snakefile "$DEEPRVAT_REPO_PATH"/pipelines/run_training.snakefile
5961
```
6062

6163

0 commit comments

Comments
 (0)