Skip to content

Commit 5c12e17

Browse files
authored
Merge pull request #44 from dougiesquire/refactor_bin_alignment_with_np
Refactor approach for making final bin right-edge inclusive
2 parents 3e7e1c1 + 4fdd716 commit 5c12e17

File tree

2 files changed

+42
-19
lines changed

2 files changed

+42
-19
lines changed

xhistogram/core.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from functools import reduce
99
from collections.abc import Iterable
1010
from numpy import (
11-
digitize,
11+
searchsorted,
1212
bincount,
1313
reshape,
1414
ravel_multi_index,
@@ -154,24 +154,26 @@ def _bincount_2d_vectorized(
154154
nbins = [len(b) for b in bins]
155155
hist_shapes = [nb + 1 for nb in nbins]
156156

157-
# a marginally faster implementation would be to use searchsorted,
158-
# like numpy histogram itself does
159-
# https://github.com/numpy/numpy/blob/9c98662ee2f7daca3f9fae9d5144a9a8d3cabe8c/numpy/lib/histograms.py#L864-L882
160-
# for now we stick with `digitize` because it's easy to understand how it works
161-
162-
# Add small increment to the last bin edge to make the final bin right-edge inclusive
163-
# Note, this is the approach taken by sklearn, e.g.
164-
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/calibration.py#L592
165-
# but a better approach would be to use something like _search_sorted_inclusive() in
166-
# numpy histogram. This is an additional motivation for moving to searchsorted
167-
bins = [np.concatenate((b[:-1], b[-1:] + 1e-8)) for b in bins]
168-
169-
# the maximum possible value of of digitize is nbins
170-
# for right=False:
157+
# The maximum possible value of searchsorted is nbins
158+
# For _searchsorted_inclusive:
171159
# - 0 corresponds to a < b[0]
172-
# - i corresponds to bins[i-1] <= a < b[i]
173-
# - nbins corresponds to a a >= b[1]
174-
each_bin_indices = [digitize(a, b) for a, b in zip(args, bins)]
160+
# - i corresponds to b[i-1] <= a < b[i]
161+
# - nbins-1 corresponds to b[-2] <= a <= b[-1]
162+
# - nbins corresponds to a >= b[-1]
163+
def _searchsorted_inclusive(a, b):
164+
"""
165+
Like `searchsorted`, but where the last bin is also right-edge inclusive.
166+
"""
167+
# Similar to implementation in np.histogramdd
168+
# see https://github.com/numpy/numpy/blob/9c98662ee2f7daca3f9fae9d5144a9a8d3cabe8c/numpy/lib/histograms.py#L1056
169+
# This assumes the bins (b) are sorted
170+
bin_indices = searchsorted(b, a, side="right")
171+
on_edge = a == b[-1]
172+
# Shift these points one bin to the left.
173+
bin_indices[on_edge] -= 1
174+
return bin_indices
175+
176+
each_bin_indices = [_searchsorted_inclusive(a, b) for a, b in zip(args, bins)]
175177
# product of the bins gives the joint distribution
176178
if N_inputs > 1:
177179
bin_indices = ravel_multi_index(each_bin_indices, hist_shapes)
@@ -327,7 +329,7 @@ def histogram(
327329
328330
See Also
329331
--------
330-
numpy.histogram, numpy.bincount, numpy.digitize
332+
numpy.histogram, numpy.bincount, numpy.searchsorted
331333
"""
332334

333335
a0 = args[0]

xhistogram/test/test_core.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pandas as pd
23

34
from itertools import combinations
45
import dask.array as dsa
@@ -328,3 +329,23 @@ def test_ensure_correctly_formatted_range(in_out):
328329
else:
329330
with pytest.raises(ValueError):
330331
_ensure_correctly_formatted_range(range_in, n)
332+
333+
334+
@pytest.mark.parametrize("block_size", [None, 1, 2])
335+
@pytest.mark.parametrize("use_dask", [False, True])
336+
def test_histogram_results_datetime(use_dask, block_size):
337+
"""Test computing histogram of datetime objects"""
338+
data = pd.date_range(start="2000-06-01", periods=5)
339+
if use_dask:
340+
data = dsa.asarray(data, chunks=(5,))
341+
# everything should be in the second bin (index 1)
342+
bins = np.array(
343+
[
344+
np.datetime64("1999-01-01"),
345+
np.datetime64("2000-01-01"),
346+
np.datetime64("2001-01-01"),
347+
]
348+
)
349+
h = histogram(data, bins=bins, block_size=block_size)[0]
350+
expected = np.histogram(data, bins=bins)[0]
351+
np.testing.assert_allclose(h, expected)

0 commit comments

Comments
 (0)