Skip to content

Commit fd1b744

Browse files
authored
Partial Aggregation Set handling, refactor, and testing (#114)
1 parent 3adc0f1 commit fd1b744

32 files changed

+2994
-1110
lines changed

ImputationPipeline/AggregatePRSResults.wdl

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ workflow AggregatePRSResults {
1010
String population_name = "Reference Population"
1111
File expected_control_results
1212
String lab_batch
13+
Int group_n
1314
}
1415

1516
call AggregateResults {
@@ -57,8 +58,10 @@ task AggregateResults {
5758
Array[File] results
5859
Array[File] missing_sites_shifts
5960
String lab_batch
61+
Int group_n
6062
}
6163

64+
String output_prefix = lab_batch + if group_n > 1 then "_group_" + group_n else ""
6265
command <<<
6366
Rscript - <<- "EOF"
6467
library(dplyr)
@@ -87,9 +90,9 @@ task AggregateResults {
8790
stop(paste0("There are ", num_control_samples, " control samples in the input tables, however, only 1 is expected."))
8891
}
8992
90-
write_tsv(results, paste0(lab_batch, "_all_results.tsv"))
93+
write_tsv(results, "~{output_prefix}_all_results.tsv")
9194
92-
write_tsv(results %>% filter(is_control_sample), paste0(lab_batch, "_control_results.tsv"))
95+
write_tsv(results %>% filter(is_control_sample), "~{output_prefix}_control_results.tsv")
9396
9497
results_pivoted <- results %>% filter(!is_control_sample) %>% pivot_longer(!c(sample_id, lab_batch, is_control_sample), names_to=c("condition",".value"), names_pattern="([^_]+)_(.+)")
9598
results_pivoted <- results_pivoted %T>% {options(warn=-1)} %>% mutate(adjusted = as.numeric(adjusted),
@@ -104,20 +107,20 @@ task AggregateResults {
104107
num_not_high = sum(risk=="NOT_HIGH", na.rm=TRUE),
105108
num_not_resulted = sum(risk=="NOT_RESULTED", na.rm = TRUE))
106109
107-
write_tsv(results_summarised, paste0(lab_batch, "_summarised_results.tsv"))
110+
write_tsv(results_summarised, "~{output_prefix}_summarised_results.tsv")
108111
109112
ggplot(results_pivoted, aes(x=adjusted)) +
110113
geom_density(aes(color=condition), fill=NA, position = "identity") +
111114
xlim(-5,5) + theme_bw() + xlab("z-score") + geom_function(fun=dnorm) +
112115
ylab("density")
113-
ggsave(filename = paste0(lab_batch, "_score_distribution.png"), dpi=300, width = 6, height = 6)
116+
ggsave(filename = "~{output_prefix}_score_distribution.png", dpi=300, width = 6, height = 6)
114117
115-
write_tsv(results_pivoted, paste0(lab_batch, "_pivoted_results.tsv"))
118+
write_tsv(results_pivoted, "~{output_prefix}_pivoted_results.tsv")
116119
117120
writeLines(lab_batch, "lab_batch.txt")
118121
119122
missing_sites_shifts <- c("~{sep='","' missing_sites_shifts}") %>% map(read_tsv) %>% reduce(bind_rows)
120-
write_tsv(missing_sites_shifts, paste0(lab_batch, "_missing_sites_shifts.tsv"))
123+
write_tsv(missing_sites_shifts, "~{output_prefix}_missing_sites_shifts.tsv")
121124
122125
EOF
123126
>>>
@@ -129,12 +132,12 @@ task AggregateResults {
129132
}
130133
131134
output {
132-
File batch_all_results = "~{lab_batch}_all_results.tsv"
133-
File batch_control_results = "~{lab_batch}_control_results.tsv"
134-
File batch_summarised_results = "~{lab_batch}_summarised_results.tsv"
135-
File batch_pivoted_results = "~{lab_batch}_pivoted_results.tsv"
136-
File batch_score_distribution = "~{lab_batch}_score_distribution.png"
137-
File batch_missing_sites_shifts = "~{lab_batch}_missing_sites_shifts.tsv"
135+
File batch_all_results = "~{output_prefix}_all_results.tsv"
136+
File batch_control_results = "~{output_prefix}_control_results.tsv"
137+
File batch_summarised_results = "~{output_prefix}_summarised_results.tsv"
138+
File batch_pivoted_results = "~{output_prefix}_pivoted_results.tsv"
139+
File batch_score_distribution = "~{output_prefix}_score_distribution.png"
140+
File batch_missing_sites_shifts = "~{output_prefix}_missing_sites_shifts.tsv"
138141
}
139142
}
140143
@@ -143,9 +146,11 @@ task PlotPCA {
143146
Array[File] target_pc_projections
144147
File population_pc_projections
145148
String lab_batch
149+
Int group_n
146150
String population_name
147151
}
148152
153+
String output_prefix = lab_batch + if group_n > 1 then "_group_" + group_n else ""
149154
command <<<
150155
Rscript - <<- "EOF"
151156
library(dplyr)
@@ -161,7 +166,7 @@ task PlotPCA {
161166
geom_point(data=target_pcs, aes(color="~{lab_batch}")) +
162167
theme_bw()
163168
164-
ggsave(filename = "~{lab_batch}_PCA_plot.png", dpi=300, width = 6, height = 6)
169+
ggsave(filename = "~{output_prefix}_PCA_plot.png", dpi=300, width = 6, height = 6)
165170
166171
EOF
167172
@@ -174,7 +179,7 @@ task PlotPCA {
174179
}
175180
176181
output {
177-
File pc_plot = "~{lab_batch}_PCA_plot.png"
182+
File pc_plot = "~{output_prefix}_PCA_plot.png"
178183
}
179184
}
180185
@@ -190,14 +195,17 @@ task BuildHTMLReport {
190195
File population_pc_projections
191196
String population_name
192197
String lab_batch
198+
Int group_n
193199
}
194200
201+
String output_prefix = lab_batch + if group_n > 1 then "_group_" + group_n else ""
202+
String title_batch = lab_batch + if group_n > 1 then "(group " + group_n + ")"else ""
195203
command <<<
196204
set -xeo pipefail
197205
198-
cat << EOF > ~{lab_batch}_report.Rmd
206+
cat << EOF > ~{output_prefix}_report.Rmd
199207
---
200-
title: "Batch ~{lab_batch} PRS Summary"
208+
title: "Batch ~{title_batch} PRS Summary"
201209
output:
202210
html_document:
203211
df_print: paged
@@ -386,7 +394,7 @@ task BuildHTMLReport {
386394
\`\`\`
387395
EOF
388396
389-
Rscript -e "library(rmarkdown); rmarkdown::render('~{lab_batch}_report.Rmd', 'html_document')"
397+
Rscript -e "library(rmarkdown); rmarkdown::render('~{output_prefix}_report.Rmd', 'html_document')"
390398
>>>
391399
392400
runtime {
@@ -396,6 +404,6 @@ task BuildHTMLReport {
396404
}
397405
398406
output {
399-
File report = "~{lab_batch}_report.html"
407+
File report = "~{output_prefix}_report.html"
400408
}
401409
}
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import firecloud.api as fapi
2+
import argparse
3+
from datetime import datetime
4+
import pytz
5+
from dataclasses import dataclass, field
6+
from io import StringIO
7+
8+
9+
@dataclass
10+
class AggregationSet:
11+
lab_batch: str
12+
group: int = 1
13+
delivered: bool = False
14+
contains_control: bool = False
15+
set_id: str = field(init=False)
16+
17+
def __post_init__(self):
18+
if self.group > 1:
19+
self.set_id = f'{self.lab_batch}_group_{self.group}'
20+
elif self.group == 1:
21+
self.set_id = self.lab_batch
22+
else:
23+
raise RuntimeError(
24+
f'Group of aggregation set for lab_batch {self.lab_batch} is {self.group}, should not be less than 1')
25+
26+
27+
def pre_existing_aggregation_set(lab_batch, group, delivered):
28+
return AggregationSet(lab_batch, group, delivered, True)
29+
30+
31+
def next_aggregation_set(agg_set):
32+
new_group = agg_set.group + 1
33+
return AggregationSet(agg_set.lab_batch, new_group)
34+
35+
36+
class GroupBuilder:
37+
38+
def __init__(self, workspace_namespace, workspace_name):
39+
self.workspace_namespace = workspace_namespace
40+
self.workspace_name = workspace_name
41+
42+
print('Finding tables to group by lab_batch')
43+
entity_types_response = fapi.list_entity_types(self.workspace_namespace, self.workspace_name)
44+
if not entity_types_response.ok:
45+
raise RuntimeError(f'ERROR: {entity_types_response.text}')
46+
self.entity_types_dict = entity_types_response.json()
47+
self.available_tables = self.entity_types_dict.keys()
48+
49+
def build_groups(self):
50+
for table_name, description in self.entity_types_dict.items():
51+
if all(x in description['attributeNames'] for x in ['is_control_sample', 'lab_batch']):
52+
self.group_samples_into_batches(table_name)
53+
54+
def group_samples_into_batches(self, table_name):
55+
samples_already_in_aggregation_sets = set() # set of samples already in aggregation sets
56+
lab_batch_sample_sets_dict = dict() # dict from lab_batch to highest group aggregation_set for that lab_batch
57+
58+
if f'{table_name}_set' in self.available_tables:
59+
# Download current sample_set table
60+
print(f'Downloading {table_name}_set table...')
61+
sample_set_response = fapi.get_entities(self.workspace_namespace, self.workspace_name, f'{table_name}_set')
62+
if not sample_set_response.ok:
63+
raise RuntimeError(f'ERROR: {sample_set_response.text}')
64+
sample_sets_dict = sample_set_response.json()
65+
66+
for sample_set in sample_sets_dict:
67+
samples = [e['entityName'] for e in sample_set['attributes'][f'{table_name}s']['items']]
68+
samples_already_in_aggregation_sets.update(samples)
69+
70+
for sample_set in sample_sets_dict:
71+
attributes = sample_set['attributes']
72+
lab_batch = attributes['lab_batch']
73+
this_aggregation_set = AggregationSet(lab_batch, attributes['group'], attributes['delivered'], True)
74+
if lab_batch in lab_batch_sample_sets_dict:
75+
if lab_batch_sample_sets_dict[lab_batch].group < this_aggregation_set.group:
76+
if not lab_batch_sample_sets_dict[lab_batch].delivered:
77+
raise RuntimeError(
78+
f'Aggregation set {lab_batch_sample_sets_dict[lab_batch].set_id}'
79+
f' has not been delivered, '
80+
f'but later set {this_aggregation_set.set_id} also exists')
81+
lab_batch_sample_sets_dict[attributes['lab_batch']] = this_aggregation_set
82+
else:
83+
lab_batch_sample_sets_dict[attributes['lab_batch']] = this_aggregation_set
84+
85+
# Read samples from samples table
86+
print(f'Reading {table_name} table...')
87+
sample_response = fapi.get_entities(self.workspace_namespace, self.workspace_name, f'{table_name}')
88+
if not sample_response.ok:
89+
raise RuntimeError(f'ERROR: {sample_response.text}')
90+
91+
samples = sample_response.json()
92+
# Writing new sample_set_membership.tsv
93+
added_sample_sets_dict = dict() # dictionary from lab_batch to aggregation sets with added samples
94+
control_samples_dict = dict() # dictionary from lab_batch to sample id of control sample
95+
added_samples_dict = dict() # dictionary from set_id to list of samples to be added to the set
96+
with StringIO() as new_membership_io, \
97+
StringIO() as samples_updated_io:
98+
# Write header
99+
new_membership_io.write(f'membership:{table_name}_set_id\t{table_name}\n')
100+
samples_updated_io.write(f'entity:{table_name}_id\trework\n')
101+
for sample in samples:
102+
if 'lab_batch' not in sample['attributes']:
103+
continue
104+
sample_name = sample['name']
105+
lab_batch = sample['attributes']['lab_batch']
106+
is_control_sample = sample['attributes']['is_control_sample']
107+
rework = sample['attributes'].get('rework', False)
108+
if is_control_sample:
109+
# do we already have a control sample for this lab batch? that would be bad...
110+
if lab_batch in control_samples_dict:
111+
raise RuntimeError(
112+
f'Multiple control samples for lab_bath {lab_batch}: {sample_name}, '
113+
f'{control_samples_dict[lab_batch]}')
114+
# store control sample name in dictionary
115+
control_samples_dict[lab_batch] = sample_name
116+
# we do not create aggregation sets if we only have the control sample.
117+
# We will add control samples to newly created aggregation sets later
118+
continue
119+
if rework or sample_name not in samples_already_in_aggregation_sets:
120+
# this (non-control) sample needs to be added to an aggregation set.
121+
# Find (or create) aggregation set to add it to.
122+
if lab_batch not in added_sample_sets_dict:
123+
if lab_batch in lab_batch_sample_sets_dict:
124+
previous_aggregation_set = lab_batch_sample_sets_dict[lab_batch]
125+
if previous_aggregation_set.delivered:
126+
# we need to create the next aggregation set for this lab batch
127+
added_sample_sets_dict[lab_batch] = AggregationSet(previous_aggregation_set.lab_batch,
128+
previous_aggregation_set.group + 1)
129+
else:
130+
# we can add to the previous aggregation set
131+
added_sample_sets_dict[lab_batch] = previous_aggregation_set
132+
else:
133+
# we need to create the first aggregation set for this lab batch
134+
added_sample_sets_dict[lab_batch] = AggregationSet(lab_batch)
135+
set_id = added_sample_sets_dict[lab_batch].set_id
136+
if set_id in added_samples_dict:
137+
added_samples_dict[set_id].append(sample_name)
138+
else:
139+
added_samples_dict[set_id] = [sample_name]
140+
# loop through added_sample_sets_dict and write set membership if we have control sample for set
141+
lab_batches_without_controls = list()
142+
for lab_batch, agg_set in added_sample_sets_dict.items():
143+
set_id = agg_set.set_id
144+
if agg_set.contains_control:
145+
# if this aggregation set already contains control, we can simply add new samples
146+
for sample in added_samples_dict[set_id]:
147+
new_membership_io.write(f'{set_id}\t{sample}\n')
148+
samples_updated_io.write(f'{sample}\tfalse\n')
149+
elif lab_batch in control_samples_dict:
150+
# found control sample, so this aggregation set can be added
151+
# add control sample to this aggregation set
152+
added_samples_dict[set_id].append(control_samples_dict[lab_batch])
153+
# write samples, including controls
154+
for sample in added_samples_dict[set_id]:
155+
new_membership_io.write(f'{set_id}\t{sample}\n')
156+
samples_updated_io.write(f'{sample}\tfalse\n')
157+
else:
158+
# no control sample for this aggregation set found, so will not aggregate yet
159+
del added_samples_dict[set_id]
160+
lab_batches_without_controls.append(lab_batch)
161+
for lab_batch in lab_batches_without_controls:
162+
del added_sample_sets_dict[lab_batch]
163+
if len(added_samples_dict) == 0:
164+
print(f'No new {table_name}_sets to be added.')
165+
else:
166+
if f'{table_name}_set' not in self.available_tables:
167+
print(f'Creating new table {table_name}_set')
168+
# Need to upload tsv to create new table
169+
with StringIO() as new_sample_sets_io:
170+
new_sample_sets_io.write(f'entity:{table_name}_set_id\n')
171+
for set_id in added_samples_dict:
172+
new_sample_sets_io.write(f'{set_id}\n')
173+
upload_new_table_response = fapi.upload_entities_tsv(self.workspace_namespace,
174+
self.workspace_name,
175+
new_sample_sets_io,
176+
"flexible")
177+
if not upload_new_table_response.ok:
178+
raise RuntimeError(f'ERROR: {upload_new_table_response.text}')
179+
print(f'Uploading new {table_name}_set table... ')
180+
upload_response = fapi.upload_entities_tsv(self.workspace_namespace, self.workspace_name,
181+
new_membership_io,
182+
"flexible")
183+
if not upload_response.ok:
184+
raise RuntimeError(f'ERROR: {upload_response.text}')
185+
# Add date and time created to sample_set
186+
print(f'Adding date and time to newly created {table_name}_sets...')
187+
188+
now = str(datetime.now(pytz.timezone('US/Eastern')))
189+
for i, (this_lab_batch, this_aggregation_set) in enumerate(added_sample_sets_dict.items()):
190+
update_response = fapi.update_entity(self.workspace_namespace, self.workspace_name,
191+
f'{table_name}_set', this_aggregation_set.set_id,
192+
[{"op": "AddUpdateAttribute",
193+
"attributeName": "time_sample_set_updated",
194+
"addUpdateAttribute": now},
195+
{"op": "AddUpdateAttribute", "attributeName": "delivered",
196+
"addUpdateAttribute": False},
197+
{"op": "AddUpdateAttribute", "attributeName": "redeliver",
198+
"addUpdateAttribute": False},
199+
{"op": "AddUpdateAttribute", "attributeName": "group",
200+
"addUpdateAttribute": this_aggregation_set.group},
201+
{"op": "AddUpdateAttribute", "attributeName": "lab_batch",
202+
"addUpdateAttribute": this_aggregation_set.lab_batch}
203+
])
204+
if not update_response.ok:
205+
raise RuntimeError(f'ERROR: {update_response.text}')
206+
print(f' Completed {i + 1}/{len(added_samples_dict)}')
207+
208+
print(f'Updating rework field in {table_name} table')
209+
upload_sample_rework_response = fapi.upload_entities_tsv(self.workspace_namespace, self.workspace_name,
210+
samples_updated_io,
211+
"flexible")
212+
if not upload_sample_rework_response.ok:
213+
raise RuntimeError(f'ERROR: {upload_sample_rework_response.text}')
214+
# Uploading new sample_set table
215+
print('SUCCESS')
216+
print(f'Printing update {table_name}_set_membership.tsv:')
217+
print(new_membership_io.getvalue())
218+
219+
220+
def run(workspace_namespace, workspace_name):
221+
group_builder = GroupBuilder(workspace_namespace, workspace_name)
222+
group_builder.build_groups()
223+
224+
225+
if __name__ == '__main__':
226+
parser = argparse.ArgumentParser()
227+
parser.add_argument("--workspace_namespace", dest="workspace_namespace", required=True)
228+
parser.add_argument("--workspace_name", dest="workspace_name", required=True)
229+
args = parser.parse_args()
230+
run(args.workspace_namespace, args.workspace_name)

ImputationPipeline/CreateAggregationSets/__init__.py

Whitespace-only changes.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[build-system]
2+
requires = ["setuptools>=61.0", "setuptools_scm[toml]>=6.2"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[tool.setuptools_scm]
6+
write_to = "ImputationPipeline/CreateAggregationSets/_version.py"
7+
root = "../../"
8+
9+
[project]
10+
name = "CreateAggregationSets"
11+
dynamic = ["version"]
12+
dependencies = ['firecloud >= 0.16.33', 'pytz >= 2022.2.1']
13+
requires-python = ">=3.7"

0 commit comments

Comments
 (0)