Skip to content

Commit 9178a51

Browse files
authored
Merge pull request #138 from jrussell25/xarray-slicer
Xarray slicer
2 parents 81f477b + 888c5e0 commit 9178a51

File tree

7 files changed

+257
-19
lines changed

7 files changed

+257
-19
lines changed

docs/_static/images/hyperslicer4.gif

511 KB
Loading

examples/hyperslicer.ipynb

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
},
2020
{
2121
"cell_type": "code",
22-
"execution_count": 1,
22+
"execution_count": null,
2323
"metadata": {},
2424
"outputs": [],
2525
"source": [
@@ -36,7 +36,7 @@
3636
},
3737
{
3838
"cell_type": "code",
39-
"execution_count": 2,
39+
"execution_count": null,
4040
"metadata": {},
4141
"outputs": [],
4242
"source": [
@@ -50,17 +50,9 @@
5050
},
5151
{
5252
"cell_type": "code",
53-
"execution_count": 3,
53+
"execution_count": null,
5454
"metadata": {},
55-
"outputs": [
56-
{
57-
"name": "stdout",
58-
"output_type": "stream",
59-
"text": [
60-
"(126, 512, 512)\n"
61-
]
62-
}
63-
],
55+
"outputs": [],
6456
"source": [
6557
"print(beads.shape) # (126, 512, 512)"
6658
]
@@ -241,6 +233,68 @@
241233
"controls8 = hyperslicer(beads4d, vmin=0, vmax=255, axes=((0, 1), \"wavenums\"))"
242234
]
243235
},
236+
{
237+
"cell_type": "markdown",
238+
"metadata": {},
239+
"source": [
240+
"### Hyperslicer with Xarray\n",
241+
"\n",
242+
"[Xarray](http://xarray.pydata.org/en/stable/index.html) is a library for having named dimensions on an array and hyperslicer supports them natively. So if you're going to go to the trouble of defining the `axes` argument you might think about just using xarray and doing it once per dataset and letting xarray keep track of them. Then hyperslicer will just access the information for you.\n",
243+
"\n",
244+
"Xarray also integrates with dask for lazy data loading so if your data is large this is a good way to process them and now you can selectively visualize these lazy arrays with hyperslicer. Here we will just demonstrate the basics with an in memory xarray but the out of memory case is similar albeit slower to render."
245+
]
246+
},
247+
{
248+
"cell_type": "code",
249+
"execution_count": null,
250+
"metadata": {},
251+
"outputs": [],
252+
"source": [
253+
"import xarray as xr"
254+
]
255+
},
256+
{
257+
"cell_type": "code",
258+
"execution_count": null,
259+
"metadata": {},
260+
"outputs": [],
261+
"source": [
262+
"# Define the coordinates for the xarray as a dict of name:array pairs\n",
263+
"# Intensity is arbiratrily made to be 0-1\n",
264+
"# Wns = Wns is relevant spectroscopic unit in cm^-1 as above\n",
265+
"# X,Y = actual dimensions of the images in microns from microscope metadata\n",
266+
"coords = {'linear':np.linspace(0,1,beads4d.shape[0]), 'wavenums':wns,\n",
267+
" 'X':np.linspace(0, 386.44, 512), 'Y':np.linspace(0, 386.44,512)}"
268+
]
269+
},
270+
{
271+
"cell_type": "code",
272+
"execution_count": null,
273+
"metadata": {},
274+
"outputs": [],
275+
"source": [
276+
"x_beads4d = xr.DataArray(beads4d, dims=coords.keys(),coords=coords)"
277+
]
278+
},
279+
{
280+
"cell_type": "code",
281+
"execution_count": null,
282+
"metadata": {
283+
"gif": "hyperslicer4.gif"
284+
},
285+
"outputs": [],
286+
"source": [
287+
"fig9, ax9 = plt.subplots()\n",
288+
"controls9 = hyperslicer(x_beads4d, vmin=0, vmax=255)"
289+
]
290+
},
291+
{
292+
"cell_type": "markdown",
293+
"metadata": {},
294+
"source": [
295+
"Hyperslicer also supports bare dask arrays with the same logic as numpy arrays."
296+
]
297+
},
244298
{
245299
"cell_type": "code",
246300
"execution_count": null,
@@ -265,7 +319,7 @@
265319
"name": "python",
266320
"nbconvert_exporter": "python",
267321
"pygments_lexer": "ipython3",
268-
"version": "3.7.8"
322+
"version": "3.8.6"
269323
}
270324
},
271325
"nbformat": 4,

