Skip to content

Commit 9c01e66

Browse files
committed
checkpoint doesnt partition
1 parent 3dd56d6 commit 9c01e66

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

v03_pipeline/lib/misc/io.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -273,31 +273,26 @@ def remap_pedigree_hash(remap_path: str, pedigree_path: str) -> hl.Int32Expressi
273273

274274
def checkpoint(
275275
t: hl.Table | hl.MatrixTable,
276-
repartition_factor: int = 1,
277276
) -> tuple[hl.Table | hl.MatrixTable, str]:
278277
suffix = 'mt' if isinstance(t, hl.MatrixTable) else 'ht'
279278
read_fn = hl.read_matrix_table if isinstance(t, hl.MatrixTable) else hl.read_table
280279
checkpoint_path = os.path.join(
281280
Env.HAIL_TMP_DIR,
282281
f'{uuid.uuid4()}.{suffix}',
283282
)
284-
t.write(checkpoint_path, repartition_factor=repartition_factor)
283+
t.write(checkpoint_path)
285284
return read_fn(checkpoint_path), checkpoint_path
286285

287286

288287
def write(
289288
t: hl.Table | hl.MatrixTable,
290289
destination_path: str,
291290
repartition: bool = True,
292-
# May be used to increase the number of partitions beyond
293-
# the optimally computed number. A higher number will
294-
# shard the table into more partitions.
295-
repartition_factor: int = 1,
296291
) -> hl.Table | hl.MatrixTable:
297292
t, path = checkpoint(t)
298293
if repartition:
299294
t = t.repartition(
300-
(compute_hail_n_partitions(file_size_bytes(path)) * repartition_factor),
295+
compute_hail_n_partitions(file_size_bytes(path)),
301296
shuffle=False,
302297
)
303298
return t.write(destination_path, overwrite=True)

v03_pipeline/lib/reference_datasets/splice_ai.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import hail as hl
22

3-
from v03_pipeline.lib.misc.io import checkpoint
3+
from v03_pipeline.lib.misc.io import (
4+
checkpoint,
5+
compute_hail_n_partitions,
6+
file_size_bytes,
7+
)
48
from v03_pipeline.lib.model import ReferenceGenome
59
from v03_pipeline.lib.reference_datasets.misc import vcf_to_ht
610

@@ -13,7 +17,12 @@ def get_ht(
1317
# of file descriptors on dataproc :/
1418
hl._set_flags(use_new_shuffle=None, no_whole_stage_codegen='1') # noqa: SLF001
1519
ht = vcf_to_ht(paths, reference_genome)
16-
ht, _ = checkpoint(ht, repartition_factor=2)
20+
ht, checkpoint_path = checkpoint(ht)
21+
# The default partitions are too big, leading to OOMs.
22+
ht = ht.repartition(
23+
compute_hail_n_partitions(file_size_bytes(checkpoint_path)),
24+
shuffle=False,
25+
)
1726

1827
# SpliceAI INFO field description from the VCF header: SpliceAIv1.3 variant annotation. These include
1928
# delta scores (DS) and delta positions (DP) for acceptor gain (AG), acceptor loss (AL), donor gain (DG), and

0 commit comments

Comments
 (0)