Skip to content

Commit 7a8b751

Browse files
Chunked live mask (#527)
* Add minimal support for live mask configuration * formatting * Formatting * Reduce possibility of integer overflow * Automatic chunking of live_maks for grids that exceed blosc's maximum elements. * Remove old test * Formatting * Resolve pre-commit issues * Helps protect integer overflow and logs when type promotion is performed. * linting * Use number of samples instead of live samples Co-authored-by: Altay Sansal <tasansal@users.noreply.github.com> * Explicitly use the live mask shape for base case * Remove integer as a return type hint * Rework chunking computation and expand test coverage * Update to use Dask's chunk generation and ~500MiB chunk sizes * Mock input arrays to increase speed and avoid pipeline OOM * Cleanup, mocking, and relocation of autochunking * Linting * Clean tests up * Use auto-chunking for live mask in factory. * Fix import * Update to handle live_mask not existing in Grid * Use numpy empty instead of mocking class. Add test for Grid without live_mask. * Linting * Avoid bare except * Mock the grid array info. * simplify dtype determination logic * use numpy instead of hand coding * refactor auto chunking and optimize live mask and grid map creation * consolidate tests and adjust for refactor * only compare grid dims, because live mask is expected to be different * move types to type checking block * remve unnecessary comment * remove comment repeating var name * Use the pre-calculated constant * Avoid magic numbers * fix broken comment --------- Co-authored-by: Altay Sansal <tasansal@users.noreply.github.com>
1 parent 8e4af5b commit 7a8b751

File tree

7 files changed

+216
-30
lines changed

7 files changed

+216
-30
lines changed

src/mdio/constants.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,26 @@
1212
FLOAT64_MIN = np.finfo("float64").min
1313
FLOAT64_MAX = np.finfo("float64").max
1414

15-
INT8_MIN = -0x80
16-
INT8_MAX = 0x7F
15+
INT8_MIN = np.iinfo("int8").min
16+
INT8_MAX = np.iinfo("int8").max
1717

18-
INT16_MIN = -0x8000
19-
INT16_MAX = 0x7FFF
18+
INT16_MIN = np.iinfo("int16").min
19+
INT16_MAX = np.iinfo("int16").max
2020

21-
INT32_MIN = -0x80000000
22-
INT32_MAX = 0x7FFFFFFF
21+
INT32_MIN = np.iinfo("int32").min
22+
INT32_MAX = np.iinfo("int32").max
2323

24-
UINT8_MIN = 0x0
25-
UINT8_MAX = 0xFF
24+
INT64_MIN = np.iinfo("int64").min
25+
INT64_MAX = np.iinfo("int64").max
2626

27-
UINT16_MIN = 0x0
28-
UINT16_MAX = 0xFFFF
27+
UINT8_MIN = 0
28+
UINT8_MAX = np.iinfo("uint8").max
2929

30-
UINT32_MIN = 0x0
31-
UINT32_MAX = 0xFFFFFFFF
30+
UINT16_MIN = 0
31+
UINT16_MAX = np.iinfo("uint16").max
32+
33+
UINT32_MIN = 0
34+
UINT32_MAX = np.iinfo("uint32").max
35+
36+
UINT64_MIN = 0
37+
UINT64_MAX = np.iinfo("uint64").max

src/mdio/core/factory.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from mdio import MDIOWriter
3535
from mdio.api.io_utils import process_url
3636
from mdio.core import Grid
37+
from mdio.core.utils_write import get_live_mask_chunksize
3738
from mdio.core.utils_write import write_attribute
3839
from mdio.segy.helpers_segy import create_zarr_hierarchy
3940

@@ -145,10 +146,12 @@ def create_empty(
145146
write_attribute(name="text_header", zarr_group=meta_group, attribute=DEFAULT_TEXT)
146147
write_attribute(name="binary_header", zarr_group=meta_group, attribute={})
147148

149+
live_shape = config.grid.shape[:-1]
150+
live_chunks = get_live_mask_chunksize(live_shape)
148151
meta_group.create_dataset(
149152
name="live_mask",
150-
shape=config.grid.shape[:-1],
151-
chunks=-1,
153+
shape=live_shape,
154+
chunks=live_chunks,
152155
dtype="bool",
153156
dimension_separator="/",
154157
)

src/mdio/core/grid.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,20 @@
44

55
import inspect
66
from dataclasses import dataclass
7+
from typing import TYPE_CHECKING
78

89
import numpy as np
910
import zarr
1011

1112
from mdio.constants import UINT32_MAX
1213
from mdio.core import Dimension
1314
from mdio.core.serialization import Serializer
15+
from mdio.core.utils_write import get_constrained_chunksize
16+
17+
18+
if TYPE_CHECKING:
19+
from segy.arrays import HeaderArray
20+
from zarr import Array as ZarrArray
1421

1522

1623
@dataclass
@@ -25,10 +32,14 @@ class Grid:
2532
2633
Args:
2734
dims: List of dimension instances.
28-
2935
"""
3036

3137
dims: list[Dimension]
38+
map: ZarrArray | None = None
39+
live_mask: ZarrArray | None = None
40+
41+
_TARGET_MEMORY_PER_BATCH = 1 * 1024**3 # 1GB target for batch process map
42+
_INTERNAL_CHUNK_SIZE_TARGET = 10 * 1024**2 # 10MB target for internal chunks
3243

3344
def __post_init__(self):
3445
"""Initialize convenience properties."""
@@ -77,23 +88,51 @@ def from_zarr(cls, zarr_root: zarr.Group):
7788

7889
return cls(dims_list)
7990

80-
def build_map(self, index_headers):
91+
def build_map(self, index_headers: HeaderArray) -> None:
8192
"""Build a map for live traces based on `index_headers`.
8293
8394
Args:
8495
index_headers: Headers to be normalized (indexed)
8596
"""
86-
live_dim_indices = tuple()
87-
for dim in self.dims[:-1]:
88-
dim_hdr = index_headers[dim.name]
89-
live_dim_indices += (np.searchsorted(dim, dim_hdr),)
90-
91-
# We set dead traces to uint32 max. Should be far away from actual trace counts.
92-
self.map = zarr.full(self.shape[:-1], dtype="uint32", fill_value=UINT32_MAX)
93-
self.map.vindex[live_dim_indices] = range(len(live_dim_indices[0]))
94-
95-
self.live_mask = zarr.zeros(self.shape[:-1], dtype="bool")
96-
self.live_mask.vindex[live_dim_indices] = 1
97+
# Determine data type for the map based on grid size
98+
grid_size = np.prod(self.shape[:-1])
99+
map_dtype = "uint64" if grid_size > UINT32_MAX else "uint32"
100+
fill_value = np.iinfo(map_dtype).max
101+
102+
# Initialize Zarr arrays for the map and live mask
103+
live_shape = self.shape[:-1]
104+
chunks = get_constrained_chunksize(
105+
shape=live_shape,
106+
dtype=map_dtype,
107+
max_bytes=self._INTERNAL_CHUNK_SIZE_TARGET,
108+
)
109+
# Temporary zarrs for ingestion.
110+
self.map = zarr.full(live_shape, fill_value, dtype=map_dtype, chunks=chunks)
111+
self.live_mask = zarr.zeros(live_shape, dtype="bool", chunks=chunks)
112+
113+
# Calculate batch size for processing
114+
memory_per_trace_index = index_headers.itemsize
115+
batch_size = int(self._TARGET_MEMORY_PER_BATCH / memory_per_trace_index)
116+
total_live_traces = index_headers.size
117+
118+
# Process live traces in batches
119+
for start in range(0, total_live_traces, batch_size):
120+
end = min(start + batch_size, total_live_traces)
121+
122+
# Compute indices for the current batch
123+
live_dim_indices = []
124+
for dim in self.dims[:-1]:
125+
dim_hdr = index_headers[dim.name][start:end]
126+
indices = np.searchsorted(dim, dim_hdr).astype(np.uint32)
127+
live_dim_indices.append(indices)
128+
live_dim_indices = tuple(live_dim_indices)
129+
130+
# Generate trace indices for the batch
131+
trace_indices = np.arange(start, end, dtype=np.uint64)
132+
133+
# Update Zarr arrays for the batch
134+
self.map.vindex[live_dim_indices] = trace_indices
135+
self.live_mask.vindex[live_dim_indices] = True
97136

98137

99138
class GridSerializer(Serializer):

src/mdio/core/utils_write.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
"""Convenience utilities for writing to Zarr."""
22

3+
from typing import TYPE_CHECKING
34
from typing import Any
45

5-
import zarr
6+
from dask.array.core import normalize_chunks
7+
from dask.array.rechunk import _balance_chunksizes
68

79

8-
def write_attribute(name: str, attribute: Any, zarr_group: zarr.Group) -> None:
10+
if TYPE_CHECKING:
11+
from numpy.typing import DTypeLike
12+
from zarr import Group
13+
14+
15+
MAX_SIZE_LIVE_MASK = 512 * 1024**2
16+
17+
18+
def write_attribute(name: str, attribute: Any, zarr_group: "Group") -> None:
919
"""Write a mappable to Zarr array or group attribute.
1020
1121
Args:
@@ -14,3 +24,34 @@ def write_attribute(name: str, attribute: Any, zarr_group: zarr.Group) -> None:
1424
zarr_group: Output group or array.
1525
"""
1626
zarr_group.attrs[name] = attribute
27+
28+
29+
def get_constrained_chunksize(
30+
shape: tuple[int, ...],
31+
dtype: "DTypeLike",
32+
max_bytes: int,
33+
) -> tuple[int]:
34+
"""Calculate the optimal chunk size for N-D array based on max_bytes.
35+
36+
Args:
37+
shape: The shape of the array.
38+
dtype: The data dtype to be used in calculation.
39+
max_bytes: The maximum allowed number of bytes per chunk.
40+
41+
Returns:
42+
A sequence of integers of calculated chunk sizes.
43+
"""
44+
chunks = normalize_chunks("auto", shape, dtype=dtype, limit=max_bytes)
45+
return tuple(_balance_chunksizes(chunk)[0] for chunk in chunks)
46+
47+
48+
def get_live_mask_chunksize(shape: tuple[int, ...]) -> tuple[int]:
49+
"""Given a live_mask shape, calculate the optimal write chunk size.
50+
51+
Args:
52+
shape: The shape of the array.
53+
54+
Returns:
55+
A sequence of integers of calculated chunk sizes.
56+
"""
57+
return get_constrained_chunksize(shape, "bool", MAX_SIZE_LIVE_MASK)

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
"""Test configuration before everything runs."""
22

3+
import warnings
34
from os import path
45
from urllib.request import urlretrieve
56

67
import pytest
78

89

10+
# Suppress Dask's chunk balancing warning
11+
warnings.filterwarnings(
12+
"ignore",
13+
message="Could not balance chunks to be equal",
14+
category=UserWarning,
15+
module="dask.array.rechunk",
16+
)
17+
18+
919
@pytest.fixture(scope="session")
1020
def fake_segy_tmp(tmp_path_factory):
1121
"""Make a temp file for the fake SEG-Y files we are going to create."""

tests/unit/test_auto_chunking.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Test live mask chunk size calculation."""
2+
3+
from typing import TYPE_CHECKING
4+
5+
import numpy as np
6+
import pytest
7+
8+
from mdio.core.utils_write import MAX_SIZE_LIVE_MASK
9+
from mdio.core.utils_write import get_constrained_chunksize
10+
from mdio.core.utils_write import get_live_mask_chunksize
11+
12+
13+
if TYPE_CHECKING:
14+
from numpy.typing import DTypeLike
15+
16+
17+
@pytest.mark.parametrize(
18+
("shape", "dtype", "limit", "expected_chunks"),
19+
[
20+
((100,), "int8", 100, (100,)), # 1D full chunk
21+
((8, 6), "int8", 20, (4, 4)), # 2D adjusted int8
22+
((6, 8), "int16", 96, (6, 8)), # 2D small int16
23+
((9, 6, 4), "int8", 100, (5, 5, 4)), # 3D adjusted
24+
((4, 5), "int32", 4, (1, 1)), # test minimum edge case
25+
((10, 10), "int8", 1000, (10, 10)), # big limit
26+
((7, 5), "int8", 35, (7, 5)), # test full primes
27+
((7, 5), "int8", 23, (4, 4)), # test adjusted primes
28+
],
29+
)
30+
def test_auto_chunking(
31+
shape: tuple[int, ...],
32+
dtype: "DTypeLike",
33+
limit: int,
34+
expected_chunks: tuple[int, ...],
35+
) -> None:
36+
"""Test automatic chunking based on size limit and an array spec."""
37+
result = get_constrained_chunksize(shape, dtype, limit)
38+
assert result == expected_chunks
39+
40+
41+
class TestAutoChunkLiveMask:
42+
"""Test class for live mask auto chunking."""
43+
44+
@pytest.mark.parametrize(
45+
("shape", "expected_chunks"),
46+
[
47+
((100,), (100,)), # small 1d
48+
((100, 100), (100, 100)), # small 2d
49+
((50000, 50000), (25000, 25000)), # large 2d
50+
((1500, 1500, 1500), (750, 750, 750)), # large 3d
51+
((1000, 1000, 100, 36), (334, 334, 100, 36)), # large 4d
52+
],
53+
)
54+
def test_auto_chunk_live_mask(
55+
self,
56+
shape: tuple[int, ...],
57+
expected_chunks: tuple[int, ...],
58+
) -> None:
59+
"""Test auto chunked live mask is within expected number of bytes."""
60+
result = get_live_mask_chunksize(shape)
61+
assert result == expected_chunks
62+
63+
@pytest.mark.parametrize(
64+
"shape",
65+
[
66+
# Below are >500MiB. Smaller ones tested above
67+
(32768, 32768),
68+
(46341, 46341),
69+
(86341, 96341),
70+
(55000, 97500),
71+
(100000, 100000),
72+
(1024, 1024, 1024),
73+
(215, 215, 215, 215),
74+
(512, 216, 512, 400),
75+
(74, 74, 74, 74, 74),
76+
(512, 17, 43, 200, 50),
77+
],
78+
)
79+
def test_auto_chunk_live_mask_nbytes(self, shape: tuple[int, ...]) -> None:
80+
"""Test auto chunked live mask is within expected number of bytes."""
81+
result = get_live_mask_chunksize(shape)
82+
chunk_elements = np.prod(result)
83+
84+
# We want them to be 500MB +/- 25%
85+
assert chunk_elements > MAX_SIZE_LIVE_MASK * 0.75
86+
assert chunk_elements < MAX_SIZE_LIVE_MASK * 1.25

tests/unit/test_factory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def test_create_empty_like(mock_reader: MDIOReader):
2020

2121
source_reader = mock_reader
2222
dest_reader = MDIOReader(dest_path)
23-
assert source_reader.grid == dest_reader.grid
23+
assert source_reader.grid.dims == dest_reader.grid.dims
24+
assert source_reader.live_mask != dest_reader.grid.live_mask
2425

2526
source_traces = source_reader._traces
2627
dest_traces = dest_reader._traces

0 commit comments

Comments
 (0)