@@ -65,7 +65,7 @@ def __init__(
65
65
split : str = "" ,
66
66
train_dataset : Optional [Dataset ] = None ,
67
67
chromosomes : List [str ] = None ,
68
- phenotype_file : str = None ,
68
+ phenotype_file : Optional [ str ] = None ,
69
69
standardize_xpheno : bool = True ,
70
70
standardize_anno : bool = False ,
71
71
standardize_rare_anno : bool = False ,
@@ -106,13 +106,14 @@ def __init__(
106
106
zarr_dir : Optional [str ] = None ,
107
107
cache_matrices : bool = False ,
108
108
verbose : bool = False ,
109
+ return_genotypes : bool = True ,
109
110
):
110
111
if verbose :
111
112
logger .setLevel (logging .DEBUG )
112
113
else :
113
114
logger .setLevel (logging .INFO )
114
115
115
- self .check_samples = True # TODO undo
116
+ self .check_samples = False # NOTE: Set to True for debugging
116
117
self .split = split
117
118
self .train_dataset = train_dataset
118
119
self .chromosomes = (
@@ -134,13 +135,10 @@ def __init__(
134
135
f"Using phenotypes: x: { self .x_phenotypes } , " f"y: { self .y_phenotypes } "
135
136
)
136
137
137
- if gt_file is None :
138
- raise ValueError ("gt_file must be specified" )
139
- self .gt_filename = gt_file
140
- if variant_file is None :
141
- raise ValueError ("variant_file must be specified" )
142
138
if phenotype_file is None :
143
139
raise ValueError ("phenotype_file must be specified" )
140
+ self .gt_filename = gt_file
141
+ self .return_genotypes = return_genotypes
144
142
self .variant_filename = variant_file
145
143
self .variant_matrix = None
146
144
self .genotype_matrix = None
@@ -154,9 +152,6 @@ def __init__(
154
152
self .variant_matrix = f ["variant_matrix" ][:]
155
153
self .genotype_matrix = f ["genotype_matrix" ][:]
156
154
157
- logger .info (
158
- f"Using phenotype file { phenotype_file } and genotype file { self .gt_filename } "
159
- )
160
155
self .setup_phenotypes (
161
156
phenotype_file , sim_phenotype_file , skip_y_na , skip_x_na , sample_file
162
157
)
@@ -204,45 +199,54 @@ def __init__(
204
199
else :
205
200
self .variants_to_keep = variants_to_keep
206
201
207
- self .setup_annotations (
208
- annotation_file , annotation_aggregation , precomputed_annotations
209
- )
202
+ if self .return_genotypes :
203
+ self .setup_annotations (
204
+ annotation_file , annotation_aggregation , precomputed_annotations
205
+ )
210
206
211
207
self .transform_data ()
212
- self .setup_variants (min_common_variant_count , min_common_af , variants )
213
208
214
- self .get_variant_metadata (grouping_level )
209
+ if self .return_genotypes :
210
+ self .setup_variants (min_common_variant_count , min_common_af , variants )
211
+
212
+ self .get_variant_metadata (grouping_level )
215
213
216
- if rare_embedding is not None :
214
+ if rare_embedding is not None and self . return_genotypes :
217
215
self .rare_embedding = getattr (rare_embedders , rare_embedding ["type" ])(
218
216
self , ** rare_embedding ["config" ]
219
217
)
220
-
221
218
else :
222
219
self .rare_embedding = None
223
220
224
221
def __getitem__ (self , idx : int ) -> torch .tensor :
225
- if self .variant_matrix is None :
226
- gt_file = h5py .File (self .gt_filename , "r" )
227
- self .variant_matrix = gt_file ["variant_matrix" ]
228
- self .genotype_matrix = gt_file ["genotype_matrix" ]
229
- if self .cache_matrices :
230
- self .variant_matrix = self .variant_matrix [:]
231
- self .genotype_matrix = self .genotype_matrix [:]
232
-
233
- # idx_pheno = self.index_map_pheno[idx] #samples and phenotype is already subset so can use idx
234
- idx_geno = self .index_map_geno [idx ]
235
- sparse_variants = self .variant_matrix [idx_geno , :]
236
- sparse_genotype = self .genotype_matrix [idx_geno , :]
237
- (
238
- common_variants ,
239
- all_sparse_variants ,
240
- sparse_genotype ,
241
- ) = self .get_common_variants (sparse_variants , sparse_genotype )
242
-
243
- rare_variant_annotations = self .get_rare_variants (
244
- idx , all_sparse_variants , sparse_genotype
245
- )
222
+ if self .return_genotypes :
223
+ if self .variant_matrix is None or self .genotype_matrix is None :
224
+ gt_file = h5py .File (self .gt_filename , "r" )
225
+ self .variant_matrix = gt_file ["variant_matrix" ]
226
+ self .genotype_matrix = gt_file ["genotype_matrix" ]
227
+ if self .cache_matrices :
228
+ self .variant_matrix = self .variant_matrix [:]
229
+ self .genotype_matrix = self .genotype_matrix [:]
230
+
231
+ idx_geno = self .index_map_geno [idx ]
232
+ if self .check_samples :
233
+ # sanity check, can be removed in future
234
+ assert self .samples_gt [idx_geno ] == self .samples [idx ]
235
+
236
+ sparse_variants = self .variant_matrix [idx_geno , :]
237
+ sparse_genotype = self .genotype_matrix [idx_geno , :]
238
+ (
239
+ common_variants ,
240
+ all_sparse_variants ,
241
+ sparse_genotype ,
242
+ ) = self .get_common_variants (sparse_variants , sparse_genotype )
243
+
244
+ rare_variant_annotations = self .get_rare_variants (
245
+ idx , all_sparse_variants , sparse_genotype
246
+ )
247
+ else :
248
+ common_variants = torch .tensor ([], dtype = torch .float )
249
+ rare_variant_annotations = torch .tensor ([], dtype = torch .float )
246
250
247
251
phenotypes = self .phenotype_df .iloc [
248
252
idx , :
@@ -255,9 +259,7 @@ def __getitem__(self, idx: int) -> torch.tensor:
255
259
y = torch .tensor (
256
260
phenotypes [self .y_phenotypes ].to_numpy (dtype = np .float32 ), dtype = torch .float
257
261
)
258
- if self .check_samples :
259
- # sanity check, can be removed in future
260
- assert self .samples_gt [idx_geno ] == self .samples [idx ]
262
+
261
263
return {
262
264
"sample" : self .samples [idx ],
263
265
"x_phenotypes" : x_phenotype_tensor ,
@@ -287,11 +289,6 @@ def setup_phenotypes(
287
289
):
288
290
logger .debug ("Reading phenotype dataframe" )
289
291
self .phenotype_df = pd .read_parquet (phenotype_file , engine = "pyarrow" )
290
- with h5py .File (self .gt_filename , "r" ) as f :
291
- samples_gt = f ["samples" ][:]
292
- samples_gt = np .array ([item .decode ("utf-8" ) for item in samples_gt ])
293
- if self .check_samples :
294
- self .samples_gt = samples_gt
295
292
samples_phenotype_df = np .array (self .phenotype_df .index )
296
293
# phenotypes_df has first to be sorted in the same order as samples_gt
297
294
if sim_phenotype_file is not None :
@@ -315,14 +312,21 @@ def setup_phenotypes(
315
312
logger .warning (
316
313
"Some samples from the sample file were not found in the data"
317
314
)
318
- sample_to_keep = shared_samples
315
+ samples_to_keep = shared_samples
319
316
logger .info (
320
317
f"Number of samples in sample file and in phenotype_df: { len (samples_to_keep )} "
321
318
)
322
319
else :
323
320
logger .info ("Using all samples in phenotype df" )
324
321
samples_to_keep = copy .deepcopy (samples_phenotype_df )
325
322
323
+ # if self.return_genotypes:
324
+ with h5py .File (self .gt_filename , "r" ) as f :
325
+ samples_gt = f ["samples" ][:]
326
+ samples_gt = np .array ([item .decode ("utf-8" ) for item in samples_gt ])
327
+ if self .check_samples :
328
+ self .samples_gt = samples_gt
329
+
326
330
logger .info ("Removing samples that are not in genotype file" )
327
331
328
332
samples_to_keep = np .array (
@@ -353,11 +357,11 @@ def setup_phenotypes(
353
357
mask_cols += self .x_phenotypes
354
358
mask = (self .phenotype_df [mask_cols ].notna ()).all (axis = 1 )
355
359
mask &= samples_to_keep_mask
356
- samples_to_keep = self .phenotype_df .index [mask ]
360
+ self . samples = self .phenotype_df .index [mask ]
357
361
self .n_samples = mask .sum ()
358
362
logger .info (f"Final number of kept samples: { self .n_samples } " )
363
+
359
364
self .phenotype_df = self .phenotype_df [mask ]
360
- self .samples = self .phenotype_df .index .to_numpy ()
361
365
362
366
# account for the fact that genotypes.h5 and phenotype_df can have different
363
367
# orders of their samples
0 commit comments