diff --git a/ci/docs.yml b/ci/docs.yml index 8b6f7a2a..63a0162d 100644 --- a/ci/docs.yml +++ b/ci/docs.yml @@ -29,4 +29,5 @@ dependencies: - cdshealpix - pip - pip: + - healpix-geo - -e .. diff --git a/ci/environment.yml b/ci/environment.yml index d33d87d0..d45bdc07 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -16,3 +16,6 @@ dependencies: - arro3-core - cdshealpix - h3ronpy + - pip + - pip: + - healpix-geo diff --git a/docs/api.rst b/docs/api.rst index efb0f280..2122b206 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -38,6 +38,11 @@ Parameters Dataset.dggs.grid_info Dataset.dggs.params + +.. autosummary:: + :toctree: generated + :template: autosummary/accessor_method.rst + Dataset.dggs.decode @@ -62,7 +67,12 @@ Parameters DataArray.dggs.grid_info DataArray.dggs.params - DataArray.dggs.decode + +.. autosummary:: + :toctree: generated + :template: autosummary/accessor_method.rst + + Dataset.dggs.decode Data inference diff --git a/docs/changelog.md b/docs/changelog.md index 39283937..a99ee0d2 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,11 +4,13 @@ ### New features +- memory-efficient index implementation based on multi-order coverage maps (MOCs) ({pull}`151`) + ### Bug fixes ### Documentation -- Documentation Contributer Guide + Github Button ({pull}`137`) +- Documentation Contributor's Guide + Github Button ({pull}`137`) ### Internal changes diff --git a/docs/conf.py b/docs/conf.py index 8cc092e9..867ed5a3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -115,6 +115,7 @@ "healpy": ("https://healpy.readthedocs.io/en/latest", None), "cdshealpix-python": ("https://cds-astro.github.io/cds-healpix-python", None), "shapely": ("https://shapely.readthedocs.io/en/stable", None), + "healpix-geo": ("https://healpix-geo.readthedocs.io/en/latest", None), } # -- myst-nb options --------------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 94f0f46d..2b00bb88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,9 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ "xarray", + "numpy>=2.0", "cdshealpix", + "healpix-geo>=0.0.3", "h3ronpy", "typing-extensions", "lonboard>=0.9.3", diff --git a/xdggs/accessor.py b/xdggs/accessor.py index 1852c981..7f92f01c 100644 --- a/xdggs/accessor.py +++ b/xdggs/accessor.py @@ -29,7 +29,9 @@ def __init__(self, obj: xr.Dataset | xr.DataArray): self._name = name self._index = index - def decode(self, grid_info=None, *, name="cell_ids") -> xr.Dataset | xr.DataArray: + def decode( + self, grid_info=None, *, name="cell_ids", index_options=None, **index_kwargs + ) -> xr.Dataset | xr.DataArray: """decode the DGGS cell ids Parameters @@ -39,6 +41,8 @@ def decode(self, grid_info=None, *, name="cell_ids") -> xr.Dataset | xr.DataArra the dataset. name : str, default: "cell_ids" The name of the coordinate containing the cell ids. + index_options, **index_kwargs : dict, optional + Additional options to forward to the index. Returns ------- @@ -51,7 +55,12 @@ def decode(self, grid_info=None, *, name="cell_ids") -> xr.Dataset | xr.DataArra if isinstance(grid_info, dict): var.attrs = grid_info - return self._obj.drop_indexes(name, errors="ignore").set_xindex(name, DGGSIndex) + if index_options is None: + index_options = {} + + return self._obj.drop_indexes(name, errors="ignore").set_xindex( + name, DGGSIndex, **(index_options | index_kwargs) + ) @property def index(self) -> DGGSIndex: diff --git a/xdggs/h3.py b/xdggs/h3.py index 0a5b103c..ec9a3dbc 100644 --- a/xdggs/h3.py +++ b/xdggs/h3.py @@ -23,7 +23,6 @@ cells_to_wkb_polygons, coordinates_to_cells, ) -from xarray.indexes import PandasIndex from xdggs.grid import DGGSInfo, translate_parameters from xdggs.index import DGGSIndex @@ -208,7 +207,7 @@ class H3Index(DGGSIndex): def __init__( self, - cell_ids: Any | PandasIndex, + cell_ids: Any | xr.Index, dim: str, grid_info: DGGSInfo, ): @@ -231,8 +230,8 @@ def from_variables( def grid_info(self) -> H3Info: return self._grid - def _replace(self, new_pd_index: PandasIndex): - return type(self)(new_pd_index, self._dim, self._grid) + def _replace(self, new_index: xr.Index): + return type(self)(new_index, self._dim, self._grid) def _repr_inline_(self, max_width: int): return f"H3Index(level={self._grid.level})" diff --git a/xdggs/healpix.py b/xdggs/healpix.py index 2da42bf5..2244f165 100644 --- a/xdggs/healpix.py +++ b/xdggs/healpix.py @@ -12,6 +12,7 @@ import cdshealpix.ring import numpy as np import xarray as xr +from healpix_geo.nested import RangeMOCIndex from xarray.indexes import PandasIndex from xdggs.grid import DGGSInfo, translate_parameters @@ -21,6 +22,13 @@ T = TypeVar("T") +try: + import dask.array as da + + dask_array_type = (da.Array,) +except ImportError: + dask_array_type = () + def polygons_shapely(vertices): import shapely @@ -112,12 +120,12 @@ class HealpixInfo(DGGSInfo): level: int """int : The hierarchical level of the grid""" - indexing_scheme: Literal["nested", "ring", "unique"] = "nested" + indexing_scheme: Literal["nested", "ring"] = "nested" """int : The indexing scheme of the grid""" valid_parameters: ClassVar[dict[str, Any]] = { "level": range(0, 29 + 1), - "indexing_scheme": ["nested", "ring", "unique"], + "indexing_scheme": ["nested", "ring"], } def __post_init__(self): @@ -309,18 +317,316 @@ def cell_boundaries(self, cell_ids: Any, backend="shapely") -> np.ndarray: return backend_func(vertices) +def construct_chunk_ranges(chunks, until): + start = 0 + + for chunksize in chunks: + stop = start + chunksize + if stop > until: + stop = until + if start == stop: + break + + if until - start < chunksize: + chunksize = until - start + + yield chunksize, slice(start, stop) + start = stop + + +def subset_chunks(chunks, indexer): + def _subset_slice(offset, chunk, indexer): + if offset >= indexer.stop or offset + chunk < indexer.start: + # outside slice + return 0 + elif offset >= indexer.start and offset + chunk < indexer.stop: + # full chunk + return chunk + else: + # partial chunk + left_trim = indexer.start - offset + right_trim = offset + chunk - indexer.stop + + if left_trim < 0: + left_trim = 0 + + if right_trim < 0: + right_trim = 0 + + return chunk - left_trim - right_trim + + def _subset_array(offset, chunk, indexer): + mask = (indexer >= offset) & (indexer < offset + chunk) + + return np.sum(mask.astype(int)) + + def _subset(offset, chunk, indexer): + if isinstance(indexer, slice): + return _subset_slice(offset, chunk, indexer) + else: + return _subset_array(offset, chunk, indexer) + + if chunks is None: + return None + + chunk_offsets = np.cumulative_sum(chunks, include_initial=True) + total_length = chunk_offsets[-1] + + if isinstance(indexer, slice): + indexer = slice(*indexer.indices(total_length)) + + trimmed_chunks = tuple( + _subset(offset, chunk, indexer) + for offset, chunk in zip(chunk_offsets[:-1], chunks) + ) + + return tuple(int(chunk) for chunk in trimmed_chunks if chunk > 0) + + +def extract_chunk(index, slice_): + return index.isel(slice_).cell_ids() + + +# optionally replaces the PandasIndex within HealpixIndex +class HealpixMocIndex(xr.Index): + """More efficient index for healpix cell ids based on a MOC + + This uses the rust `moc crate `_ to represent + cell ids as a set of disconnected ranges at level 29, vastly reducing the + memory footprint and computation time of set-like operations. + + .. warning:: + + Only supported for the ``nested`` scheme. + + See Also + -------- + healpix_geo.nested.RangeMOCIndex + The low-level implementation of the index functionality. + """ + + def __init__(self, index, *, dim, name, grid_info, chunksizes): + self._index = index + self._dim = dim + self._grid_info = grid_info + self._name = name + self._chunksizes = chunksizes + + @property + def size(self): + """The number of indexed cells.""" + return self._index.size + + @property + def nbytes(self): + """The number of bytes occupied by the index. + + .. note:: + This does not take any (constant) overhead into account. + """ + return self._index.nbytes + + @property + def chunksizes(self): + """The size of the chunks of the indexed coordinate.""" + return self._chunksizes + + @classmethod + def from_array(cls, array, *, dim, name, grid_info): + """Construct an index from a raw array. + + Parameters + ---------- + array : array-like + The array of cell ids as uint64. If the size is equal to the total + number of cells at the given refinement level, creates a full domain + index without looking at the cell ids. If a chunked array, it will + create indexes for each chunk and then merge the chunk indexes + in-memory. + dim : hashable + The dimension of the index. + name : hashable + The name of the indexed coordinate. + grid_info : xdggs.HealpixInfo + The grid parameters. + + Returns + ------- + index : HealpixMocIndex + The resulting index. + """ + if grid_info.indexing_scheme != "nested": + raise ValueError( + "The MOC index currently only supports the 'nested' scheme" + ) + + if array.ndim != 1: + raise ValueError("only 1D cell ids are supported") + + if array.size == 12 * 4**grid_info.level: + index = RangeMOCIndex.full_domain(grid_info.level) + elif isinstance(array, dask_array_type): + from functools import reduce + + import dask + + [indexes] = dask.compute( + dask.delayed(RangeMOCIndex.from_cell_ids)(grid_info.level, chunk) + for chunk in array.astype("uint64").to_delayed() + ) + index = reduce(RangeMOCIndex.union, indexes) + else: + index = RangeMOCIndex.from_cell_ids(grid_info.level, array.astype("uint64")) + + chunksizes = {dim: array.chunks[0] if hasattr(array, "chunks") else None} + return cls( + index, dim=dim, name=name, grid_info=grid_info, chunksizes=chunksizes + ) + + def _replace(self, index, chunksizes): + return type(self)( + index, + dim=self._dim, + name=self._name, + grid_info=self._grid_info, + chunksizes=chunksizes, + ) + + @classmethod + def from_variables(cls, variables, *, options): + """Create a new index object from the cell id coordinate variable + + Parameters + ---------- + variables : dict-like + Mapping of :py:class:`Variable` objects holding the coordinate labels + to index. + options : dict-like + Mapping of arbitrary options to pass to the HealpixInfo object. + + Returns + ------- + index : Index + A new Index object. + """ + name, var, dim = _extract_cell_id_variable(variables) + grid_info = HealpixInfo.from_dict(var.attrs | options) + + return cls.from_array(var.data, dim=dim, name=name, grid_info=grid_info) + + def create_variables(self, variables): + """Create new coordinate variables from this index + + Parameters + ---------- + variables : dict-like, optional + Mapping of :py:class:`Variable` objects. + + Returns + ------- + index_variables : dict-like + Dictionary of :py:class:`Variable` objects. + """ + name = self._name + if variables is not None and name in variables: + var = variables[name] + attrs = var.attrs + encoding = var.encoding + else: + attrs = None + encoding = None + + chunks = self._chunksizes[self._dim] + if chunks is not None: + import dask + import dask.array as da + + chunk_arrays = [ + da.from_delayed( + dask.delayed(extract_chunk)(self._index, slice_), + shape=(chunksize,), + dtype="uint64", + name=f"chunk-{index}", + meta=np.array((), dtype="uint64"), + ) + for index, (chunksize, slice_) in enumerate( + construct_chunk_ranges(chunks, self._index.size) + ) + ] + data = da.concatenate(chunk_arrays, axis=0) + var = xr.Variable(self._dim, data, attrs=attrs, encoding=encoding) + else: + var = xr.Variable( + self._dim, self._index.cell_ids(), attrs=attrs, encoding=encoding + ) + + return {name: var} + + def isel(self, indexers): + """Subset the index using positional indexers. + + Parameters + ---------- + indexers : dict-like + A dictionary of positional indexers as passed from + :py:meth:`Dataset.isel` and where the entries have been filtered for + the current index. Note that the underlying index currently only + supports slices. + + Returns + ------- + maybe_index : Index + A new Index object or ``None``. + """ + indexer = indexers[self._dim] + if isinstance(indexer, np.ndarray): + if np.isdtype(indexer.dtype, "signed integer"): + indexer = np.where(indexer >= 0, indexer, self.size + indexer).astype( + "uint64" + ) + elif np.isdtype(indexer.dtype, "unsigned integer"): + indexer = indexer.astype("uint64") + else: + raise ValueError("can only index with integer arrays or slices") + + new_chunksizes = { + self._dim: subset_chunks(self._chunksizes[self._dim], indexer) + } + + return self._replace(self._index.isel(indexer), chunksizes=new_chunksizes) + + @register_dggs("healpix") class HealpixIndex(DGGSIndex): def __init__( self, - cell_ids: Any | PandasIndex, + cell_ids: Any | xr.Index, dim: str, grid_info: DGGSInfo, + index_kind: str = "pandas", ): if not isinstance(grid_info, HealpixInfo): raise ValueError(f"grid info object has an invalid type: {type(grid_info)}") - super().__init__(cell_ids, dim, grid_info) + self._dim = dim + + if isinstance(cell_ids, xr.Index): + self._index = cell_ids + elif index_kind == "pandas": + self._index = PandasIndex(cell_ids, dim) + elif index_kind == "moc": + self._index = HealpixMocIndex.from_array( + cell_ids, dim=dim, grid_info=grid_info, name="cell_ids" + ) + self._kind = index_kind + + self._grid = grid_info + + def values(self): + if self._kind == "moc": + return self._index._index.cell_ids() + else: + return self._index.index.values @classmethod def from_variables( @@ -331,16 +637,21 @@ def from_variables( ) -> "HealpixIndex": _, var, dim = _extract_cell_id_variable(variables) + index_kind = options.pop("index_kind", "pandas") + grid_info = HealpixInfo.from_dict(var.attrs | options) - return cls(var.data, dim, grid_info) + return cls(var.data, dim, grid_info, index_kind=index_kind) + + def create_variables(self, variables): + return self._index.create_variables(variables) - def _replace(self, new_pd_index: PandasIndex): - return type(self)(new_pd_index, self._dim, self._grid) + def _replace(self, new_index: xr.Index): + return type(self)(new_index, self._dim, self._grid, index_kind=self._kind) @property def grid_info(self) -> HealpixInfo: return self._grid def _repr_inline_(self, max_width: int): - return f"HealpixIndex(level={self._grid.level}, indexing_scheme={self._grid.indexing_scheme})" + return f"HealpixIndex(level={self._grid.level}, indexing_scheme={self._grid.indexing_scheme}, kind={self._kind})" diff --git a/xdggs/index.py b/xdggs/index.py index b0f71f28..def181ff 100644 --- a/xdggs/index.py +++ b/xdggs/index.py @@ -9,7 +9,7 @@ from xdggs.utils import GRID_REGISTRY, _extract_cell_id_variable -def decode(ds, grid_info=None, *, name="cell_ids"): +def decode(ds, grid_info=None, *, name="cell_ids", index_options=None, **index_kwargs): """ decode grid parameters and create a DGGS index @@ -23,6 +23,8 @@ def decode(ds, grid_info=None, *, name="cell_ids"): the dataset. name : str, default: "cell_ids" The name of the coordinate containing the cell ids. + index_options, **index_kwargs : dict, optional + Additional options to forward to the index. Returns ------- @@ -34,20 +36,25 @@ def decode(ds, grid_info=None, *, name="cell_ids"): xarray.Dataset.dggs.decode xarray.DataArray.dggs.decode """ - return ds.dggs.decode(name=name, grid_info=grid_info) + if index_options is None: + index_options = {} + + return ds.dggs.decode( + name=name, grid_info=grid_info, index_options=index_options | index_kwargs + ) class DGGSIndex(Index): _dim: str - _pd_index: PandasIndex + _index: xr.Index - def __init__(self, cell_ids: Any | PandasIndex, dim: str, grid_info: DGGSInfo): + def __init__(self, cell_ids: Any | xr.Index, dim: str, grid_info: DGGSInfo): self._dim = dim - if isinstance(cell_ids, PandasIndex): - self._pd_index = cell_ids + if isinstance(cell_ids, xr.Index): + self._index = cell_ids else: - self._pd_index = PandasIndex(cell_ids, dim) + self._index = PandasIndex(cell_ids, dim) self._grid = grid_info @@ -67,33 +74,36 @@ def from_variables( return cls.from_variables(variables, options=options) + def values(self): + return self._index.index.values + def create_variables( self, variables: Mapping[Any, xr.Variable] | None = None ) -> dict[Hashable, xr.Variable]: - return self._pd_index.create_variables(variables) + return self._index.create_variables(variables) def isel( self: "DGGSIndex", indexers: Mapping[Any, int | np.ndarray | xr.Variable] ) -> Union["DGGSIndex", None]: - new_pd_index = self._pd_index.isel(indexers) - if new_pd_index is not None: - return self._replace(new_pd_index) + new_index = self._index.isel(indexers) + if new_index is not None: + return self._replace(new_index) else: return None def sel(self, labels, method=None, tolerance=None): if method == "nearest": raise ValueError("finding nearest grid cell has no meaning") - return self._pd_index.sel(labels, method=method, tolerance=tolerance) + return self._index.sel(labels, method=method, tolerance=tolerance) - def _replace(self, new_pd_index: PandasIndex): + def _replace(self, new_index: PandasIndex): raise NotImplementedError() def cell_centers(self) -> tuple[np.ndarray, np.ndarray]: - return self._grid.cell_ids2geographic(self._pd_index.index.values) + return self._grid.cell_ids2geographic(self.values()) def cell_boundaries(self) -> np.ndarray: - return self.grid_info.cell_boundaries(self._pd_index.index.values) + return self.grid_info.cell_boundaries(self.values()) @property def grid_info(self) -> DGGSInfo: diff --git a/xdggs/tests/__init__.py b/xdggs/tests/__init__.py index 1a193487..1d767c0b 100644 --- a/xdggs/tests/__init__.py +++ b/xdggs/tests/__init__.py @@ -1,4 +1,7 @@ +from contextlib import nullcontext + import geoarrow.pyarrow as ga +import pytest import shapely from xdggs.tests.matchers import ( # noqa: F401 @@ -7,6 +10,52 @@ assert_exceptions_equal, ) +try: + import dask + import dask.array as da + + has_dask = True +except ImportError: + dask = None + + class da: + @staticmethod + def arange(*args, **kwargs): + pass + + has_dask = False + + +# vendored from xarray +class CountingScheduler: + """Simple dask scheduler counting the number of computes. + + Reference: https://stackoverflow.com/questions/53289286/""" + + def __init__(self, max_computes=0): + self.total_computes = 0 + self.max_computes = max_computes + + def __call__(self, dsk, keys, **kwargs): + self.total_computes += 1 + if self.total_computes > self.max_computes: + raise RuntimeError( + f"Too many computes. Total: {self.total_computes} > max: {self.max_computes}." + ) + return dask.get(dsk, keys, **kwargs) + + +requires_dask = pytest.mark.skipif(not has_dask, reason="requires dask") + def geoarrow_to_shapely(arr): return shapely.from_wkb(ga.as_wkb(arr)) + + +# vendored from xarray +def raise_if_dask_computes(max_computes=0): + # return a dummy context manager so that this can be used for non-dask objects + if not has_dask: + return nullcontext() + scheduler = CountingScheduler(max_computes) + return dask.config.set(scheduler=scheduler) diff --git a/xdggs/tests/test_h3.py b/xdggs/tests/test_h3.py index e1c2629d..c58c0205 100644 --- a/xdggs/tests/test_h3.py +++ b/xdggs/tests/test_h3.py @@ -221,8 +221,8 @@ def test_init(cell_ids, dim, level): assert index._dim == dim # TODO: how do we check the index, if at all? - assert index._pd_index.dim == dim - assert np.all(index._pd_index.index.values == cell_ids) + assert index._index.dim == dim + assert np.all(index._index.index.values == cell_ids) @pytest.mark.parametrize("level", levels) @@ -247,8 +247,8 @@ def test_from_variables(variable_name, variable, options): assert (index._dim,) == variable.dims # TODO: how do we check the index, if at all? - assert (index._pd_index.dim,) == variable.dims - assert np.all(index._pd_index.index.values == variable.data) + assert (index._index.dim,) == variable.dims + assert np.all(index._index.index.values == variable.data) @pytest.mark.parametrize(["old_variable", "new_variable"], variable_combinations) @@ -267,7 +267,7 @@ def test_replace(old_variable, new_variable): assert new_index._grid == index._grid assert new_index._dim == index._dim - assert new_index._pd_index == new_pandas_index + assert new_index._index == new_pandas_index @pytest.mark.parametrize("max_width", [20, 50, 80, 120]) diff --git a/xdggs/tests/test_healpix.py b/xdggs/tests/test_healpix.py index bc75eef7..5e31366b 100644 --- a/xdggs/tests/test_healpix.py +++ b/xdggs/tests/test_healpix.py @@ -12,7 +12,13 @@ from xarray.core.indexes import PandasIndex from xdggs import healpix -from xdggs.tests import assert_exceptions_equal, geoarrow_to_shapely +from xdggs.tests import ( + assert_exceptions_equal, + da, + geoarrow_to_shapely, + raise_if_dask_computes, + requires_dask, +) try: ExceptionGroup @@ -465,9 +471,9 @@ def test_init(self, cell_ids, dim, grid) -> None: assert index._grid == grid assert index._dim == dim - assert index._pd_index.dim == dim + assert index._index.dim == dim - np.testing.assert_equal(index._pd_index.index.values, cell_ids) + np.testing.assert_equal(index._index.index.values, cell_ids) @given(strategies.grids()) def test_grid(self, grid): @@ -491,7 +497,20 @@ def test_from_variables(variable_name, variable, options) -> None: assert index._grid.indexing_scheme == expected_scheme assert (index._dim,) == variable.dims - np.testing.assert_equal(index._pd_index.index.values, variable.data) + np.testing.assert_equal(index._index.index.values, variable.data) + + +def test_from_variables_moc() -> None: + level = 2 + grid_info = {"grid_name": "healpix", "level": level, "indexing_scheme": "nested"} + variables = {"cell_ids": xr.Variable("cells", np.arange(12 * 4**level), grid_info)} + + index = healpix.HealpixIndex.from_variables( + variables, options={"index_kind": "moc"} + ) + + assert isinstance(index._index, healpix.HealpixMocIndex) + assert index.grid_info.to_dict() == grid_info @pytest.mark.parametrize(["old_variable", "new_variable"], variable_combinations) @@ -511,7 +530,7 @@ def test_replace(old_variable, new_variable) -> None: new_index = index._replace(new_pandas_index) assert new_index._dim == index._dim - assert new_index._pd_index == new_pandas_index + assert new_index._index == new_pandas_index assert index._grid == grid @@ -526,3 +545,211 @@ def test_repr_inline(level, max_width) -> None: assert f"level={level}" in actual # ignore max_width for now # assert len(actual) <= max_width + + +class TestHealpixMocIndex: + @pytest.mark.parametrize( + ["level", "cell_ids", "max_computes"], + ( + pytest.param( + 2, np.arange(12 * 4**2, dtype="uint64"), 1, id="numpy-2-full_domain" + ), + pytest.param( + 2, + np.arange(3 * 4**2, 5 * 4**2, dtype="uint64"), + 1, + id="numpy-2-region", + ), + pytest.param( + 10, + da.arange(12 * 4**10, chunks=(4**6,), dtype="uint64"), + 0, + marks=requires_dask, + id="dask-10-full_domain", + ), + pytest.param( + 15, + da.arange(12 * 4**15, chunks=(4**10,), dtype="uint64"), + 0, + marks=requires_dask, + id="dask-15-full_domain", + ), + pytest.param( + 10, + da.arange(3 * 4**10, 5 * 4**10, chunks=(4**6,), dtype="uint64"), + 1, + marks=requires_dask, + id="dask-10-region", + ), + ), + ) + def test_from_array(self, level, cell_ids, max_computes): + grid_info = healpix.HealpixInfo(level=level, indexing_scheme="nested") + + with raise_if_dask_computes(max_computes=max_computes): + index = healpix.HealpixMocIndex.from_array( + cell_ids, dim="cells", name="cell_ids", grid_info=grid_info + ) + + assert isinstance(index, healpix.HealpixMocIndex) + chunks = index.chunksizes["cells"] + assert chunks is None or isinstance(chunks[0], int) + assert index.size == cell_ids.size + assert index.nbytes == 16 + + def test_from_array_unsupported_indexing_scheme(self): + level = 1 + cell_ids = np.arange(12 * 4**level, dtype="uint64") + grid_info = healpix.HealpixInfo(level=level, indexing_scheme="ring") + + with pytest.raises(ValueError, match=".*only supports the 'nested' scheme"): + healpix.HealpixMocIndex.from_array( + cell_ids, dim="cells", name="cell_ids", grid_info=grid_info + ) + + @pytest.mark.parametrize("dask", [False, pytest.param(True, marks=requires_dask)]) + @pytest.mark.parametrize( + ["level", "cell_ids"], + ( + ( + 1, + np.array( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 22, 23, 24, 25, 43, 45, 46, 47], + dtype="uint64", + ), + ), + (4, np.arange(12 * 4**4, dtype="uint64")), + ), + ) + def test_from_variables(self, level, cell_ids, dask): + grid_info_mapping = { + "grid_name": "healpix", + "level": level, + "indexing_scheme": "nested", + } + variables = {"cell_ids": xr.Variable("cells", cell_ids, grid_info_mapping)} + if dask: + variables["cell_ids"] = variables["cell_ids"].chunk(4**level) + + actual = healpix.HealpixMocIndex.from_variables(variables, options={}) + + assert isinstance(actual, healpix.HealpixMocIndex) + assert actual.size == cell_ids.size + np.testing.assert_equal(actual._index.cell_ids(), cell_ids) + + @pytest.mark.parametrize( + "indexer", + ( + slice(None), + slice(None, 4**1), + slice(2 * 4**1, 7 * 4**1), + slice(7, 25), + np.array([-4, -3, -2], dtype="int64"), + np.array([12, 13, 14, 15, 16], dtype="uint64"), + np.array([1, 2, 3, 4, 5], dtype="uint32"), + ), + ) + @pytest.mark.parametrize( + "chunks", + [ + pytest.param(None, id="none"), + pytest.param((12, 12, 12, 12), marks=requires_dask, id="equally_sized"), + ], + ) + def test_isel(self, indexer, chunks): + from healpix_geo.nested import RangeMOCIndex + + grid_info = healpix.HealpixInfo(level=1, indexing_scheme="nested") + cell_ids = np.arange(12 * 4**grid_info.level, dtype="uint64") + if chunks is None: + input_chunks = None + expected_chunks = None + else: + import dask.array as da + + cell_ids_ = da.arange( + 12 * 4**grid_info.level, dtype="uint64", chunks=chunks + ) + input_chunks = cell_ids_.chunks[0] + expected_chunks = cell_ids_[indexer].chunks[0] + + index = healpix.HealpixMocIndex( + RangeMOCIndex.from_cell_ids(grid_info.level, cell_ids), + dim="cells", + name="cell_ids", + grid_info=grid_info, + chunksizes={"cells": input_chunks}, + ) + + actual = index.isel({"cells": indexer}) + expected = healpix.HealpixMocIndex( + RangeMOCIndex.from_cell_ids(grid_info.level, cell_ids[indexer]), + dim="cells", + name="cell_ids", + grid_info=grid_info, + chunksizes={"cells": expected_chunks}, + ) + + assert isinstance(actual, healpix.HealpixMocIndex) + assert actual.nbytes == expected.nbytes + assert actual.chunksizes == expected.chunksizes + np.testing.assert_equal(actual._index.cell_ids(), expected._index.cell_ids()) + + @pytest.mark.parametrize( + "chunks", + [ + pytest.param((12, 12, 12, 12), marks=requires_dask), + pytest.param((18, 10, 10, 10), marks=requires_dask), + pytest.param((8, 12, 14, 14), marks=requires_dask), + None, + ], + ) + def test_create_variables(self, chunks): + from healpix_geo.nested import RangeMOCIndex + + grid_info = healpix.HealpixInfo(level=1, indexing_scheme="nested") + cell_ids = np.arange(12 * 4**grid_info.level, dtype="uint64") + indexer = slice(3 * 4**grid_info.level, 7 * 4**grid_info.level) + index = healpix.HealpixMocIndex( + RangeMOCIndex.from_cell_ids(grid_info.level, cell_ids[indexer]), + dim="cells", + name="cell_ids", + grid_info=grid_info, + chunksizes={"cells": chunks}, + ) + + if chunks is not None: + variables = { + "cell_ids": xr.Variable("cells", cell_ids, grid_info.to_dict()).chunk( + {"cells": chunks} + ) + } + else: + variables = { + "cell_ids": xr.Variable("cells", cell_ids, grid_info.to_dict()) + } + + actual = index.create_variables(variables) + expected = {"cell_ids": variables["cell_ids"].isel(cells=indexer)} + + assert actual.keys() == expected.keys() + xr.testing.assert_equal(actual["cell_ids"], expected["cell_ids"]) + + def test_create_variables_new(self): + from healpix_geo.nested import RangeMOCIndex + + grid_info = healpix.HealpixInfo(level=1, indexing_scheme="nested") + cell_ids = np.arange(12 * 4**grid_info.level, dtype="uint64") + indexer = slice(3 * 4**grid_info.level, 7 * 4**grid_info.level) + index = healpix.HealpixMocIndex( + RangeMOCIndex.from_cell_ids(grid_info.level, cell_ids[indexer]), + dim="cells", + name="cell_ids", + grid_info=grid_info, + chunksizes={"cells": None}, + ) + actual = index.create_variables({}) + expected = {"cell_ids": xr.Variable("cells", cell_ids[indexer])} + + assert actual.keys() == expected.keys() + xr.testing.assert_equal(actual["cell_ids"], expected["cell_ids"])