diff --git a/workflow/Snakefile b/workflow/Snakefile index 5a57486..ec5aa4c 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -22,6 +22,11 @@ if "variant-calls" in config: benchmark=used_benchmarks, vartype=["snvs", "indels"], ), + expand( + "results/report/fp-fn/callsets/{callset}/{classification}", + callset=used_callsets, + classification=["fp", "fn"], + ), get_fp_fn_reports, # collect the checkpoint inputs to avoid issues when # --all-temp is used: --all-temp leads to premature deletion diff --git a/workflow/resources/datavzrd/fp-fn-per-callset-config.yte.yaml b/workflow/resources/datavzrd/fp-fn-per-callset-config.yte.yaml new file mode 100644 index 0000000..7e307bf --- /dev/null +++ b/workflow/resources/datavzrd/fp-fn-per-callset-config.yte.yaml @@ -0,0 +1,37 @@ +__use_yte__: true + +__variables__: + green: "#74c476" + orange: "#fd8d3c" + +name: ?f"{wildcards.classification} of {wildcards.callset}" + +webview-controls: true +default-view: results-table + +datasets: + results: + path: ?input.table + separator: "\t" + offer-excel: true + +views: + results-table: + dataset: results + desc: | + ?f""" + Rows are sorted by coverage. + Benchmark version: {params.genome} {params.version} + """ + page-size: 12 + render-table: + columns: + coverage: + plot: + heatmap: + scale: ordinal + ?if params.somatic: + true_genotype: + display-mode: hidden + predicted_genotype: + display-mode: hidden \ No newline at end of file diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index e50f0ad..12188a9 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -19,6 +19,8 @@ callsets = config.get("variant-calls", dict()) benchmarks.update(config.get("custom-benchmarks", dict())) used_benchmarks = {callset["benchmark"] for callset in callsets.values()} +used_callsets = {callset for callset in callsets.keys()} + used_genomes = {benchmarks[benchmark]["genome"] for benchmark in used_benchmarks} @@ -403,6 +405,16 @@ def get_coverages(wildcards): return coverages +def get_coverages_of_callset(callset): + benchmark = config["variant-calls"][callset]["benchmark"] + high_cov_status = benchmarks[benchmark].get("high-coverage", False) + if high_cov_status: + coverages = high_coverages + else: + coverages = low_coverages + return coverages + + def get_somatic_status(wildcards): if hasattr(wildcards, "benchmark"): return genomes[benchmarks[wildcards.benchmark]["genome"]].get("somatic") @@ -468,10 +480,17 @@ def get_collect_stratifications_input(wildcards): ) +def get_collect_stratifications_fp_fn_input(wildcards): + return expand( + "results/fp-fn/callsets/{{callset}}/{cov}.{{classification}}.tsv", + cov=get_nonempty_coverages(wildcards), + ) + + def get_fp_fn_reports(wildcards): for genome in used_genomes: yield from expand( - "results/report/fp-fn/{genome}/{cov}/{classification}", + "results/report/fp-fn/genomes/{genome}/{cov}/{classification}", genome=genome, cov={ cov @@ -482,6 +501,15 @@ def get_fp_fn_reports(wildcards): ) +def get_fp_fn_reports_benchmarks(wildcards): + for genome in used_genomes: + yield from expand( + "results/report/fp-fn/benchmarks/{benchmark}/{classification}", + benchmark={benchmark for benchmark in used_benchmarks}, + classification=["fp", "fn"], + ) + + def get_benchmark_callsets(benchmark): return [ callset @@ -497,6 +525,13 @@ def get_collect_precision_recall_input(wildcards): ) +def get_collect_fp_fn_benchmark_input(wildcards): + callsets = get_benchmark_callsets(wildcards.benchmark) + return expand( + "results/fp-fn/callsets/{callset}.{{classification}}.tsv", callset=callsets + ) + + def get_genome_name(wildcards): if hasattr(wildcards, "benchmark"): return get_benchmark(wildcards.benchmark).get("genome") @@ -546,19 +581,25 @@ def get_callset_label_entries(callsets): def get_collect_fp_fn_callsets(wildcards): - return get_genome_callsets(wildcards.genome) + callsets = get_genome_callsets(wildcards.genome) + callsets = [ + callset + for callset in callsets + if wildcards.cov in get_coverages_of_callset(callset) + ] + return callsets def get_collect_fp_fn_input(wildcards): callsets = get_collect_fp_fn_callsets(wildcards) return expand( - "results/fp-fn/callsets/{{cov}}/{callset}/{{classification}}.tsv", + "results/fp-fn/callsets/{callset}/{{cov}}.{{classification}}.tsv", callset=callsets, ) def get_collect_fp_fn_labels(wildcards): - callsets = get_genome_callsets(wildcards.genome) + callsets = get_collect_fp_fn_callsets(wildcards) return get_callset_label_entries(callsets) diff --git a/workflow/rules/eval.smk b/workflow/rules/eval.smk index b68b484..724f53a 100644 --- a/workflow/rules/eval.smk +++ b/workflow/rules/eval.smk @@ -338,9 +338,9 @@ rule extract_fp_fn: calls="results/vcfeval/{callset}/{cov}/output.vcf.gz", common_src=common_src, output: - "results/fp-fn/callsets/{cov}/{callset}/{classification}.tsv", + "results/fp-fn/callsets/{callset}/{cov}.{classification}.tsv", log: - "logs/extract-fp-fn/{cov}/{callset}/{classification}.log", + "logs/extract-fp-fn/{callset}/{cov}.{classification}.log", conda: "../envs/vembrane.yaml" script: @@ -374,6 +374,40 @@ rule collect_fp_fn: "../scripts/collect-fp-fn.py" +rule collect_stratifications_fp_fn: + input: + get_collect_stratifications_fp_fn_input, + output: + "results/fp-fn/callsets/{callset}.{classification}.tsv", + params: + coverages=get_nonempty_coverages, + coverage_lower_bounds=get_coverages, + log: + "logs/fp-fn/callsets/{callset}.{classification}.log", + conda: + "../envs/stats.yaml" + # This has to happen after precision/recall has been computed, otherwise we risk + # extremely high memory usage if a callset does not match the truth at all. + priority: 1 + script: + "../scripts/collect-stratifications-fp-fn.py" + + +rule collect_fp_fn_benchmark: + input: + tables=get_collect_fp_fn_benchmark_input, + output: + "results/fp-fn/benchmarks/{benchmark}.{classification}.tsv", + params: + callsets=lambda w: get_benchmark_callsets(w.benchmark), + log: + "logs/fp-fn/benchmarks/{benchmark}.{classification}.log", + conda: + "../envs/stats.yaml" + script: + "../scripts/collect-fp-fn-benchmarks.py" + + rule report_fp_fn: input: main_dataset="results/fp-fn/genomes/{genome}/{cov}/{classification}/main.tsv", @@ -381,11 +415,14 @@ rule report_fp_fn: config=workflow.source_path("../resources/datavzrd/fp-fn-config.yte.yaml"), output: report( - directory("results/report/fp-fn/{genome}/{cov}/{classification}"), + directory("results/report/fp-fn/genomes/{genome}/{cov}/{classification}"), htmlindex="index.html", - category="{classification} variants", + category="{classification} variants per genome", subcategory=lambda w: w.genome, - labels=lambda w: {"coverage": w.cov}, + labels=lambda w: { + "coverage": w.cov, + "genome": w.genome, + }, ), log: "logs/datavzrd/fp-fn/{genome}/{cov}/{classification}.log", @@ -394,3 +431,32 @@ rule report_fp_fn: version=get_genome_version, wrapper: "v5.0.1/utils/datavzrd" + + +rule report_fp_fn_callset: + input: + table="results/fp-fn/callsets/{callset}.{classification}.tsv", + config=workflow.source_path( + "../resources/datavzrd/fp-fn-per-callset-config.yte.yaml" + ), + output: + report( + directory("results/report/fp-fn/callsets/{callset}/{classification}"), + htmlindex="index.html", + category="{classification} variants per benchmark", + subcategory=lambda w: config["variant-calls"][w.callset]["benchmark"], + labels=lambda w: { + "callset": w.callset, + }, + ), + log: + "logs/datavzrd/fp-fn/{callset}/{classification}.log", + params: + labels=lambda w: get_callsets_labels( + get_benchmark_callsets(config["variant-calls"][w.callset]["benchmark"]) + ), + genome=get_genome_name, + version=get_genome_version, + somatic=get_somatic_status, + wrapper: + "v5.0.1/utils/datavzrd" diff --git a/workflow/scripts/collect-fp-fn-benchmarks.py b/workflow/scripts/collect-fp-fn-benchmarks.py new file mode 100644 index 0000000..97d370b --- /dev/null +++ b/workflow/scripts/collect-fp-fn-benchmarks.py @@ -0,0 +1,42 @@ +import sys +sys.stderr = open(snakemake.log[0], "w") + +import pandas as pd + + +def load_data(path, callset): + d = pd.read_csv(path, sep="\t") + d.insert(0, "callset", callset) + return d + + +results = pd.concat( + [ + load_data(f, callset) + for f, callset in zip(snakemake.input.tables, snakemake.params.callsets) + ], + axis="rows", +) + +def cov_key(cov_label): + # return lower bound as integer for sorting + if ".." in cov_label: + return int(cov_label.split("..")[0]) + else: + return int(cov_label[1:]) + + + +def sort_key(col): + if col.name == "callset": + return col + if col.name == "coverage": + return col.apply(cov_key) + else: + return col + + +results.sort_values(["callset", "coverage"], inplace=True, key=sort_key) +results["sort_index"] = results["coverage"].apply(cov_key) + +results.to_csv(snakemake.output[0], sep="\t", index=False) diff --git a/workflow/scripts/collect-stratifications-fp-fn.py b/workflow/scripts/collect-stratifications-fp-fn.py new file mode 100644 index 0000000..e188f6c --- /dev/null +++ b/workflow/scripts/collect-stratifications-fp-fn.py @@ -0,0 +1,56 @@ +import sys + +sys.stderr = open(snakemake.log[0], "w") + +import pandas as pd + + +def get_cov_label(coverage): + lower = snakemake.params.coverage_lower_bounds[coverage] + bounds = [ + bound + for bound in snakemake.params.coverage_lower_bounds.values() + if bound > lower + ] + if bounds: + upper = min(bounds) + return f"{lower}..{upper}" + else: + return f"≥{lower}" + + +def load_data(f, coverage): + d = pd.read_csv(f, sep="\t") + d.insert(0, "coverage", get_cov_label(coverage)) + return d + + +if snakemake.input: + report = pd.concat( + load_data(f, cov) for cov, f in zip(snakemake.params.coverages, snakemake.input) + ) + + # TODO With separate files for SNVs and indels with e.g. STRELKA no predicted variants for the other type are expected + # If later relevant, add annotation to the report + # if (report["tp_truth"] == 0).all(): + # raise ValueError( + # f"The callset {snakemake.wildcards.callset} does not predict any variant from the truth. " + # "This is likely a technical issue in the callset and should be checked before further evaluation." + # ) + + report.to_csv(snakemake.output[0], sep="\t", index=False) +else: + pd.DataFrame( + { + col: [] + for col in [ + "coverage", + "class", + "chromosome position", + "ref_allele", + "alt_allele" + "true_genotype", + "predicted_genotype" + ] + } + ).to_csv(snakemake.output[0], sep="\t")