diff --git a/src/mdio/constants.py b/src/mdio/constants.py index bff76531..b6fff807 100644 --- a/src/mdio/constants.py +++ b/src/mdio/constants.py @@ -1,36 +1,59 @@ """Constant values used across MDIO.""" -import numpy as np +from numpy import finfo as np_finfo +from numpy import iinfo as np_iinfo +from numpy import nan as np_nan -FLOAT16_MAX = np.finfo("float16").max -FLOAT16_MIN = np.finfo("float16").min +from mdio.schemas.dtype import ScalarType -FLOAT32_MAX = np.finfo("float32").max -FLOAT32_MIN = np.finfo("float32").min +FLOAT16_MAX = np_finfo("float16").max +FLOAT16_MIN = np_finfo("float16").min -FLOAT64_MIN = np.finfo("float64").min -FLOAT64_MAX = np.finfo("float64").max +FLOAT32_MAX = np_finfo("float32").max +FLOAT32_MIN = np_finfo("float32").min -INT8_MIN = np.iinfo("int8").min -INT8_MAX = np.iinfo("int8").max +FLOAT64_MIN = np_finfo("float64").min +FLOAT64_MAX = np_finfo("float64").max -INT16_MIN = np.iinfo("int16").min -INT16_MAX = np.iinfo("int16").max +INT8_MIN = np_iinfo("int8").min +INT8_MAX = np_iinfo("int8").max -INT32_MIN = np.iinfo("int32").min -INT32_MAX = np.iinfo("int32").max +INT16_MIN = np_iinfo("int16").min +INT16_MAX = np_iinfo("int16").max -INT64_MIN = np.iinfo("int64").min -INT64_MAX = np.iinfo("int64").max +INT32_MIN = np_iinfo("int32").min +INT32_MAX = np_iinfo("int32").max + +INT64_MIN = np_iinfo("int64").min +INT64_MAX = np_iinfo("int64").max UINT8_MIN = 0 -UINT8_MAX = np.iinfo("uint8").max +UINT8_MAX = np_iinfo("uint8").max UINT16_MIN = 0 -UINT16_MAX = np.iinfo("uint16").max +UINT16_MAX = np_iinfo("uint16").max UINT32_MIN = 0 -UINT32_MAX = np.iinfo("uint32").max +UINT32_MAX = np_iinfo("uint32").max UINT64_MIN = 0 -UINT64_MAX = np.iinfo("uint64").max +UINT64_MAX = np_iinfo("uint64").max + +# Zarr fill values for different scalar types +fill_value_map = { + ScalarType.BOOL: None, + ScalarType.FLOAT16: np_nan, + ScalarType.FLOAT32: np_nan, + ScalarType.FLOAT64: np_nan, + ScalarType.UINT8: 2**8 - 1, # Max value for uint8 + ScalarType.UINT16: 2**16 - 1, # Max value for uint16 + ScalarType.UINT32: 2**32 - 1, # Max value for uint32 + ScalarType.UINT64: 2**64 - 1, # Max value for uint64 + ScalarType.INT8: 2**7 - 1, # Max value for int8 + ScalarType.INT16: 2**15 - 1, # Max value for int16 + ScalarType.INT32: 2**31 - 1, # Max value for int32 + ScalarType.INT64: 2**63 - 1, # Max value for int64 + ScalarType.COMPLEX64: complex(np_nan, np_nan), + ScalarType.COMPLEX128: complex(np_nan, np_nan), + ScalarType.COMPLEX256: complex(np_nan, np_nan), +} diff --git a/src/mdio/schemas/v1/dataset_builder.py b/src/mdio/schemas/v1/dataset_builder.py index facd0dc8..698b1874 100644 --- a/src/mdio/schemas/v1/dataset_builder.py +++ b/src/mdio/schemas/v1/dataset_builder.py @@ -208,7 +208,7 @@ def add_coordinate( # noqa: PLR0913 # Add a coordinate variable to the dataset self.add_variable( name=coord.name, - long_name=f"'{coord.name}' coordinate variable", + long_name=coord.long_name, dimensions=dimensions, # dimension names (list[str]) data_type=coord.data_type, compressor=compressor, diff --git a/src/mdio/schemas/v1/dataset_serializer.py b/src/mdio/schemas/v1/dataset_serializer.py new file mode 100644 index 00000000..932816f2 --- /dev/null +++ b/src/mdio/schemas/v1/dataset_serializer.py @@ -0,0 +1,296 @@ +"""Convert MDIO v1 schema Dataset to Xarray DataSet and write it in Zarr.""" + +from collections.abc import Mapping + +from dask import array as dask_array +from numcodecs import Blosc as nc_Blosc +from numpy import dtype as np_dtype +from xarray import DataArray as xr_DataArray +from xarray import Dataset as xr_Dataset +from zarr.core.chunk_key_encodings import V2ChunkKeyEncoding + +try: + # zfpy is an optional dependency for ZFP compression + # It is not installed by default, so we check for its presence and import it only if available. + from zfpy import ZFPY as zfpy_ZFPY # noqa: N811 +except ImportError: + zfpy_ZFPY = None # noqa: N816 + +from mdio.constants import fill_value_map +from mdio.schemas.compressors import ZFP as mdio_ZFP # noqa: N811 +from mdio.schemas.compressors import Blosc as mdio_Blosc +from mdio.schemas.dimension import NamedDimension +from mdio.schemas.dtype import ScalarType +from mdio.schemas.dtype import StructuredType +from mdio.schemas.v1.dataset import Dataset +from mdio.schemas.v1.dataset_builder import _to_dictionary +from mdio.schemas.v1.variable import Coordinate +from mdio.schemas.v1.variable import Variable + + +def _get_all_named_dimensions(dataset: Dataset) -> dict[str, NamedDimension]: + """Get all NamedDimensions from the dataset variables. + + This function returns a dictionary of NamedDimensions, but if some dimensions + are not resolvable, they will not be included in the result. + + Args: + dataset: The MDIO Dataset to extract NamedDimensions from. + + Note: + The Dataset Builder ensures that all dimensions are resolvable by always embedding + dimensions as NamedDimension and never as str. + If the dataset is created in a different way, some dimensions may be specified as + dimension names (str) instead of NamedDimension. In this case, we will try to resolve + them to NamedDimension, but if the dimension is not found, it will be skipped. + It is the responsibility of the Dataset creator to ensure that all dimensions are + resolvable at the Dataset level. + + Returns: + A dictionary mapping dimension names to NamedDimension instances. + """ + all_named_dims: dict[str, NamedDimension] = {} + for v in dataset.variables: + if v.dimensions is not None: + for d in v.dimensions: + if isinstance(d, NamedDimension): + all_named_dims[d.name] = d + else: + pass + return all_named_dims + + +def _get_dimension_names(var: Variable) -> list[str]: + """Get the names of dimensions for a variable. + + Note: + We expect that Datasets produced by DatasetBuilder has all dimensions + embedded as NamedDimension, but we also support dimension name strings for + compatibility with Dataset produced in a different way. + """ + dim_names: list[str] = [] + if var.dimensions is not None: + for d in var.dimensions: + if isinstance(d, NamedDimension): + dim_names.append(d.name) + elif isinstance(d, str): + dim_names.append(d) + return dim_names + + +def _get_coord_names(var: Variable) -> list[str]: + """Get the names of coordinates for a variable.""" + coord_names: list[str] = [] + if var.coordinates is not None: + for c in var.coordinates: + if isinstance(c, Coordinate): + coord_names.append(c.name) + elif isinstance(c, str): + coord_names.append(c) + return coord_names + + +def _get_np_datatype(var: Variable) -> np_dtype: + """Get the numpy dtype for a variable.""" + data_type = var.data_type + if isinstance(data_type, ScalarType): + return np_dtype(data_type.value) + if isinstance(data_type, StructuredType): + return np_dtype([(f.name, f.format.value) for f in data_type.fields]) + err = f"Unsupported data type: {type(data_type)} in variable {var.name}" + raise TypeError(err) + + +def _get_zarr_shape(var: Variable, all_named_dims: dict[str, NamedDimension]) -> tuple[int, ...]: + """Get the shape of a variable for Zarr storage. + + Note: + We expect that Datasets produced by DatasetBuilder has all dimensions + embedded as NamedDimension, but we also support dimension name strings for + compatibility with Dataset produced in a different way. + """ + shape: list[int] = [] + for dim in var.dimensions: + if isinstance(dim, NamedDimension): + shape.append(dim.size) + if isinstance(dim, str): + named_dim = all_named_dims.get(dim) + if named_dim is None: + err = f"Dimension named '{dim}' can't be resolved to a NamedDimension." + raise ValueError(err) + shape.append(named_dim.size) + return tuple(shape) + + +def _get_zarr_chunks(var: Variable, all_named_dims: dict[str, NamedDimension]) -> tuple[int, ...]: + """Get the chunk shape for a variable, defaulting to its shape if no chunk grid is defined.""" + if var.metadata is not None and var.metadata.chunk_grid is not None: + return tuple(var.metadata.chunk_grid.configuration.chunk_shape) + # Default to full shape if no chunk grid is defined + return _get_zarr_shape(var, all_named_dims=all_named_dims) + + +def _convert_compressor( + compressor: mdio_Blosc | mdio_ZFP | None, +) -> nc_Blosc | zfpy_ZFPY | None: + """Convert a compressor to a numcodecs compatible format.""" + if compressor is None: + return None + + if isinstance(compressor, mdio_Blosc): + return nc_Blosc( + cname=compressor.algorithm.value, + clevel=compressor.level, + shuffle=compressor.shuffle.value, + blocksize=compressor.blocksize if compressor.blocksize > 0 else 0, + ) + + if isinstance(compressor, mdio_ZFP): + if zfpy_ZFPY is None: + msg = "zfpy and numcodecs are required to use ZFP compression" + raise ImportError(msg) + return zfpy_ZFPY( + mode=compressor.mode.value, + tolerance=compressor.tolerance, + rate=compressor.rate, + precision=compressor.precision, + ) + + msg = f"Unsupported compressor model: {type(compressor)}" + raise TypeError(msg) + + +def _get_fill_value(data_type: ScalarType | StructuredType | str) -> any: + """Get the fill value for a given data type. + + The Zarr fill_value is a scalar value providing the default value to use for + uninitialized portions of the array, or null if no fill_value is to be used + https://zarr-specs.readthedocs.io/en/latest/v2/v2.0.html + """ + if isinstance(data_type, ScalarType): + return fill_value_map.get(data_type) + if isinstance(data_type, StructuredType): + return tuple(fill_value_map.get(field.format) for field in data_type.fields) + if isinstance(data_type, str): + return "" + # If we do not have a fill value for this type, use None + return None + + +def to_xarray_dataset(mdio_ds: Dataset) -> xr_DataArray: # noqa: PLR0912 + """Build an XArray dataset with correct dimensions and dtypes. + + This function constructs the underlying data structure for an XArray dataset, + handling dimension mapping, data types, and metadata organization. + + Args: + mdio_ds: The source MDIO dataset to construct from. + + Returns: + The constructed dataset with proper MDIO structure and metadata. + """ + # See the xarray tutorial for more details on how to create datasets: + # https://tutorial.xarray.dev/fundamentals/01.1_creating_data_structures.html + + all_named_dims = _get_all_named_dimensions(mdio_ds) + + # First pass: Build all variables + data_arrays: dict[str, xr_DataArray] = {} + for v in mdio_ds.variables: + # Use dask array instead of numpy array for lazy evaluation + shape = _get_zarr_shape(v, all_named_dims=all_named_dims) + dtype = _get_np_datatype(v) + chunks = _get_zarr_chunks(v, all_named_dims=all_named_dims) + arr = dask_array.zeros(shape, dtype=dtype, chunks=chunks) + + # Create a DataArray for the variable. We will set coords in the second pass + dim_names = _get_dimension_names(v) + data_array = xr_DataArray(arr, dims=dim_names) + + # Add array attributes + if v.metadata is not None: + meta_dict = _to_dictionary(v.metadata) + # Exclude chunk_grid + del meta_dict["chunkGrid"] + # Remove empty attributes + meta_dict = {k: v for k, v in meta_dict.items() if v is not None} + # Add metadata to the data array attributes + data_array.attrs.update(meta_dict) + if v.long_name: + data_array.attrs["long_name"] = v.long_name + + # Create a custom chunk key encoding with "/" as separator + chunk_key_encoding = V2ChunkKeyEncoding(separator="/").to_dict() + encoding = { + # Is this a bug in Zarr? For datatype: + # dtype([('cdp-x', ' None: + """Write an XArray dataset to Zarr format. + + Args: + dataset: The XArray dataset to write. + store: The Zarr store to write to. If None, defaults to in-memory store. + *args: Additional positional arguments for the Zarr store. + **kwargs: Additional keyword arguments for the Zarr store. + + Notes: + It sets the zarr_format to 2, which is the default for XArray datasets. + Since we set kwargs["compute"], this method will return a dask.delayed.Delayed object + and the arrays will not be immediately written. + + References: + https://docs.xarray.dev/en/stable/user-guide/io.html + https://docs.xarray.dev/en/latest/generated/xarray.DataArray.to_zarr.html + + Returns: + None: The function writes the dataset as dask.delayed.Delayed object to the + specified Zarr store. + """ + kwargs["zarr_format"] = 2 + kwargs["compute"] = False + return dataset.to_zarr(*args, store=store, **kwargs) diff --git a/tests/unit/v1/helpers.py b/tests/unit/v1/helpers.py index 2058bdd5..d0ebe5d7 100644 --- a/tests/unit/v1/helpers.py +++ b/tests/unit/v1/helpers.py @@ -129,6 +129,7 @@ def _get_coordinate( def _get_all_coordinates(dataset: Dataset) -> list[Coordinate]: + """Get all coordinates from the dataset.""" all_coords: dict[str, Coordinate] = {} for v in dataset.variables: if v.coordinates is not None: @@ -138,8 +139,24 @@ def _get_all_coordinates(dataset: Dataset) -> list[Coordinate]: return list(all_coords.values()) -def make_campos_3d_dataset() -> Dataset: - """Create in-memory campos_3d dataset.""" +def output_path(file_dir: str, file_name: str, debugging: bool = False) -> str: + """Generate the output path for the test file-system output. + + Note: + Use debugging=True, if you need to retain the created files for debugging + purposes. Otherwise, the files will be created in-memory and not saved to disk. + """ + if debugging: + # Use the following for debugging: + file_path = f"{file_dir}/mdio-tests/{file_name}.zarr" + else: + # Use the following for normal runs: + file_path = f"memory://path_to_zarr/mdio-tests/{file_name}.zarr" + return file_path + + +def make_seismic_poststack_3d_acceptance_dataset() -> Dataset: + """Create in-memory Seismic PostStack 3D Acceptance dataset.""" ds = MDIODatasetBuilder( "campos_3d", attributes=UserAttributes( @@ -163,7 +180,7 @@ def make_campos_3d_dataset() -> Dataset: ds.add_coordinate( "depth", dimensions=["depth"], - data_type=ScalarType.FLOAT64, + data_type=ScalarType.UINT32, metadata_info=[AllUnits(units_v1=LengthUnitModel(length=LengthUnitEnum.METER))], ) # Add coordinates @@ -203,7 +220,7 @@ def make_campos_3d_dataset() -> Dataset: histogram=CenteredBinHistogram(binCenters=[1, 2], counts=[10, 15]), ) ), - UserAttributes(attributes={"fizz": "buzz", "UnitSystem": "Canonical"}), + UserAttributes(attributes={"fizz": "buzz"}), ], ) # Add velocity variable @@ -241,14 +258,19 @@ def make_campos_3d_dataset() -> Dataset: ds.add_variable( name="image_headers", dimensions=["inline", "crossline"], + coordinates=["cdp-x", "cdp-y"], data_type=StructuredType( fields=[ - StructuredField(name="cdp-x", format=ScalarType.FLOAT32), - StructuredField(name="cdp-y", format=ScalarType.FLOAT32), - StructuredField(name="inline", format=ScalarType.UINT32), - StructuredField(name="crossline", format=ScalarType.UINT32), + StructuredField(name="cdp-x", format=ScalarType.INT32), + StructuredField(name="cdp-y", format=ScalarType.INT32), + StructuredField(name="elevation", format=ScalarType.FLOAT16), + StructuredField(name="some_scalar", format=ScalarType.FLOAT16), ] ), - coordinates=["cdp-x", "cdp-y"], + metadata_info=[ + ChunkGridMetadata( + chunk_grid=RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=[128, 128])) + ) + ], ) return ds.build() diff --git a/tests/unit/v1/test_dataset_builder_add_coordinate.py b/tests/unit/v1/test_dataset_builder_add_coordinate.py index 46cb27cd..20f68602 100644 --- a/tests/unit/v1/test_dataset_builder_add_coordinate.py +++ b/tests/unit/v1/test_dataset_builder_add_coordinate.py @@ -93,7 +93,7 @@ def test_add_coordinate_with_defaults() -> None: coords=["cdp"], dtype=ScalarType.FLOAT32, ) - assert v.long_name == "'cdp' coordinate variable" # Default value + assert v.long_name is None # Default value assert v.compressor is None # Default value assert v.metadata is None # Default value diff --git a/tests/unit/v1/test_dataset_builder_add_variable.py b/tests/unit/v1/test_dataset_builder_add_variable.py index 84abfe3a..6f46db50 100644 --- a/tests/unit/v1/test_dataset_builder_add_variable.py +++ b/tests/unit/v1/test_dataset_builder_add_variable.py @@ -2,7 +2,7 @@ # PLR2004 Magic value used in comparison, consider replacing `3` with a constant variable # The above erroneous warning is generated for every numerical assert. # Thus, disable it for this file -"""Tests the schema v1 Variable public API.""" +"""Tests the schema v1 dataset_builder.add_variable() public API.""" import pytest diff --git a/tests/unit/v1/test_dataset_builder_build.py b/tests/unit/v1/test_dataset_builder_build.py index aa5fad55..2a68c833 100644 --- a/tests/unit/v1/test_dataset_builder_build.py +++ b/tests/unit/v1/test_dataset_builder_build.py @@ -2,7 +2,7 @@ # PLR2004 Magic value used in comparison, consider replacing `3` with a constant variable # The above erroneous warning is generated for every numerical assert. # Thus, disable it for this file -"""Tests the schema v1 dataset_builder.add_coordinate() public API.""" +"""Tests the schema v1 dataset_builder.build() public API.""" from mdio.schemas.dtype import ScalarType from mdio.schemas.dtype import StructuredField @@ -12,7 +12,7 @@ from mdio.schemas.v1.units import LengthUnitEnum from mdio.schemas.v1.units import SpeedUnitEnum -from .helpers import make_campos_3d_dataset +from .helpers import make_seismic_poststack_3d_acceptance_dataset from .helpers import validate_variable @@ -47,9 +47,9 @@ def test_build() -> None: assert next(v for v in dataset.variables if v.name == "data") is not None -def test_build_campos_3d() -> None: # noqa: PLR0915 Too many statements (57 > 50) - """Test building a Campos 3D dataset with multiple variables and attributes.""" - dataset = make_campos_3d_dataset() +def test_build_seismic_poststack_3d_acceptance_dataset() -> None: # noqa: PLR0915 Too many statements (57 > 50) + """Test building a Seismic PostStack 3D Acceptance dataset.""" + dataset = make_seismic_poststack_3d_acceptance_dataset() # Verify dataset structure assert dataset.metadata.name == "campos_3d" @@ -75,7 +75,7 @@ def test_build_campos_3d() -> None: # noqa: PLR0915 Too many statements (57 > 5 ) depth = validate_variable( - dataset, name="depth", dims=[("depth", 384)], coords=["depth"], dtype=ScalarType.FLOAT64 + dataset, name="depth", dims=[("depth", 384)], coords=["depth"], dtype=ScalarType.UINT32 ) assert depth.metadata.units_v1.length == LengthUnitEnum.METER @@ -146,10 +146,11 @@ def test_build_campos_3d() -> None: # noqa: PLR0915 Too many statements (57 > 5 coords=["cdp-x", "cdp-y"], dtype=StructuredType( fields=[ - StructuredField(name="cdp-x", format=ScalarType.FLOAT32), - StructuredField(name="cdp-y", format=ScalarType.FLOAT32), - StructuredField(name="inline", format=ScalarType.UINT32), - StructuredField(name="crossline", format=ScalarType.UINT32), + StructuredField(name="cdp-x", format=ScalarType.INT32), + StructuredField(name="cdp-y", format=ScalarType.INT32), + StructuredField(name="elevation", format=ScalarType.FLOAT16), + StructuredField(name="some_scalar", format=ScalarType.FLOAT16), ] ), ) + assert headers.metadata.chunk_grid.configuration.chunk_shape == [128, 128] diff --git a/tests/unit/v1/test_dataset_serializer.py b/tests/unit/v1/test_dataset_serializer.py new file mode 100644 index 00000000..aa715e06 --- /dev/null +++ b/tests/unit/v1/test_dataset_serializer.py @@ -0,0 +1,386 @@ +# ruff: noqa: PLR2004 +# PLR2004 Magic value used in comparison, consider replacing `3` with a constant variable +# The above erroneous warning is generated for every numerical assert. +# Thus, disable it for this file +"""Tests the schema v1 dataset_serializer public API.""" + +import pytest +from numpy import dtype as np_dtype +from numpy import nan as np_nan +from numpy import isnan as np_isnan + +from mdio.constants import fill_value_map +from mdio.schemas.chunk_grid import RegularChunkGrid +from mdio.schemas.chunk_grid import RegularChunkShape +from mdio.schemas.dimension import NamedDimension +from mdio.schemas.dtype import ScalarType +from mdio.schemas.dtype import StructuredField +from mdio.schemas.dtype import StructuredType +from mdio.schemas.metadata import ChunkGridMetadata +from mdio.schemas.v1.dataset import Dataset +from mdio.schemas.v1.dataset import DatasetInfo +from mdio.schemas.v1.dataset_builder import MDIODatasetBuilder +from mdio.schemas.v1.dataset_builder import _to_dictionary +from mdio.schemas.v1.dataset_serializer import _convert_compressor +from mdio.schemas.v1.dataset_serializer import _get_all_named_dimensions +from mdio.schemas.v1.dataset_serializer import _get_coord_names +from mdio.schemas.v1.dataset_serializer import _get_dimension_names +from mdio.schemas.v1.dataset_serializer import _get_fill_value +from mdio.schemas.v1.dataset_serializer import _get_np_datatype +from mdio.schemas.v1.dataset_serializer import _get_zarr_chunks +from mdio.schemas.v1.dataset_serializer import _get_zarr_shape +from mdio.schemas.v1.dataset_serializer import to_xarray_dataset +from mdio.schemas.v1.dataset_serializer import to_zarr +from mdio.schemas.v1.variable import Coordinate +from mdio.schemas.v1.variable import Variable + +from .helpers import make_seismic_poststack_3d_acceptance_dataset +from .helpers import output_path + +try: + from zfpy import ZFPY as zfpy_ZFPY # noqa: N811 + + HAS_ZFPY = True +except ImportError: + zfpy_ZFPY = None # noqa: N816 + HAS_ZFPY = False + +from numcodecs import Blosc as nc_Blosc + +from mdio.schemas.compressors import ZFP as mdio_ZFP # noqa: N811 +from mdio.schemas.compressors import Blosc as mdio_Blosc +from mdio.schemas.compressors import BloscAlgorithm as mdio_BloscAlgorithm +from mdio.schemas.compressors import BloscShuffle as mdio_BloscShuffle +from mdio.schemas.compressors import ZFPMode as mdio_ZFPMode + + +def test__get_all_named_dimensions() -> None: + """Test _get_all_named_dimensions function.""" + dim1 = NamedDimension(name="inline", size=100) + dim2 = NamedDimension(name="crossline", size=200) + dim3 = NamedDimension(name="depth", size=300) + v1 = Variable(name="named_dims", data_type=ScalarType.FLOAT32, dimensions=[dim1, dim2, dim3]) + v2 = Variable( + name="string_dims", + data_type=ScalarType.FLOAT32, + dimensions=["inline", "crossline", "depth"], + ) + v3 = Variable(name="unresolved_dims", data_type=ScalarType.FLOAT32, dimensions=["x", "y", "z"]) + ds = Dataset( + variables=[v1, v2, v3], + metadata=_to_dictionary( + [ + DatasetInfo( + name="test_dataset", api_version="1.0.0", created_on="2023-10-01T00:00:00Z" + ) + ] + ), + ) + + all_dims = _get_all_named_dimensions(ds) + # Only 3 named dimensions could be resolved. + # The dimension names "x", "y', "z" are unresolvable. + assert set(all_dims) == {"inline", "crossline", "depth"} + + +def test__get_dimension_names() -> None: + """Test _get_dimension_names function with various dimension types.""" + dim1 = NamedDimension(name="inline", size=100) + dim2 = NamedDimension(name="crossline", size=200) + + # Test case 1: Variable with NamedDimension + var_named_dims = Variable( + name="Variable with NamedDimension dimensions", + data_type=ScalarType.FLOAT32, + dimensions=[dim1, dim2], + ) + assert set(_get_dimension_names(var_named_dims)) == {"inline", "crossline"} + + # Test case 2: Variable with string dimensions + var_string_dims = Variable( + name="Variable with string dimensions", + data_type=ScalarType.FLOAT32, + dimensions=["x", "y", "z"], + ) + assert set(_get_dimension_names(var_string_dims)) == {"x", "y", "z"} + + # Test case 3: Mixed NamedDimension and string dimensions + # NOTE: mixing NamedDimension and string dimensions is not allowed by the Variable schema + + +def test__get_coord_names() -> None: + """Comprehensive test for _get_coord_names function covering all scenarios.""" + dim1 = NamedDimension(name="inline", size=100) + dim2 = NamedDimension(name="crossline", size=200) + + # Test 1: Variable with Coordinate objects + coord1 = Coordinate(name="x_coord", dimensions=[dim1, dim2], data_type=ScalarType.FLOAT32) + coord2 = Coordinate(name="y_coord", dimensions=[dim1, dim2], data_type=ScalarType.FLOAT64) + variable_coords = Variable( + name="Variable with Coordinate objects", + data_type=ScalarType.FLOAT32, + dimensions=[dim1, dim2], + coordinates=[coord1, coord2], + ) + assert set(_get_coord_names(variable_coords)) == {"x_coord", "y_coord"} + + # Test 2: Variable with string coordinates + variable_strings = Variable( + name="Variable with string coordinates", + data_type=ScalarType.FLOAT32, + dimensions=[dim1, dim2], + coordinates=["lat", "lon", "time"], + ) + assert set(_get_coord_names(variable_strings)) == {"lat", "lon", "time"} + + # Test 3: Variable with mixed coordinate types + # NOTE: mixing Coordinate objects and coordinate name strings is not allowed by the + # Variable schema + + +def test__get_np_datatype() -> None: + """Comprehensive test for _get_np_datatype function.""" + # Test 1: ScalarType cases - all supported scalar types + scalar_type_tests = [ + (ScalarType.FLOAT32, "float32"), + (ScalarType.FLOAT64, "float64"), + (ScalarType.INT8, "int8"), + (ScalarType.INT16, "int16"), + (ScalarType.INT32, "int32"), + (ScalarType.INT64, "int64"), + (ScalarType.UINT8, "uint8"), + (ScalarType.UINT16, "uint16"), + (ScalarType.UINT32, "uint32"), + (ScalarType.UINT64, "uint64"), + (ScalarType.COMPLEX64, "complex64"), + (ScalarType.COMPLEX128, "complex128"), + (ScalarType.BOOL, "bool"), + ] + + for scalar_type, expected_numpy_type in scalar_type_tests: + variable = Variable(name="test_var", dimensions=[], data_type=scalar_type) + + result = _get_np_datatype(variable) + expected = np_dtype(expected_numpy_type) + + assert result == expected + assert isinstance(result, np_dtype) + assert result.name == expected.name + + # Test 2: StructuredType with multiple fields + multi_fields = [ + StructuredField(name="x", format=ScalarType.FLOAT64), + StructuredField(name="y", format=ScalarType.FLOAT64), + StructuredField(name="z", format=ScalarType.FLOAT64), + StructuredField(name="id", format=ScalarType.INT32), + StructuredField(name="valid", format=ScalarType.BOOL), + ] + structured_multi = StructuredType(fields=multi_fields) + + variable_multi_struct = Variable( + name="multi_struct_var", dimensions=[], data_type=structured_multi + ) + + result_multi = _get_np_datatype(variable_multi_struct) + expected_multi = np_dtype( + [("x", "float64"), ("y", "float64"), ("z", "float64"), ("id", "int32"), ("valid", "bool")] + ) + + assert result_multi == expected_multi + assert isinstance(result_multi, np_dtype) + assert len(result_multi.names) == 5 + assert set(result_multi.names) == {"x", "y", "z", "id", "valid"} + + +def test__get_zarr_shape() -> None: + """Test for _get_zarr_shape function.""" + d1 = NamedDimension(name="inline", size=100) + d2 = NamedDimension(name="crossline", size=200) + d3 = NamedDimension(name="depth", size=300) + + v = Variable(name="seismic 3d var", data_type=ScalarType.FLOAT32, dimensions=[d1, d2, d3]) + assert _get_zarr_shape(v, all_named_dims=[d1, d2, d3]) == (100, 200, 300) + + +def test__get_zarr_chunks() -> None: + """Test for _get_zarr_chunks function.""" + d1 = NamedDimension(name="inline", size=100) + d2 = NamedDimension(name="crossline", size=200) + d3 = NamedDimension(name="depth", size=300) + + # Test 1: Variable with chunk defined in metadata + v = Variable( + name="seismic 3d var", + data_type=ScalarType.FLOAT32, + dimensions=[d1, d2, d3], + metadata=_to_dictionary( + ChunkGridMetadata( + chunk_grid=RegularChunkGrid( + configuration=RegularChunkShape(chunk_shape=[10, 20, 30]) + ) + ) + ), + ) + assert _get_zarr_chunks(v, all_named_dims=[d1, d2, d3]) == (10, 20, 30) + + # Test 2: Variable with no chunks defined + v = Variable(name="seismic 3d var", data_type=ScalarType.FLOAT32, dimensions=[d1, d2, d3]) + assert _get_zarr_chunks(v, all_named_dims=[d1, d2, d3]) == (100, 200, 300) + + +def test__get_fill_value() -> None: + """Test for _get_fill_value function.""" + # Test 1: ScalarType cases - should return values from fill_value_map + scalar_types = [ + ScalarType.BOOL, + ] + for scalar_type in scalar_types: + assert _get_fill_value(scalar_type) is None + + scalar_types = [ + ScalarType.FLOAT16, + ScalarType.FLOAT32, + ScalarType.FLOAT64, + ] + for scalar_type in scalar_types: + assert np_isnan(_get_fill_value(scalar_type)) + + scalar_types = [ + ScalarType.UINT8, + ScalarType.UINT16, + ScalarType.UINT32, + ScalarType.INT8, + ScalarType.INT16, + ScalarType.INT32, + ] + for scalar_type in scalar_types: + assert fill_value_map[scalar_type] == _get_fill_value(scalar_type) + + scalar_types = [ + ScalarType.COMPLEX64, + ScalarType.COMPLEX128, + ScalarType.COMPLEX256, + ] + for scalar_type in scalar_types: + val = _get_fill_value(scalar_type) + assert isinstance(val, complex) + assert np_isnan(val.real) + assert np_isnan(val.imag) + + # Test 2: StructuredType + f1 = StructuredField(name="cdp-x", format=ScalarType.INT32) + f2 = StructuredField(name="cdp-y", format=ScalarType.INT32) + f3 = StructuredField(name="elevation", format=ScalarType.FLOAT16) + f4 = StructuredField(name="some_scalar", format=ScalarType.FLOAT16) + structured_type = StructuredType(fields=[f1, f2, f3, f4]) + result_structured = _get_fill_value(structured_type) + assert result_structured == (2147483647, 2147483647, np_nan, np_nan) + + # Test 3: String type - should return empty string + result_string = _get_fill_value("string_type") + assert result_string == "" + + # Test 4: Unknown type - should return None + result_none = _get_fill_value(42) # Invalid type + assert result_none is None + + # Test 5: None input - should return None + result_none_input = _get_fill_value(None) + assert result_none_input is None + + +def test__convert_compressor() -> None: + """Simple test for _convert_compressor function covering basic scenarios.""" + # Test 1: None input - should return None + result_none = _convert_compressor(None) + assert result_none is None + + # Test 2: mdio_Blosc compressor - should return nc_Blosc + result_blosc = _convert_compressor( + mdio_Blosc( + algorithm=mdio_BloscAlgorithm.LZ4, + level=5, + shuffle=mdio_BloscShuffle.AUTOSHUFFLE, + blocksize=1024, + ) + ) + assert isinstance(result_blosc, nc_Blosc) + assert result_blosc.cname == "lz4" # BloscAlgorithm.LZ4.value + assert result_blosc.clevel == 5 + assert result_blosc.shuffle == -1 # BloscShuffle.UTOSHUFFLE = -1 + assert result_blosc.blocksize == 1024 + + # Test 3: mdio_Blosc with blocksize 0 - should use 0 as blocksize + result_blosc_zero = _convert_compressor( + mdio_Blosc( + algorithm=mdio_BloscAlgorithm.ZSTD, + level=3, + shuffle=mdio_BloscShuffle.AUTOSHUFFLE, + blocksize=0, + ) + ) + assert isinstance(result_blosc_zero, nc_Blosc) + assert result_blosc_zero.blocksize == 0 + + # Test 4: mdio_ZFP compressor - should return zfpy_ZFPY if available + zfp_compressor = mdio_ZFP(mode=mdio_ZFPMode.FIXED_RATE, tolerance=0.01, rate=8.0, precision=16) + + if HAS_ZFPY: + result_zfp = _convert_compressor(zfp_compressor) + assert isinstance(result_zfp, zfpy_ZFPY) + assert result_zfp.mode == 1 # ZFPMode.FIXED_RATE.value = "fixed_rate" + assert result_zfp.tolerance == 0.01 + assert result_zfp.rate == 8.0 + assert result_zfp.precision == 16 + else: + # Test 5: mdio_ZFP without zfpy installed - should raise ImportError + with pytest.raises(ImportError) as exc_info: + _convert_compressor(zfp_compressor) + + error_message = str(exc_info.value) + assert "zfpy and numcodecs are required to use ZFP compression" in error_message + + # Test 6: Unsupported compressor type - should raise TypeError + unsupported_compressor = "invalid_compressor" + with pytest.raises(TypeError) as exc_info: + _convert_compressor(unsupported_compressor) + error_message = str(exc_info.value) + assert "Unsupported compressor model" in error_message + assert "" in error_message + + +def test_to_xarray_dataset(tmp_path) -> None: # noqa: ANN001 - tmp_path is a pytest fixture + """Test building a complete dataset.""" + dataset = ( + MDIODatasetBuilder("test_dataset") + .add_dimension("inline", 100) + .add_dimension("crossline", 200) + .add_dimension("depth", 300) + .add_coordinate("inline", dimensions=["inline"], data_type=ScalarType.FLOAT64) + .add_coordinate("crossline", dimensions=["crossline"], data_type=ScalarType.FLOAT64) + .add_coordinate("x_coord", dimensions=["inline", "crossline"], data_type=ScalarType.FLOAT32) + .add_coordinate("y_coord", dimensions=["inline", "crossline"], data_type=ScalarType.FLOAT32) + .add_variable( + "data", + long_name="Test Data", + dimensions=["inline", "crossline", "depth"], + coordinates=["inline", "crossline", "x_coord", "y_coord"], + data_type=ScalarType.FLOAT32, + ) + .build() + ) + + xr_ds = to_xarray_dataset(dataset) + + file_path = output_path(tmp_path, f"{xr_ds.attrs['name']}", debugging=False) + to_zarr(xr_ds, file_path, mode="w") + + +def test_seismic_poststack_3d_acceptance_to_xarray_dataset(tmp_path) -> None: # noqa: ANN001 + """Test building a complete dataset.""" + dataset = make_seismic_poststack_3d_acceptance_dataset() + + xr_ds = to_xarray_dataset(dataset) + + file_path = output_path(tmp_path, f"{xr_ds.attrs['name']}", debugging=True) + to_zarr(xr_ds, file_path, mode="w")