Skip to content

Commit 11d24f8

Browse files
authored
Add an Xarray dataset accessor for Cubed with a visualize method (#22)
* Add an Xarray dataset accessor for Cubed with a visualize method * Update test dependencies * Install graphviz for CI
1 parent 848b351 commit 11d24f8

File tree

4 files changed

+41
-0
lines changed

4 files changed

+41
-0
lines changed

.github/workflows/main.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ jobs:
3434
with:
3535
python-version: ${{ matrix.python-version }}
3636

37+
- name: Setup Graphviz
38+
uses: ts-graphviz/setup-graphviz@v2
39+
3740
- name: Install uv
3841
run: |
3942
curl -LsSf https://astral.sh/uv/install.sh | sh

cubed_xarray/cubedmanager.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING, Any, Callable, Iterable, Union
55

66
import numpy as np
7+
import xarray as xr
78
from tlz import partition
89
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint
910

@@ -227,3 +228,28 @@ def store(
227228
targets,
228229
**kwargs,
229230
)
231+
232+
233+
@xr.register_dataset_accessor("cubed")
234+
class DatasetAccessor:
235+
def __init__(self, ds):
236+
self.ds = ds
237+
238+
def visualize(
239+
self,
240+
filename="cubed",
241+
format=None,
242+
optimize_graph=True,
243+
optimize_function=None,
244+
show_hidden=False,
245+
):
246+
import cubed
247+
248+
cubed.visualize(
249+
*(self.ds[var].data for var in self.ds.data_vars.keys()),
250+
filename=filename,
251+
format=format,
252+
optimize_graph=optimize_graph,
253+
optimize_function=optimize_function,
254+
show_hidden=show_hidden,
255+
)

cubed_xarray/tests/test_wrapping.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,14 @@ def test_to_zarr(tmpdir, executor):
6161
assert isinstance(restored.var1.data, cubed.Array)
6262
computed = restored.compute()
6363
assert_allclose(original, computed)
64+
65+
66+
def test_dataset_accessor_visualize(tmp_path):
67+
spec = cubed.Spec(allowed_mem="200MB")
68+
69+
ds = create_test_data().chunk(
70+
chunked_array_type="cubed", from_array_kwargs={"spec": spec}
71+
)
72+
assert not (tmp_path / "cubed.svg").exists()
73+
ds.cubed.visualize(filename=tmp_path / "cubed")
74+
assert (tmp_path / "cubed.svg").exists()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030

3131
[project.optional-dependencies]
3232
test = [
33+
"cubed[diagnostics]",
3334
"dill",
3435
"pre-commit",
3536
"ruff",

0 commit comments

Comments
 (0)