Skip to content

Commit 38c09db

Browse files
meyerkmPMBio
and
PMBio
authored
Add in Snakemake Log-Files (#147)
* remove double logging * add in logging redirct * add in logging redirct -remove logging from train rule (already performed with parallel) - remove params.prefix bug from log paths * bugfix logging redirect for associate.py - add in logging handlers.clear() - define logging level * output additional fields in model_config.yaml , to be used for pretrained_models setup * adding in logging directive to cv pipeline * add in log to final regenie pipeline * fixup! Format Python code with psf/black pull_request --------- Co-authored-by: PMBio <PMBio@users.noreply.github.com>
1 parent 8d8e0cb commit 38c09db

31 files changed

+240
-58
lines changed

deeprvat/cv_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
logging.basicConfig(
1818
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
19-
level="INFO",
19+
level=logging.INFO,
2020
stream=sys.stdout,
2121
)
2222
logger = logging.getLogger(__name__)

deeprvat/data/dense_gt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
logging.basicConfig(
2424
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
25-
level="INFO",
25+
level=logging.INFO,
2626
stream=sys.stdout,
2727
)
2828
logger = logging.getLogger(__name__)

deeprvat/data/rare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
logging.basicConfig(
1414
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
15-
level="INFO",
15+
level=logging.INFO,
1616
stream=sys.stdout,
1717
)
1818
logger = logging.getLogger(__name__)

deeprvat/deeprvat/associate.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
from tqdm import tqdm, trange
2626
import zarr
2727
import re
28-
2928
import deeprvat.deeprvat.models as deeprvat_models
3029
from deeprvat.data import DenseGTDataset
3130

31+
logging.root.handlers.clear() # Remove all handlers associated with the root logger object
3232
logging.basicConfig(
3333
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
3434
level=logging.INFO,
@@ -48,6 +48,11 @@
4848
AGG_FCT = {"mean": np.mean, "max": np.max}
4949

5050

51+
@click.group()
52+
def cli():
53+
pass
54+
55+
5156
def get_burden(
5257
batch: Dict,
5358
agg_models: Dict[str, List[nn.Module]],
@@ -99,11 +104,6 @@ def separate_parallel_results(results: List) -> Tuple[List, ...]:
99104
return tuple(map(list, zip(*results)))
100105

101106

102-
@click.group()
103-
def cli():
104-
pass
105-
106-
107107
def make_dataset_(
108108
config: Dict,
109109
debug: bool = False,
@@ -306,7 +306,6 @@ def make_regenie_input_(
306306
gene_metadata_file: Path,
307307
gtf: Path,
308308
):
309-
logger.setLevel(logging.INFO)
310309

311310
## Check options
312311
if not skip_burdens and burdens_genes_samples is None:
@@ -420,7 +419,7 @@ def make_regenie_input_(
420419
if average_repeats:
421420
logger.info("Averaging burdens across all repeats")
422421
burdens = np.zeros((n_samples, n_genes))
423-
for repeat in trange(burdens_zarr.shape[2]):
422+
for repeat in trange(burdens_zarr.shape[2], file=sys.stdout):
424423
burdens += burdens_zarr[:n_samples, :, repeat]
425424
burdens = burdens / burdens_zarr.shape[2]
426425
else:
@@ -448,7 +447,7 @@ def make_regenie_input_(
448447
n_samples,
449448
samples=list(sample_ids.astype(str)),
450449
) as f:
451-
for i in trange(n_genes):
450+
for i in trange(n_genes, file=sys.stdout):
452451
varid = f"pseudovariant_gene_{ensgids[i]}"
453452
this_burdens = burdens[:, i] # Rescale scores to be in range (0, 2)
454453
genotypes = np.stack(
@@ -746,7 +745,7 @@ def load_models(
746745
}
747746

748747
if len(checkpoint_files[first_repeat]) > 1:
749-
logging.info(
748+
logger.info(
750749
f" Averaging results from {len(checkpoint_files[first_repeat])} models for each repeat"
751750
)
752751

@@ -1064,7 +1063,9 @@ def combine_burden_chunks_(
10641063
end_id = 0
10651064

10661065
for i, chunk in tqdm(
1067-
enumerate(range(0, n_chunks)), desc=f"Merging {n_chunks} chunks"
1066+
enumerate(range(0, n_chunks)),
1067+
desc=f"Merging {n_chunks} chunks",
1068+
file=sys.stdout,
10681069
):
10691070
chunk_dir = burdens_chunks_dir / f"chunk_{chunk}"
10701071

deeprvat/deeprvat/common_variant_condition_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
logging.basicConfig(
2222
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
23-
level="INFO",
23+
level=logging.INFO,
2424
stream=sys.stdout,
2525
)
2626
logger = logging.getLogger(__name__)

deeprvat/deeprvat/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
logging.basicConfig(
1818
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
19-
level="INFO",
19+
level=logging.INFO,
2020
stream=sys.stdout,
2121
)
2222
logger = logging.getLogger(__name__)

deeprvat/deeprvat/evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
logging.basicConfig(
1616
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
17-
level="INFO",
17+
level=logging.INFO,
1818
stream=sys.stdout,
1919
)
2020
logger = logging.getLogger(__name__)

deeprvat/deeprvat/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
logging.basicConfig(
2020
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
21-
level="INFO",
21+
level=logging.INFO,
2222
stream=sys.stdout,
2323
)
2424
logger = logging.getLogger(__name__)

deeprvat/deeprvat/train.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pprint import pformat, pprint
1010
from tempfile import TemporaryDirectory
1111
from typing import Dict, Optional, Tuple, Union
12-
12+
import re
1313
import click
1414
import math
1515
import numpy as np
@@ -37,10 +37,9 @@
3737
from torch.utils.data import DataLoader, Dataset, Subset
3838
from tqdm import tqdm
3939

40-
4140
logging.basicConfig(
4241
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
43-
level="INFO",
42+
level=logging.INFO,
4443
stream=sys.stdout,
4544
)
4645
logger = logging.getLogger(__name__)
@@ -872,20 +871,20 @@ def run_bagging(
872871
trainer.fit(model, dm)
873872
except RuntimeError as e:
874873
# if batch_size is choosen to big, it will be reduced until it fits the GPU
875-
logging.error(f"Caught RuntimeError: {e}")
874+
logger.error(f"Caught RuntimeError: {e}")
876875
if str(e).find("CUDA out of memory") != -1:
877876
if dm.hparams.batch_size > 4:
878-
logging.error(
877+
logger.error(
879878
"Retrying training with half the original batch size"
880879
)
881880
gc.collect()
882881
torch.cuda.empty_cache()
883882
dm.hparams.batch_size = dm.hparams.batch_size // 2
884883
else:
885-
logging.error("Batch size is already <= 4, giving up")
884+
logger.error("Batch size is already <= 4, giving up")
886885
raise RuntimeError("Could not find small enough batch size")
887886
else:
888-
logging.error(f"Caught unknown error: {e}")
887+
logger.error(f"Caught unknown error: {e}")
889888
raise e
890889
else:
891890
break
@@ -1167,7 +1166,21 @@ def best_training_run(
11671166
config = yaml.safe_load(f)
11681167

11691168
with open(config_file_out, "w") as f:
1170-
yaml.dump({"model": config["model"]}, f)
1169+
yaml.dump(
1170+
{
1171+
"model": config["model"],
1172+
"rare_variant_annotations": config["training_data"]["dataset_config"][
1173+
"rare_embedding"
1174+
]["config"]["annotations"],
1175+
"training_data_thresholds": {
1176+
k: str(re.sub(f"^{k} ", "", v))
1177+
for k, v in config["training_data"]["dataset_config"][
1178+
"rare_embedding"
1179+
]["config"]["thresholds"].items()
1180+
},
1181+
},
1182+
f,
1183+
)
11711184

11721185
n_bags = config["training"]["n_bags"] if not debug else 3
11731186
for k in range(n_bags):

deeprvat/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
logging.basicConfig(
1010
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
11-
level="INFO",
11+
level=logging.INFO,
1212
stream=sys.stdout,
1313
)
1414
logger = logging.getLogger(__name__)

deeprvat/seed_gene_discovery/evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
logging.basicConfig(
1717
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
18-
level="INFO",
18+
level=logging.INFO,
1919
stream=sys.stdout,
2020
)
2121
logger = logging.getLogger(__name__)

deeprvat/seed_gene_discovery/seed_gene_discovery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
logging.basicConfig(
2626
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
27-
level="INFO",
27+
level=logging.INFO,
2828
stream=sys.stdout,
2929
)
3030
logger = logging.getLogger(__name__)

deeprvat/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
logging.basicConfig(
1919
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
20-
level="INFO",
20+
level=logging.INFO,
2121
stream=sys.stdout,
2222
)
2323
logger = logging.getLogger(__name__)

pipelines/association_testing/association_dataset.snakefile

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@ rule association_dataset:
1616
resources:
1717
mem_mb = lambda wildcards, attempt: 32000 * (attempt + 1),
1818
priority: 30
19+
log:
20+
stdout="logs/association_dataset/{phenotype}.stdout",
21+
stderr="logs/association_dataset/{phenotype}.stderr"
1922
shell:
2023
'deeprvat_associate make-dataset '
2124
+ debug +
2225
"--skip-genotypes "
2326
'{input.data_config} '
24-
'{output}'
27+
'{output} '
28+
+ logging_redirct
2529

2630

2731
rule association_dataset_burdens:
@@ -33,8 +37,12 @@ rule association_dataset_burdens:
3337
resources:
3438
mem_mb = lambda wildcards, attempt: 32000 * (attempt + 1)
3539
priority: 30
40+
log:
41+
stdout=f"logs/association_dataset_burdens/{phenotypes[0]}.stdout",
42+
stderr=f"logs/association_dataset_burdens/{phenotypes[0]}.stderr"
3643
shell:
3744
'deeprvat_associate make-dataset '
3845
+ debug +
3946
'{input.data_config} '
40-
'{output}'
47+
'{output} '
48+
+ logging_redirct

pipelines/association_testing/burdens.snakefile

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@ rule combine_burdens:
1616
threads: 1
1717
resources:
1818
mem_mb = lambda wildcards, attempt: 4098 + (attempt - 1) * 4098,
19+
log:
20+
stdout="logs/combine_burdens/combine_burdens.stdout",
21+
stderr="logs/combine_burdens/combine_burdens.stderr"
1922
shell:
2023
' '.join([
2124
'deeprvat_associate combine-burden-chunks',
2225
'{params.prefix}/burdens/chunks/',
2326
' --n-chunks ' + str(n_burden_chunks),
24-
'{params.prefix}/burdens',
27+
'{params.prefix}/burdens ',
28+
logging_redirct
2529
])
2630

2731
rule all_xy:
@@ -42,14 +46,18 @@ rule compute_xy:
4246
threads: 8
4347
resources:
4448
mem_mb = lambda wildcards, attempt: 20480 + (attempt - 1) * 4098,
49+
log:
50+
stdout="logs/compute_xy/{phenotype}.stdout",
51+
stderr="logs/compute_xy/{phenotype}.stderr"
4552
shell:
4653
' && '.join([
4754
('deeprvat_associate compute-xy '
4855
'--dataset-file {input.dataset} '
4956
'{input.data_config} '
5057
"{output.samples} "
5158
"{output.x} "
52-
"{output.y}")
59+
"{output.y} "
60+
+ logging_redirct)
5361
])
5462

5563

@@ -73,6 +81,9 @@ rule compute_burdens:
7381
resources:
7482
mem_mb = 32000,
7583
gpus = 1
84+
log:
85+
stdout="logs/compute_burdens/compute_burdens_{chunk}.stdout",
86+
stderr="logs/compute_burdens/compute_burdens_{chunk}.stderr"
7687
shell:
7788
' '.join([
7889
'deeprvat_associate compute-burdens '
@@ -83,7 +94,8 @@ rule compute_burdens:
8394
'{input.data_config} '
8495
'{input.model_config} '
8596
'{input.checkpoints} '
86-
'{params.prefix}/burdens'],
97+
'{params.prefix}/burdens '
98+
+ logging_redirct ],
8799
)
88100

89101

@@ -98,11 +110,16 @@ rule reverse_models:
98110
threads: 4
99111
resources:
100112
mem_mb = 20480,
113+
log:
114+
stdout="logs/reverse_models/reverse_models.stdout",
115+
stderr="logs/reverse_models/reverse_models.stderr"
101116
shell:
102117
" && ".join([
103118
("deeprvat_associate reverse-models "
104119
"{input.model_config} "
105120
"{input.data_config} "
106-
"{input.checkpoints}"),
107-
"touch {output}"
121+
"{input.checkpoints} "
122+
+ logging_redirct),
123+
"touch {output} "
124+
108125
])

0 commit comments

Comments
 (0)