|
8 | 8 | from functools import reduce
|
9 | 9 | from collections.abc import Iterable
|
10 | 10 | from numpy import (
|
11 |
| - digitize, |
| 11 | + searchsorted, |
12 | 12 | bincount,
|
13 | 13 | reshape,
|
14 | 14 | ravel_multi_index,
|
@@ -154,24 +154,26 @@ def _bincount_2d_vectorized(
|
154 | 154 | nbins = [len(b) for b in bins]
|
155 | 155 | hist_shapes = [nb + 1 for nb in nbins]
|
156 | 156 |
|
157 |
| - # a marginally faster implementation would be to use searchsorted, |
158 |
| - # like numpy histogram itself does |
159 |
| - # https://github.com/numpy/numpy/blob/9c98662ee2f7daca3f9fae9d5144a9a8d3cabe8c/numpy/lib/histograms.py#L864-L882 |
160 |
| - # for now we stick with `digitize` because it's easy to understand how it works |
161 |
| - |
162 |
| - # Add small increment to the last bin edge to make the final bin right-edge inclusive |
163 |
| - # Note, this is the approach taken by sklearn, e.g. |
164 |
| - # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/calibration.py#L592 |
165 |
| - # but a better approach would be to use something like _search_sorted_inclusive() in |
166 |
| - # numpy histogram. This is an additional motivation for moving to searchsorted |
167 |
| - bins = [np.concatenate((b[:-1], b[-1:] + 1e-8)) for b in bins] |
168 |
| - |
169 |
| - # the maximum possible value of of digitize is nbins |
170 |
| - # for right=False: |
| 157 | + # The maximum possible value of searchsorted is nbins |
| 158 | + # For _searchsorted_inclusive: |
171 | 159 | # - 0 corresponds to a < b[0]
|
172 |
| - # - i corresponds to bins[i-1] <= a < b[i] |
173 |
| - # - nbins corresponds to a a >= b[1] |
174 |
| - each_bin_indices = [digitize(a, b) for a, b in zip(args, bins)] |
| 160 | + # - i corresponds to b[i-1] <= a < b[i] |
| 161 | + # - nbins-1 corresponds to b[-2] <= a <= b[-1] |
| 162 | + # - nbins corresponds to a >= b[-1] |
| 163 | + def _searchsorted_inclusive(a, b): |
| 164 | + """ |
| 165 | + Like `searchsorted`, but where the last bin is also right-edge inclusive. |
| 166 | + """ |
| 167 | + # Similar to implementation in np.histogramdd |
| 168 | + # see https://github.com/numpy/numpy/blob/9c98662ee2f7daca3f9fae9d5144a9a8d3cabe8c/numpy/lib/histograms.py#L1056 |
| 169 | + # This assumes the bins (b) are sorted |
| 170 | + bin_indices = searchsorted(b, a, side="right") |
| 171 | + on_edge = a == b[-1] |
| 172 | + # Shift these points one bin to the left. |
| 173 | + bin_indices[on_edge] -= 1 |
| 174 | + return bin_indices |
| 175 | + |
| 176 | + each_bin_indices = [_searchsorted_inclusive(a, b) for a, b in zip(args, bins)] |
175 | 177 | # product of the bins gives the joint distribution
|
176 | 178 | if N_inputs > 1:
|
177 | 179 | bin_indices = ravel_multi_index(each_bin_indices, hist_shapes)
|
@@ -327,7 +329,7 @@ def histogram(
|
327 | 329 |
|
328 | 330 | See Also
|
329 | 331 | --------
|
330 |
| - numpy.histogram, numpy.bincount, numpy.digitize |
| 332 | + numpy.histogram, numpy.bincount, numpy.searchsorted |
331 | 333 | """
|
332 | 334 |
|
333 | 335 | a0 = args[0]
|
|
0 commit comments