mpl_interactions/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
version_info = (0, 11, 0)
1+
version_info = (0, 12, 0)
22
__version__ = ".".join(map(str, version_info))

mpl_interactions/generic.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .helpers import *
1616
from .utils import figure, ioff, nearest_idx
1717
from .controller import gogogo_controls
18+
from .xarray_helpers import get_hs_axes, get_hs_extent, get_hs_fmts
1819

1920
# functions that are methods
2021
__all__ = [
@@ -578,7 +579,14 @@ def hyperslicer(
578579
controls
579580
"""
580581

581-
arr = np.asarray(np.squeeze(arr))
582+
arr = np.squeeze(arr)
583+
584+
arr_type = "numpy"
585+
if "xarray.core.dataarray.DataArray" in str(arr.__class__):
586+
arr_type = "xarray"
587+
elif "dask.array.core.Array" in str(arr.__class__):
588+
arr_type = "dask"
589+
582590
if arr.ndim < 3 + is_color_image:
583591
raise ValueError(
584592
f"arr must be at least {3+is_color_image}D but it is {arr.ndim}D. mpl_interactions.imshow for 2D images."
@@ -599,11 +607,15 @@ def hyperslicer(
599607

600608
names = None
601609
axes = None
602-
if "names" in kwargs:
603-
names = kwargs.pop("names")
610+
if arr_type != "xarray":
611+
if "names" in kwargs:
612+
names = kwargs.pop("names")
613+
614+
elif "axes" in kwargs:
615+
axes = kwargs.pop("axes")
604616

605-
elif "axes" in kwargs:
606-
axes = kwargs.pop("axes")
617+
else:
618+
axes = get_hs_axes(arr, is_color_image=is_color_image)
607619

608620
# Just pass in an array - no kwargs
609621
for i in range(arr.ndim - im_dims):
@@ -662,6 +674,13 @@ def hyperslicer(
662674
slider_format_strings[name] = "{:.0f}"
663675
kwargs[name] = np.arange(arr.shape[i])
664676

677+
if arr_type == "xarray":
678+
slider_format_strings = get_hs_fmts(arr, is_color_image=is_color_image)
679+
extent = get_hs_extent(arr, is_color_image=is_color_image)
680+
else:
681+
if "extent" not in kwargs:
682+
extent = None
683+
665684
controls, params = gogogo_controls(
666685
kwargs, controls, display_controls, slider_format_strings, play_buttons, allow_dupes=True
667686
)

mpl_interactions/helpers.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"gogogo_display",
4141
"create_mpl_controls_fig",
4242
"eval_xy",
43+
"choose_fmt_str",
4344
]
4445

4546

@@ -627,3 +628,31 @@ def gogogo_display(ipympl, use_ipywidgets, display, controls, fig):
627628
fig.show()
628629
controls[0].show()
629630
return controls
631+
632+
633+
def choose_fmt_str(dtype=None):
634+
"""
635+
Choose the appropriate string formatting for different dtypes.
636+
637+
Paramters
638+
---------
639+
640+
dtype : np.dtye
641+
dtype of array containing values to be formatted.
642+
643+
Returns
644+
-------
645+
646+
fmt : str
647+
Format string
648+
"""
649+
if np.issubdtype(dtype, "float"):
650+
fmt = r"{:0.2f}"
651+
652+
elif np.issubdtype(dtype, "int"):
653+
fmt = r"{:d}"
654+
655+
else:
656+
fmt = r"{:}"
657+
658+
return fmt

mpl_interactions/xarray_helpers.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import numpy as np
2+
from .helpers import choose_fmt_str
3+
4+
5+
def choose_datetime_nonsense(arr, timeunit="m"):
6+
"""
7+
Try to do something reasonable to datetimes and timedeltas.
8+
9+
Parameters
10+
----------
11+
12+
arr : np.array
13+
Array with values to be formatted.
14+
15+
Returns
16+
-------
17+
18+
out : np.array
19+
Array modified to format decently in a slider.
20+
21+
"""
22+
23+
if np.issubdtype(arr.dtype, "datetime64"):
24+
# print('datetime')
25+
out = arr.astype(f"datetime64[{timeunit}]")
26+
elif np.issubdtype(arr.dtype, "timedelta64"):
27+
out = arr.astype(f"timedelta64[{timeunit}]").astype(int)
28+
else:
29+
out = arr
30+
return out
31+
32+
33+
def get_hs_axes(xarr, is_color_image=False, timeunit="m"):
34+
"""
35+
Read the dims and coordinates from an xarray and construct the
36+
axes argument for hyperslicer. Called internally by hyperslicer.
37+
38+
Parameters
39+
----------
40+
41+
xarr : xarray.DataArray
42+
DataArray being viewed with hyperslicer
43+
44+
is_color_image : bool, default False
45+
Whether the individual images of the hyperstack are color images.
46+
47+
timeunit : str, default "m"
48+
Truncation level for datetime and timedelta axes.
49+
50+
Returns
51+
-------
52+
axes : list
53+
axes kwarg for hyperslicer
54+
55+
"""
56+
if not is_color_image:
57+
dims = xarr.dims[:-2]
58+
else:
59+
dims = xarr.dims[:-3]
60+
coords_list = [choose_datetime_nonsense(xarr.coords[d].values, timeunit=timeunit) for d in dims]
61+
# print(coords_list)
62+
axes = zip(dims, coords_list)
63+
return list(axes)
64+
65+
66+
def get_hs_extent(xarr, is_color_image=False):
67+
"""
68+
Read the "XY" coordinates of an xarray.DataArray to set extent of image for
69+
imshow.
70+
71+
Parameters
72+
----------
73+
74+
xarr : xarray.DataArray
75+
DataArray being viewed with hyperslicer
76+
77+
is_color_image : bool, default False
78+
Whether the individual images of the hyperstack are color images.
79+
80+
Returns
81+
-------
82+
extent : list
83+
Extent argument for imshow. [d0_min, d0_max, d1_min, d1_max]
84+
85+
"""
86+
87+
if not is_color_image:
88+
dims = xarr.dims[-2:]
89+
else:
90+
dims = xarr.dims[-3:-1]
91+
extent = []
92+
for d in dims:
93+
vals = xarr[d].values
94+
extent.append(vals.min())
95+
extent.append(vals.max())
96+
return extent
97+
98+
99+
def get_hs_fmts(xarr, units=None, is_color_image=False):
100+
"""
101+
Get appropriate slider format strings from xarray coordinates
102+
based the dtype of corresponding values.
103+
104+
Parameters
105+
----------
106+
107+
xarr : xarray.DataArray
108+
DataArray being viewed with hyperslicer
109+
110+
units : array-like
111+
Units to append to end of slider value. Must have the same length
112+
as number of non-image dimensions in xarray.
113+
114+
is_color_image : bool, default False
115+
Whether the individual images of the hyperstack are color images.
116+
117+
Returns
118+
-------
119+
fmt_strings : dict
120+
Slider format strings for hyperslicer (or other mpl-interactions?)
121+
"""
122+
if not is_color_image:
123+
dims = xarr.dims[:-2]
124+
else:
125+
dims = xarr.dims[:-3]
126+
fmt_strs = {}
127+
for i, d in enumerate(dims):
128+
fmt_strs[d] = choose_fmt_str(xarr[d].dtype)
129+
if units is not None and units[i] is not None:
130+
try:
131+
fmt_strs[d] += " {}".format(units[i])
132+
except:
133+
continue
134+
return fmt_strs

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
"jupyter-sphinx",
6262
"sphinx-copybutton",
6363
"sphinx-autobuild",
64+
"xarray",
6465
],
6566
"test": [
6667
"pytest",
@@ -71,6 +72,7 @@
7172
"pandas",
7273
"requests",
7374
"scipy",
75+
"xarray",
7476
],
7577
"dev": [
7678
"pre-commit",

0 commit comments

Comments
 (0)