Skip to content

Commit f67d509

Browse files
authored
Refactor SEG-Y workers to open files instead of passing SegyFile from main process for safer multiprocessing. (#575)
1 parent c7461e1 commit f67d509

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

src/mdio/segy/_workers.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,47 @@
55
import os
66
from typing import TYPE_CHECKING
77
from typing import Any
8+
from typing import TypedDict
89
from typing import cast
910

1011
import numpy as np
12+
from segy import SegyFile
1113

1214
if TYPE_CHECKING:
13-
from segy import SegyFile
1415
from segy.arrays import HeaderArray
16+
from segy.config import SegySettings
17+
from segy.schema import SegySpec
1518
from zarr import Array
1619

1720
from mdio.core import Grid
1821

1922

20-
def header_scan_worker(segy_file: SegyFile, trace_range: tuple[int, int]) -> HeaderArray:
23+
class SegyFileArguments(TypedDict):
24+
"""Arguments to open SegyFile instance creation."""
25+
26+
url: str
27+
spec: SegySpec | None
28+
settings: SegySettings | None
29+
30+
31+
def header_scan_worker(
32+
segy_kw: SegyFileArguments,
33+
trace_range: tuple[int, int],
34+
) -> HeaderArray:
2135
"""Header scan worker.
2236
2337
If SegyFile is not open, it can either accept a path string or a handle that was opened in
2438
a different context manager.
2539
2640
Args:
27-
segy_file: SegyFile instance.
41+
segy_kw: Arguments to open SegyFile instance.
2842
trace_range: Tuple consisting of the trace ranges to read.
2943
3044
Returns:
3145
HeaderArray parsed from SEG-Y library.
3246
"""
47+
segy_file = SegyFile(**segy_kw)
48+
3349
slice_ = slice(*trace_range)
3450

3551
cloud_native_mode = os.getenv("MDIO__IMPORT__CLOUD_NATIVE", default="False")
@@ -52,7 +68,7 @@ def header_scan_worker(segy_file: SegyFile, trace_range: tuple[int, int]) -> Hea
5268

5369

5470
def trace_worker(
55-
segy_file: SegyFile,
71+
segy_kw: SegyFileArguments,
5672
data_array: Array,
5773
metadata_array: Array,
5874
grid: Grid,
@@ -68,7 +84,7 @@ def trace_worker(
6884
slices across the sample dimension since SEG-Y data isn't chunked, eliminating concern.
6985
7086
Args:
71-
segy_file: SegyFile instance.
87+
segy_kw: Arguments to open SegyFile instance.
7288
data_array: Handle for zarr.Array we are writing traces to
7389
metadata_array: Handle for zarr.Array we are writing trace headers
7490
grid: mdio.Grid instance
@@ -78,6 +94,7 @@ def trace_worker(
7894
Partial statistics for chunk, or None
7995
"""
8096
# Special case where there are no traces inside chunk.
97+
segy_file = SegyFile(**segy_kw)
8198
live_subset = grid.live_mask[chunk_indices[:-1]]
8299

83100
if np.count_nonzero(live_subset) == 0:

src/mdio/segy/blocked_io.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import multiprocessing as mp
65
import os
76
from concurrent.futures import ProcessPoolExecutor
87
from itertools import repeat
@@ -48,22 +47,23 @@ def to_zarr(segy_file: SegyFile, grid: Grid, data_array: Array, header_array: Ar
4847
chunker = ChunkIterator(data_array, chunk_samples=False)
4948
num_chunks = len(chunker)
5049

51-
# For Unix async writes with s3fs/fsspec & multiprocessing, use 'spawn' instead of default
52-
# 'fork' to avoid deadlocks on cloud stores. Slower but necessary. Default on Windows.
5350
num_cpus = int(os.getenv("MDIO__IMPORT__CPU_COUNT", default_cpus))
5451
num_workers = min(num_chunks, num_cpus)
55-
context = mp.get_context("spawn")
56-
executor = ProcessPoolExecutor(max_workers=num_workers, mp_context=context)
5752

5853
# Chunksize here is for multiprocessing, not Zarr chunksize.
5954
pool_chunksize, extra = divmod(num_chunks, num_workers * 4)
6055
pool_chunksize += 1 if extra else pool_chunksize
6156

57+
segy_kw = {
58+
"url": segy_file.fs.unstrip_protocol(segy_file.url),
59+
"spec": segy_file.spec,
60+
"settings": segy_file.settings,
61+
}
6262
tqdm_kw = {"unit": "block", "dynamic_ncols": True}
63-
with executor:
63+
with ProcessPoolExecutor(max_workers=num_workers) as executor:
6464
lazy_work = executor.map(
6565
trace_worker, # fn
66-
repeat(segy_file),
66+
repeat(segy_kw),
6767
repeat(data_array),
6868
repeat(header_array),
6969
repeat(grid),

src/mdio/segy/parsers.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import multiprocessing as mp
65
import os
76
from concurrent.futures import ProcessPoolExecutor
87
from itertools import repeat
@@ -48,15 +47,17 @@ def parse_index_headers(
4847

4948
trace_ranges.append((start, stop))
5049

51-
# For Unix async reads with s3fs/fsspec & multiprocessing, use 'spawn' instead of default
52-
# 'fork' to avoid deadlocks on cloud stores. Slower but necessary. Default on Windows.
5350
num_cpus = int(os.getenv("MDIO__IMPORT__CPU_COUNT", default_cpus))
5451
num_workers = min(n_blocks, num_cpus)
55-
context = mp.get_context("spawn")
5652

53+
segy_kw = {
54+
"url": segy_file.fs.unstrip_protocol(segy_file.url),
55+
"spec": segy_file.spec,
56+
"settings": segy_file.settings,
57+
}
5758
tqdm_kw = {"unit": "block", "dynamic_ncols": True}
58-
with ProcessPoolExecutor(num_workers, mp_context=context) as executor:
59-
lazy_work = executor.map(header_scan_worker, repeat(segy_file), trace_ranges)
59+
with ProcessPoolExecutor(num_workers) as executor:
60+
lazy_work = executor.map(header_scan_worker, repeat(segy_kw), trace_ranges)
6061

6162
if progress_bar is True:
6263
lazy_work = tqdm(

0 commit comments

Comments
 (0)