Skip to content

Multidimensional histogram #5400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 79 additions & 54 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@
_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"})


try:
import dask.array as dsa

has_dask = True
except ImportError:
has_dask = False


def _any_dask_array(*args):
if not has_dask:
return False
else:
return any(isinstance(a, dsa.core.Array) for a in args)


def _first_of_type(args, kind):
"""Return either first object of type 'kind' or raise if not found."""
for arg in args:
Expand Down Expand Up @@ -1743,25 +1758,30 @@ def hist(
dim : str or tuple of strings, optional
Dimensions over which which the histogram is computed. The default is to
compute the histogram of the flattened array. i.e. over all dimensions.
bins : int or array_like or a list of ints or arrays, or list of DataArrays, optional
bins : int, str, numpy array or DataArray, or a list of ints, strs, arrays and/or DataArrays, optional
If a list, there should be one entry for each item in ``args``.
The bin specification:

* If int, the number of bins for all arguments in ``args``.
* If array_like, the bin edges for all arguments in ``args``.
* If a list of ints, the number of bins for every argument in ``args``.
* If a list arrays, the bin edges for each argument in ``args``
(required format for Dask inputs).
* If a list of DataArrays, the bins for each argument in ``args``
The DataArrays can be multidimensional, but must contain X
and must not have any dimensions shared with the `dim` argument.
If supplied these will be present as coordinates on the output.
The bin specifications for each entry:

* If int, the number of bins.
* If str; the method used to automatically calculate the optimal bin
width, as defined by `np.histogram_bin_edges`.
* If a numpy array, the bin edges. Must be 1D.
* If a DataArray, the bin edges. The DataArray can be multidimensional,
but must contain the output bins dimension (named as `[var]_bins`
for a given input variable named `var`), and must not have any
dimensions shared with the `dim` argument. If supplied this DataArray
will be present as a coordinate on the output.
* If a list of ints, strs, arrays and/or DataArrays; the bin specification
as above for every argument in ``args``.
* If not supplied (or any elements of the list are `None`) then bins
will be automatically calculated by `np.histogram_bin_edges`.

When bin edges are specified, all but the last (righthand-most) bin include
the left edge and exclude the right edge. The last bin includes both edges.

A ``TypeError`` will be raised if ``args`` contains dask arrays and
``bins`` are not specified explicitly as a list of arrays.
A ``TypeError`` will also be raised if ``args`` contains dask arrays and
``bins`` are not specified explicitly via arrays or DataArrays, because
other bin specifications trigger loading of the entire input data.
weights : array_like, optional
An array of weights, of the same shape as `a`. Each value in
`a` only contributes its associated weight towards the bin count
Expand Down Expand Up @@ -1790,7 +1810,7 @@ def hist(

The returned dataarray will have one additional coordinate for each
dataarray supplied, named as `[var]_bins`, which contains the positions
of the centres of each bin.
of the centres of each bin, varing along a new dimension of the same name.

All other coordinates will be retained, unless they depend on a dimension
which has been reduced along, in which case they will be dropped.
Expand Down Expand Up @@ -1842,55 +1862,60 @@ def hist(
new_bin_dims = [da.name + "_bins" for da in dataarrays]
output_dims = [broadcast_dims] + new_bin_dims

# TODO deal with arrays or lists of arrays
# TODO just check if array-like after explicitly checking if other options
# TODO might need to check if numpy arrays explicitly
# Check validity of given bins, or create using np.histogram_bin_edges
def _create_bin_coords(dataarrays, bins, range, new_bin_dims):
return [
DataArray(np.histogram_bin_edges(a, b, r), dims=d, name=d, attrs=a.attrs)
for a, b, r, d in zip(dataarrays, bins, range, new_bin_dims)
]
def _check_and_format_bins_into_coords(b, da, r, bin_dim):
# Check validity of given bins, or create using np.histogram_bin_edges
# Package into a coordinate DataArray before returning
_edges = np.histogram_bin_edges
if isinstance(b, (int, str)) or b is None:
if is_duck_dask_array(da.data):
raise TypeError(f"Choice of bins as {b} would trigger loading "
f"of entire input array to histogram")
b = "auto" if b is None else b
return DataArray(_edges(da.values, b, r), dims=bin_dim, name=bin_dim,
attrs=da.attrs)
elif isinstance(b, np.ndarray):
if b.ndim > 1:
raise ValueError("bins specified as numpy arrays can only be 1-dimensional")
return DataArray(b, dims=bin_dim, name=bin_dim, attrs=da.attrs)
elif isinstance(b, DataArray):
print(b.dims)
print(broadcast_dims)
if bin_dim not in b.dims:
raise ValueError(
"A bins dataarray does not contain the "
"corresponding output bins dimension - "
f"has dims {b.dims} but not {bin_dim}."
)
if not (set(b.dims) - set([bin_dim])).issubset(broadcast_dims):
raise ValueError(
"A bins dataarray has dimensions present that "
"will not be broadcast on the output: "
f"{da.dims} vs {tuple(*broadcast_dims)}"
)
return b
else:
raise TypeError(f"Type {type(b)} is not a valid argument to bins")

# TODO change to allow mixed types of bins arguments
if isinstance(bins, int):
bins = _create_bin_coords(dataarrays, [bins] * n_args, range, new_bin_dims)
elif bins is None:
bins = _create_bin_coords(dataarrays, ["auto"] * n_args, range, new_bin_dims)
elif isinstance(bins, Iterable):
# TODO check ranges are of valid type
ranges = [None] * n_args
if isinstance(bins, list):
if len(bins) != n_args:
raise TypeError(
"If bins is an Iterable then it must have same length "
"If `bins` is a list then it must have same length "
"as number of input dataarrays passed, but instead has "
f"length {len(bins)}"
)
if all(isinstance(obj, DataArray) for obj in bins):
for da, new_bin_dim in zip(bins, new_bin_dims):
if not set(da.dims).issubset(broadcast_dims):
raise ValueError(
"A bins dataarray has dimensions present that "
"will not be broadcast on the output: "
f"{da.dims} vs {tuple(*broadcast_dims)}"
)
if new_bin_dim not in da.dims:
raise ValueError(
"A bins dataarray does not contain the "
"corresponding output bins dimension - "
f"has dims {da.dims} but not {new_bin_dim}."
)
elif all(isinstance(b, (int, str)) for b in bins):
bins = _create_bin_coords(dataarrays, bins, range, new_bin_dims)
else:
TypeError(
"One or more of the elements in bins is not a valid argument."
f"Instead found types {[type(b) for b in bins]}"
f"length {len(bins)}. To manually specify bin edges for "
"a single input pass them as a numpy array instead."
)
else:
raise TypeError(f"Type {type(bins)} is not a valid argument to bins")
bins = [bins] * n_args
bins = [_check_and_format_bins_into_coords(b, da, r, d) for b, da, r, d in
zip(bins, dataarrays, ranges, new_bin_dims)]

# Align / broadcast all inputs (including weights and bins)
# TODO if this was merely alignment could blockwise handle all the broadcasting?
arrs = broadcast(dataarrays)
weights = weights.broadcast_like(arrs[0])
# TODO surround with try except?
aligned_bins = [b.broadcast_like(arrs[0]) for b in bins]
# TODO bins now already has the output dims included, is that correct?

Expand Down
11 changes: 10 additions & 1 deletion xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1930,7 +1930,7 @@ def test_polyval(use_dask, use_datetime):
xr.testing.assert_allclose(da, da_pv.T)


class TestHistInputTypes:
class TestHistInputTypeChecks:
def test_invalid_data(self):
with pytest.raises(TypeError, match="Only xr.DataArray is supported"):
hist("string")
Expand All @@ -1951,6 +1951,10 @@ def test_wrong_number_of_bins(self):
with pytest.raises(TypeError, match="must have same length"):
hist(xr.DataArray([], name="a"), bins=[2, 3])

def test_non_1d_numpy_array(self):
with pytest.raises(ValueError, match="can only be 1-d"):
hist(xr.DataArray([], name="a"), bins=np.array([[2, 3]]))

def test_invalid_bins(self):
with pytest.raises(TypeError, match="not a valid argument"):
hist(xr.DataArray([], name="a"), bins=2.7)
Expand All @@ -1966,3 +1970,8 @@ def test_bin_dataarrays_without_reduce_dim(self):
bins = xr.DataArray(1)
with pytest.raises(ValueError, match="does not contain"):
hist(data, dim="x", bins=[bins])

def test_prevent_trigger_loading_dask(self):
with pytest.raises(TypeError, match="would trigger loading"):
hist(xr.DataArray([0], name="a").chunk(1), bins=2)