Skip to content

Commit 1818c7b

Browse files
committed
Convert _path.is_sorted_and_has_non_nan to pybind11
1 parent bab748c commit 1818c7b

File tree

2 files changed

+19
-36
lines changed

2 files changed

+19
-36
lines changed

src/_path.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,17 +1223,15 @@ bool convert_to_string(PathIterator &path,
12231223
}
12241224

12251225
template<class T>
1226-
bool is_sorted_and_has_non_nan(PyArrayObject *array)
1226+
bool is_sorted_and_has_non_nan(py::array_t<T> array)
12271227
{
1228-
char* ptr = PyArray_BYTES(array);
1229-
npy_intp size = PyArray_DIM(array, 0),
1230-
stride = PyArray_STRIDE(array, 0);
1228+
auto size = array.shape(0);
12311229
using limits = std::numeric_limits<T>;
12321230
T last = limits::has_infinity ? -limits::infinity() : limits::min();
12331231
bool found_non_nan = false;
12341232

1235-
for (npy_intp i = 0; i < size; ++i, ptr += stride) {
1236-
T current = *(T*)ptr;
1233+
for (auto i = 0; i < size; ++i) {
1234+
T current = *array.data(i);
12371235
// The following tests !isnan(current), but also works for integral
12381236
// types. (The isnan(IntegralType) overload is absent on MSVC.)
12391237
if (current == current) {

src/_path_wrapper.cpp

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -350,41 +350,26 @@ Py_is_sorted_and_has_non_nan(py::object obj)
350350
{
351351
bool result;
352352

353-
PyArrayObject *array = (PyArrayObject *)PyArray_CheckFromAny(
354-
obj.ptr(), NULL, 1, 1, NPY_ARRAY_NOTSWAPPED, NULL);
355-
356-
if (array == NULL) {
357-
throw py::error_already_set();
353+
py::array array = py::array::ensure(obj);
354+
if (array.ndim() != 1) {
355+
throw std::invalid_argument("array must be 1D");
358356
}
359357

358+
auto dtype = array.dtype();
360359
/* Handle just the most common types here, otherwise coerce to double */
361-
switch (PyArray_TYPE(array)) {
362-
case NPY_INT:
363-
result = is_sorted_and_has_non_nan<npy_int>(array);
364-
break;
365-
case NPY_LONG:
366-
result = is_sorted_and_has_non_nan<npy_long>(array);
367-
break;
368-
case NPY_LONGLONG:
369-
result = is_sorted_and_has_non_nan<npy_longlong>(array);
370-
break;
371-
case NPY_FLOAT:
372-
result = is_sorted_and_has_non_nan<npy_float>(array);
373-
break;
374-
case NPY_DOUBLE:
375-
result = is_sorted_and_has_non_nan<npy_double>(array);
376-
break;
377-
default:
378-
Py_DECREF(array);
379-
array = (PyArrayObject *)PyArray_FromObject(obj.ptr(), NPY_DOUBLE, 1, 1);
380-
if (array == NULL) {
381-
throw py::error_already_set();
382-
}
383-
result = is_sorted_and_has_non_nan<npy_double>(array);
360+
if (dtype.equal(py::dtype::of<std::int32_t>())) {
361+
result = is_sorted_and_has_non_nan<int32_t>(array);
362+
} else if (dtype.equal(py::dtype::of<std::int64_t>())) {
363+
result = is_sorted_and_has_non_nan<int64_t>(array);
364+
} else if (dtype.equal(py::dtype::of<float>())) {
365+
result = is_sorted_and_has_non_nan<float>(array);
366+
} else if (dtype.equal(py::dtype::of<double>())) {
367+
result = is_sorted_and_has_non_nan<double>(array);
368+
} else {
369+
array = py::array_t<double>::ensure(obj);
370+
result = is_sorted_and_has_non_nan<double>(array);
384371
}
385372

386-
Py_DECREF(array);
387-
388373
return result;
389374
}
390375

0 commit comments

Comments
 (0)