@@ -271,26 +271,33 @@ def remap_pedigree_hash(remap_path: str, pedigree_path: str) -> hl.Int32Expressi
271
271
return hl .int32 (int (sha256 .hexdigest ()[:8 ], 16 ))
272
272
273
273
274
- def checkpoint (t : hl .Table | hl .MatrixTable ) -> tuple [hl .Table | hl .MatrixTable , str ]:
274
+ def checkpoint (
275
+ t : hl .Table | hl .MatrixTable ,
276
+ repartition_factor : int = 1 ,
277
+ ) -> tuple [hl .Table | hl .MatrixTable , str ]:
275
278
suffix = 'mt' if isinstance (t , hl .MatrixTable ) else 'ht'
276
279
read_fn = hl .read_matrix_table if isinstance (t , hl .MatrixTable ) else hl .read_table
277
280
checkpoint_path = os .path .join (
278
281
Env .HAIL_TMP_DIR ,
279
282
f'{ uuid .uuid4 ()} .{ suffix } ' ,
280
283
)
281
- t .write (checkpoint_path )
284
+ t .write (checkpoint_path , repartition_factor = repartition_factor )
282
285
return read_fn (checkpoint_path ), checkpoint_path
283
286
284
287
285
288
def write (
286
289
t : hl .Table | hl .MatrixTable ,
287
290
destination_path : str ,
288
291
repartition : bool = True ,
292
+ # May be used to increase the number of partitions beyond
293
+ # the optimally computed number. A higher number will
294
+ # shard the table into more partitions.
295
+ repartition_factor : int = 1 ,
289
296
) -> hl .Table | hl .MatrixTable :
290
297
t , path = checkpoint (t )
291
298
if repartition :
292
299
t = t .repartition (
293
- compute_hail_n_partitions (file_size_bytes (path )),
300
+ ( compute_hail_n_partitions (file_size_bytes (path )) * repartition_factor ),
294
301
shuffle = False ,
295
302
)
296
303
return t .write (destination_path , overwrite = True )
0 commit comments