Skip to content

Commit f0a2e25

Browse files
bfclarkeThibaultBechtlerendastPMBio
authored
Feature/chunked writer precomputed burdens testing combo (#135)
* make genotype and variant files optional * specify burden file manually * specify burden file manually * add pipelines * bug fixes * add pipelines * add eval pipeline * Small bug fix removing overwriting of function option name. * add argument for additional REGENIE options * keep beta in results * add conditional analysis * run regenie_step2 in long queue * improve modularity * remove .loco files from output/input * run REGENIE step 2 on verylong queue * bug fixes * bug fixes * Add tests * Add test data for merging * Fix test path * unused variable * bug fix * bug fix (remove broken debugging code) * remove duplicate rule * update pipelines for new burden directory structure * bug fixes for REGENIE pipeline * bug fixes for REGENIE pipelines * add option for determinism in training * add test for results of training_association_testing pipeline * make genotype and variant files optional * specify burden file manually * specify burden file manually * add pipelines * bug fixes * add pipelines * add eval pipeline * add argument for additional REGENIE options * add conditional analysis * Small bug fix removing overwriting of function option name. * run regenie_step2 in long queue * improve modularity * remove .loco files from output/input * run REGENIE step 2 on verylong queue * bug fixes * bug fixes * bug fix * bug fix (remove broken debugging code) * remove duplicate rule * update pipelines for new burden directory structure * bug fixes for REGENIE pipeline * bug fixes for REGENIE pipelines * add option for determinism in training * add test for results of training_association_testing pipeline * make genotype and variant files optional * specify burden file manually * specify burden file manually * add pipelines * bug fixes * add pipelines * add eval pipeline * add argument for additional REGENIE options * add conditional analysis * Small bug fix removing overwriting of function option name. * run regenie_step2 in long queue * improve modularity * remove .loco files from output/input * run REGENIE step 2 on verylong queue * bug fixes * bug fixes * bug fix * bug fix (remove broken debugging code) * remove duplicate rule * update pipelines for new burden directory structure * bug fixes for REGENIE pipeline * bug fixes for REGENIE pipelines * add option for determinism in training * add test for results of training_association_testing pipeline * make genotype and variant files optional * specify burden file manually * specify burden file manually * bug fixes * add eval pipeline * add argument for additional REGENIE options * update pipelines for new burden directory structure * bug fixes for REGENIE pipelines * add test for results of training_association_testing pipeline * corrections to pipelines * delete old example config * bug fixes for new config file naming * improve error message when repo is misspecified * correct data key * add command to compare association testing results * cast trial ID to int to fix intermittent bug where it's stored as float * expose deterministic option (for testing/debugging) * remove deeprvat_repo_dir * Add chunked writer * Remove deeprvat from paths * Fix dtype * Add merge to snakemake pipeline * Remove x.zarr and y.zarr * Working test for 5 chunks * Add back sample_ids test * Fix pretrained path * Remove phenotypes key * Fix failing tests * More config errors * fixes to pipelines * remove unfinished sections * pipeline fixes * reduce number of phenotypes in example configs * bug fix * pipeline fixes * add example config for REGENIE * pipeline fixes * pipeline fixes * pipeline fixes * adapt CV pipeline * pipeline fix * reduce uninformative logging * modifications for running on compute cluster * pipeline fix * fix typo * delete unused files * fixup! Format Python code with psf/black pull_request * pipeline fixes * code cleanup * remove unneded teest files * fixup! Format Python code with psf/black pull_request * correct typo * black * fixup! Format Python code with psf/black pull_request --------- Co-authored-by: Brian Clarke <brian.clarke@dkfz.de> Co-authored-by: Thibault <th.bechtler@gmail.com> Co-authored-by: Magnus Wahlberg <endast@gmail.com> Co-authored-by: PMBio <PMBio@users.noreply.github.com>
1 parent 801fc32 commit f0a2e25

39 files changed

+2068
-622
lines changed

deeprvat/cv_utils.py

Lines changed: 69 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -96,80 +96,117 @@ def generate_test_config(input_config, out_file, fold, n_folds):
9696

9797

9898
@cli.command()
99-
@click.option("--link-burdens", type=click.Path())
99+
@click.option("--skip-burdens", is_flag=True)
100100
@click.option("--burden-dirs", "-b", multiple=True)
101-
@click.argument("out_dir", type=click.Path(), default="./")
101+
@click.option("--xy-dirs", "-b", multiple=True)
102+
@click.argument("out_dir_burdens", type=click.Path(), default="./")
103+
@click.argument("out_dir_xy", type=click.Path(), default="./")
102104
@click.argument("config_file", type=click.Path(exists=True))
103105
def combine_test_set_burdens(
104-
out_dir,
105-
link_burdens,
106+
out_dir_burdens,
107+
out_dir_xy,
108+
skip_burdens,
106109
burden_dirs,
110+
xy_dirs,
107111
config_file,
108112
):
113+
assert len(burden_dirs) == len(xy_dirs)
114+
109115
with open(config_file) as f:
110116
config = yaml.safe_load(f)
111117
compression_level = 1
112-
skip_burdens = link_burdens is not None
113118
n_total_samples = []
114-
for burden_dir in burden_dirs:
115-
print(burden_dir)
116-
this_y = zarr.open(f"{burden_dir}/y.zarr")
117-
this_x = zarr.open(f"{burden_dir}/x.zarr")
119+
for xy_dir, burden_dir in zip(xy_dirs, burden_dirs):
120+
logger.debug(xy_dir)
121+
this_y = zarr.open(f"{xy_dir}/y.zarr")
122+
this_x = zarr.open(f"{xy_dir}/x.zarr")
123+
this_sample_ids_xy = zarr.load(f"{xy_dir}/sample_ids.zarr")
118124
# this_burdens = zarr.open(f'{burden_dir}/burdens.zarr')
119125

120-
assert this_y.shape[0] == this_x.shape[0] # == this_burdens.shape[0]
126+
assert this_y.shape[0] == this_x.shape[0]
121127
n_total_samples.append(this_y.shape[0])
122128

129+
if not skip_burdens:
130+
this_burdens = zarr.open(f"{burden_dir}/burdens.zarr")
131+
this_sample_ids_burdens = zarr.load(f"{burden_dir}/sample_ids.zarr")
132+
assert this_y.shape[0] == this_burdens.shape[0]
133+
logger.debug(this_sample_ids_xy, this_sample_ids_burdens)
134+
assert np.array_equal(this_sample_ids_xy, this_sample_ids_burdens)
135+
123136
n_total_samples = np.sum(n_total_samples)
124-
print(f"Total number of samples {n_total_samples}")
137+
logger.info(f"Total number of samples: {n_total_samples}")
125138
if not skip_burdens:
126-
this_burdens = zarr.open(
127-
f"{burden_dir}/burdens.zarr"
128-
) # any burden tensor (here from the last file to get dims 1 -n)
129139
burdens = zarr.open(
130-
Path(out_dir) / "burdens.zarr",
140+
Path(out_dir_burdens) / "burdens.zarr",
131141
mode="a",
132142
shape=(n_total_samples,) + this_burdens.shape[1:],
133143
chunks=(1000, 1000),
134144
dtype=np.float32,
135145
compressor=Blosc(clevel=compression_level),
136146
)
137-
print(f"burdens shape: {burdens.shape}")
138-
else:
139-
burdens = None
147+
logger.info(f"burdens shape: {burdens.shape}")
148+
sample_ids_burdens = zarr.open(
149+
Path(out_dir_burdens) / "sample_ids.zarr",
150+
mode="a",
151+
shape=(n_total_samples),
152+
chunks=(None),
153+
dtype="U200",
154+
compressor=Blosc(clevel=compression_level),
155+
)
140156

141157
y = zarr.open(
142-
Path(out_dir) / "y.zarr",
158+
Path(out_dir_xy) / "y.zarr",
143159
mode="a",
144160
shape=(n_total_samples,) + this_y.shape[1:],
145161
chunks=(None, None),
146162
dtype=np.float32,
147163
compressor=Blosc(clevel=compression_level),
148164
)
149165
x = zarr.open(
150-
Path(out_dir) / "x.zarr",
166+
Path(out_dir_xy) / "x.zarr",
151167
mode="a",
152168
shape=(n_total_samples,) + this_x.shape[1:],
153169
chunks=(None, None),
154170
dtype=np.float32,
155171
compressor=Blosc(clevel=compression_level),
156172
)
173+
sample_ids_xy = zarr.open(
174+
Path(out_dir_xy) / "sample_ids.zarr",
175+
mode="a",
176+
shape=(n_total_samples),
177+
chunks=(None),
178+
dtype="U200",
179+
compressor=Blosc(clevel=compression_level),
180+
)
157181

158182
start_idx = 0
159183

160-
for burden_dir in burden_dirs:
161-
this_y = zarr.open(f"{burden_dir}/y.zarr")[:]
184+
for xy_dir, burden_dir in zip(xy_dirs, burden_dirs):
185+
this_y = zarr.load(f"{xy_dir}/y.zarr")
162186
end_idx = start_idx + this_y.shape[0]
163-
this_x = zarr.open(f"{burden_dir}/x.zarr")[:]
164-
if not skip_burdens:
165-
logger.info("writing burdens")
166-
this_burdens = zarr.open(f"{burden_dir}/burdens.zarr")[:]
167-
burdens[start_idx:end_idx] = this_burdens
168187
print((start_idx, end_idx))
188+
this_x = zarr.load(f"{xy_dir}/x.zarr")
189+
this_sample_ids_xy = zarr.load(f"{xy_dir}/sample_ids.zarr")
169190
y[start_idx:end_idx] = this_y
170191
x[start_idx:end_idx] = this_x
192+
sample_ids_xy[start_idx:end_idx] = this_sample_ids_xy
193+
if not skip_burdens:
194+
logger.info("writing burdens")
195+
this_burdens = zarr.load(f"{burden_dir}/burdens.zarr")
196+
burdens[start_idx:end_idx] = this_burdens
197+
this_sample_ids_burdens = zarr.load(f"{burden_dir}/sample_ids.zarr")
198+
sample_ids_burdens[start_idx:end_idx] = this_sample_ids_burdens
171199
start_idx = end_idx
172200

201+
# sanity check
202+
if not skip_burdens and not np.array_equal(sample_ids_xy[:], sample_ids_burdens[:]):
203+
logger.error(
204+
"sample_ids_xy, sample_ids_burdens do not match:\n"
205+
+ f"sample_ids_xy: {sample_ids_xy[:]}"
206+
+ f"sample_ids_burdens: {sample_ids_burdens[:]}"
207+
)
208+
raise RuntimeError("sample_ids_xy, sample_ids_burdens do not match")
209+
173210
y_transformation = config["association_testing_data"]["dataset_config"].get(
174211
"y_transformation", None
175212
)
@@ -202,13 +239,12 @@ def combine_test_set_burdens(
202239
for col in range(this_y.shape[1]):
203240
this_y[:, col] = my_quantile_transform(this_y[:, col])
204241
y[:] = this_y
242+
243+
if not skip_burdens:
244+
genes = np.load(f"{burden_dirs[0]}/genes.npy")
245+
np.save(Path(out_dir_burdens) / "genes.npy", genes)
246+
205247
print("done")
206-
if link_burdens is not None:
207-
source_path = Path(out_dir) / "burdens.zarr"
208-
source_path.unlink(missing_ok=True)
209-
source_path.symlink_to(link_burdens)
210-
genes = np.load(f"{burden_dirs[0]}/genes.npy")
211-
np.save(Path(out_dir) / "genes.npy", genes)
212248

213249

214250
if __name__ == "__main__":

deeprvat/data/dense_gt.py

Lines changed: 53 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
split: str = "",
6666
train_dataset: Optional[Dataset] = None,
6767
chromosomes: List[str] = None,
68-
phenotype_file: str = None,
68+
phenotype_file: Optional[str] = None,
6969
standardize_xpheno: bool = True,
7070
standardize_anno: bool = False,
7171
standardize_rare_anno: bool = False,
@@ -106,13 +106,14 @@ def __init__(
106106
zarr_dir: Optional[str] = None,
107107
cache_matrices: bool = False,
108108
verbose: bool = False,
109+
return_genotypes: bool = True,
109110
):
110111
if verbose:
111112
logger.setLevel(logging.DEBUG)
112113
else:
113114
logger.setLevel(logging.INFO)
114115

115-
self.check_samples = True # TODO undo
116+
self.check_samples = False # NOTE: Set to True for debugging
116117
self.split = split
117118
self.train_dataset = train_dataset
118119
self.chromosomes = (
@@ -134,13 +135,10 @@ def __init__(
134135
f"Using phenotypes: x: {self.x_phenotypes}, " f"y: {self.y_phenotypes}"
135136
)
136137

137-
if gt_file is None:
138-
raise ValueError("gt_file must be specified")
139-
self.gt_filename = gt_file
140-
if variant_file is None:
141-
raise ValueError("variant_file must be specified")
142138
if phenotype_file is None:
143139
raise ValueError("phenotype_file must be specified")
140+
self.gt_filename = gt_file
141+
self.return_genotypes = return_genotypes
144142
self.variant_filename = variant_file
145143
self.variant_matrix = None
146144
self.genotype_matrix = None
@@ -154,9 +152,6 @@ def __init__(
154152
self.variant_matrix = f["variant_matrix"][:]
155153
self.genotype_matrix = f["genotype_matrix"][:]
156154

157-
logger.info(
158-
f"Using phenotype file {phenotype_file} and genotype file {self.gt_filename}"
159-
)
160155
self.setup_phenotypes(
161156
phenotype_file, sim_phenotype_file, skip_y_na, skip_x_na, sample_file
162157
)
@@ -204,45 +199,54 @@ def __init__(
204199
else:
205200
self.variants_to_keep = variants_to_keep
206201

207-
self.setup_annotations(
208-
annotation_file, annotation_aggregation, precomputed_annotations
209-
)
202+
if self.return_genotypes:
203+
self.setup_annotations(
204+
annotation_file, annotation_aggregation, precomputed_annotations
205+
)
210206

211207
self.transform_data()
212-
self.setup_variants(min_common_variant_count, min_common_af, variants)
213208

214-
self.get_variant_metadata(grouping_level)
209+
if self.return_genotypes:
210+
self.setup_variants(min_common_variant_count, min_common_af, variants)
211+
212+
self.get_variant_metadata(grouping_level)
215213

216-
if rare_embedding is not None:
214+
if rare_embedding is not None and self.return_genotypes:
217215
self.rare_embedding = getattr(rare_embedders, rare_embedding["type"])(
218216
self, **rare_embedding["config"]
219217
)
220-
221218
else:
222219
self.rare_embedding = None
223220

224221
def __getitem__(self, idx: int) -> torch.tensor:
225-
if self.variant_matrix is None:
226-
gt_file = h5py.File(self.gt_filename, "r")
227-
self.variant_matrix = gt_file["variant_matrix"]
228-
self.genotype_matrix = gt_file["genotype_matrix"]
229-
if self.cache_matrices:
230-
self.variant_matrix = self.variant_matrix[:]
231-
self.genotype_matrix = self.genotype_matrix[:]
232-
233-
# idx_pheno = self.index_map_pheno[idx] #samples and phenotype is already subset so can use idx
234-
idx_geno = self.index_map_geno[idx]
235-
sparse_variants = self.variant_matrix[idx_geno, :]
236-
sparse_genotype = self.genotype_matrix[idx_geno, :]
237-
(
238-
common_variants,
239-
all_sparse_variants,
240-
sparse_genotype,
241-
) = self.get_common_variants(sparse_variants, sparse_genotype)
242-
243-
rare_variant_annotations = self.get_rare_variants(
244-
idx, all_sparse_variants, sparse_genotype
245-
)
222+
if self.return_genotypes:
223+
if self.variant_matrix is None or self.genotype_matrix is None:
224+
gt_file = h5py.File(self.gt_filename, "r")
225+
self.variant_matrix = gt_file["variant_matrix"]
226+
self.genotype_matrix = gt_file["genotype_matrix"]
227+
if self.cache_matrices:
228+
self.variant_matrix = self.variant_matrix[:]
229+
self.genotype_matrix = self.genotype_matrix[:]
230+
231+
idx_geno = self.index_map_geno[idx]
232+
if self.check_samples:
233+
# sanity check, can be removed in future
234+
assert self.samples_gt[idx_geno] == self.samples[idx]
235+
236+
sparse_variants = self.variant_matrix[idx_geno, :]
237+
sparse_genotype = self.genotype_matrix[idx_geno, :]
238+
(
239+
common_variants,
240+
all_sparse_variants,
241+
sparse_genotype,
242+
) = self.get_common_variants(sparse_variants, sparse_genotype)
243+
244+
rare_variant_annotations = self.get_rare_variants(
245+
idx, all_sparse_variants, sparse_genotype
246+
)
247+
else:
248+
common_variants = torch.tensor([], dtype=torch.float)
249+
rare_variant_annotations = torch.tensor([], dtype=torch.float)
246250

247251
phenotypes = self.phenotype_df.iloc[
248252
idx, :
@@ -255,9 +259,7 @@ def __getitem__(self, idx: int) -> torch.tensor:
255259
y = torch.tensor(
256260
phenotypes[self.y_phenotypes].to_numpy(dtype=np.float32), dtype=torch.float
257261
)
258-
if self.check_samples:
259-
# sanity check, can be removed in future
260-
assert self.samples_gt[idx_geno] == self.samples[idx]
262+
261263
return {
262264
"sample": self.samples[idx],
263265
"x_phenotypes": x_phenotype_tensor,
@@ -287,11 +289,6 @@ def setup_phenotypes(
287289
):
288290
logger.debug("Reading phenotype dataframe")
289291
self.phenotype_df = pd.read_parquet(phenotype_file, engine="pyarrow")
290-
with h5py.File(self.gt_filename, "r") as f:
291-
samples_gt = f["samples"][:]
292-
samples_gt = np.array([item.decode("utf-8") for item in samples_gt])
293-
if self.check_samples:
294-
self.samples_gt = samples_gt
295292
samples_phenotype_df = np.array(self.phenotype_df.index)
296293
# phenotypes_df has first to be sorted in the same order as samples_gt
297294
if sim_phenotype_file is not None:
@@ -315,14 +312,21 @@ def setup_phenotypes(
315312
logger.warning(
316313
"Some samples from the sample file were not found in the data"
317314
)
318-
sample_to_keep = shared_samples
315+
samples_to_keep = shared_samples
319316
logger.info(
320317
f"Number of samples in sample file and in phenotype_df: {len(samples_to_keep)}"
321318
)
322319
else:
323320
logger.info("Using all samples in phenotype df")
324321
samples_to_keep = copy.deepcopy(samples_phenotype_df)
325322

323+
# if self.return_genotypes:
324+
with h5py.File(self.gt_filename, "r") as f:
325+
samples_gt = f["samples"][:]
326+
samples_gt = np.array([item.decode("utf-8") for item in samples_gt])
327+
if self.check_samples:
328+
self.samples_gt = samples_gt
329+
326330
logger.info("Removing samples that are not in genotype file")
327331

328332
samples_to_keep = np.array(
@@ -353,11 +357,11 @@ def setup_phenotypes(
353357
mask_cols += self.x_phenotypes
354358
mask = (self.phenotype_df[mask_cols].notna()).all(axis=1)
355359
mask &= samples_to_keep_mask
356-
samples_to_keep = self.phenotype_df.index[mask]
360+
self.samples = self.phenotype_df.index[mask]
357361
self.n_samples = mask.sum()
358362
logger.info(f"Final number of kept samples: {self.n_samples}")
363+
359364
self.phenotype_df = self.phenotype_df[mask]
360-
self.samples = self.phenotype_df.index.to_numpy()
361365

362366
# account for the fact that genotypes.h5 and phenotype_df can have different
363367
# orders of their samples

0 commit comments

Comments
 (0)