Skip to content

Commit b43387c

Browse files
committed
Merge branch 'main' into dask-new
[skip ci]
2 parents c4bd40b + 7b21d32 commit b43387c

File tree

3 files changed

+102
-51
lines changed

3 files changed

+102
-51
lines changed

scipy/_lib/_uarray/_uarray_dispatch.cxx

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,32 @@
1515

1616
namespace {
1717

18+
template <typename T>
19+
class immortal {
20+
alignas(T) std::byte storage[sizeof(T)];
21+
22+
public:
23+
template <typename... Args>
24+
immortal(Args&&... args) {
25+
// Construct new T in storage
26+
new(&storage) T(std::forward<Args>(args)...);
27+
}
28+
~immortal() {
29+
// Intentionally don't call destructor
30+
}
31+
32+
T* get() { return reinterpret_cast<T*>(&storage); }
33+
const T* get() const { return reinterpret_cast<const T*>(&storage); }
34+
const T* get_const() const { return reinterpret_cast<const T*>(&storage); }
35+
36+
const T* operator ->() const { return get(); }
37+
T* operator ->() { return get(); }
38+
39+
T& operator*() { return *get(); }
40+
const T& operator*() const { return *get(); }
41+
42+
};
43+
1844
/** Handle to a python object that automatically DECREFs */
1945
class py_ref {
2046
explicit py_ref(PyObject * object): obj_(object) {}
@@ -129,8 +155,8 @@ using global_state_t = std::unordered_map<std::string, global_backends>;
129155
using local_state_t = std::unordered_map<std::string, local_backends>;
130156

131157
static py_ref BackendNotImplementedError;
132-
static global_state_t global_domain_map;
133-
thread_local global_state_t * current_global_state = &global_domain_map;
158+
static immortal<global_state_t> global_domain_map;
159+
thread_local global_state_t * current_global_state = global_domain_map.get();
134160
thread_local global_state_t thread_local_domain_map;
135161
thread_local local_state_t local_domain_map;
136162

@@ -140,30 +166,30 @@ Using these with PyObject_GetAttr is faster than PyObject_GetAttrString which
140166
has to create a new python string internally.
141167
*/
142168
struct {
143-
py_ref ua_convert;
144-
py_ref ua_domain;
145-
py_ref ua_function;
169+
immortal<py_ref> ua_convert;
170+
immortal<py_ref> ua_domain;
171+
immortal<py_ref> ua_function;
146172

147173
bool init() {
148-
ua_convert = py_ref::steal(PyUnicode_InternFromString("__ua_convert__"));
149-
if (!ua_convert)
174+
*ua_convert = py_ref::steal(PyUnicode_InternFromString("__ua_convert__"));
175+
if (!*ua_convert)
150176
return false;
151177

152-
ua_domain = py_ref::steal(PyUnicode_InternFromString("__ua_domain__"));
153-
if (!ua_domain)
178+
*ua_domain = py_ref::steal(PyUnicode_InternFromString("__ua_domain__"));
179+
if (!*ua_domain)
154180
return false;
155181

156-
ua_function = py_ref::steal(PyUnicode_InternFromString("__ua_function__"));
157-
if (!ua_function)
182+
*ua_function = py_ref::steal(PyUnicode_InternFromString("__ua_function__"));
183+
if (!*ua_function)
158184
return false;
159185

160186
return true;
161187
}
162188

163189
void clear() {
164-
ua_convert.reset();
165-
ua_domain.reset();
166-
ua_function.reset();
190+
ua_convert->reset();
191+
ua_domain->reset();
192+
ua_function->reset();
167193
}
168194
} identifiers;
169195

@@ -202,7 +228,7 @@ std::string domain_to_string(PyObject * domain) {
202228

203229
Py_ssize_t backend_get_num_domains(PyObject * backend) {
204230
auto domain =
205-
py_ref::steal(PyObject_GetAttr(backend, identifiers.ua_domain.get()));
231+
py_ref::steal(PyObject_GetAttr(backend, identifiers.ua_domain->get()));
206232
if (!domain)
207233
return -1;
208234

@@ -225,7 +251,7 @@ enum class LoopReturn { Continue, Break, Error };
225251
template <typename Func>
226252
LoopReturn backend_for_each_domain(PyObject * backend, Func f) {
227253
auto domain =
228-
py_ref::steal(PyObject_GetAttr(backend, identifiers.ua_domain.get()));
254+
py_ref::steal(PyObject_GetAttr(backend, identifiers.ua_domain->get()));
229255
if (!domain)
230256
return LoopReturn::Error;
231257

@@ -537,7 +563,7 @@ struct BackendState {
537563

538564
/** Clean up global python references when the module is finalized. */
539565
void globals_free(void * /* self */) {
540-
global_domain_map.clear();
566+
global_domain_map->clear();
541567
BackendNotImplementedError.reset();
542568
identifiers.clear();
543569
}
@@ -550,7 +576,7 @@ void globals_free(void * /* self */) {
550576
* cleanup.
551577
*/
552578
int globals_traverse(PyObject * self, visitproc visit, void * arg) {
553-
for (const auto & kv : global_domain_map) {
579+
for (const auto & kv : *global_domain_map) {
554580
const auto & globals = kv.second;
555581
PyObject * backend = globals.global.backend.get();
556582
Py_VISIT(backend);
@@ -563,7 +589,7 @@ int globals_traverse(PyObject * self, visitproc visit, void * arg) {
563589
}
564590

565591
int globals_clear(PyObject * /* self */) {
566-
global_domain_map.clear();
592+
global_domain_map->clear();
567593
return 0;
568594
}
569595

@@ -1170,7 +1196,7 @@ py_ref Function::canonicalize_kwargs(PyObject * kwargs) {
11701196

11711197
py_func_args Function::replace_dispatchables(
11721198
PyObject * backend, PyObject * args, PyObject * kwargs, PyObject * coerce) {
1173-
auto has_ua_convert = PyObject_HasAttr(backend, identifiers.ua_convert.get());
1199+
auto has_ua_convert = PyObject_HasAttr(backend, identifiers.ua_convert->get());
11741200
if (!has_ua_convert) {
11751201
return {py_ref::ref(args), py_ref::ref(kwargs)};
11761202
}
@@ -1182,7 +1208,7 @@ py_func_args Function::replace_dispatchables(
11821208

11831209
PyObject * convert_args[] = {backend, dispatchables.get(), coerce};
11841210
auto res = py_ref::steal(Q_PyObject_VectorcallMethod(
1185-
identifiers.ua_convert.get(), convert_args,
1211+
identifiers.ua_convert->get(), convert_args,
11861212
array_size(convert_args) | Q_PY_VECTORCALL_ARGUMENTS_OFFSET, nullptr));
11871213
if (!res) {
11881214
return {};
@@ -1287,7 +1313,7 @@ PyObject * Function::call(PyObject * args_, PyObject * kwargs_) {
12871313
backend, reinterpret_cast<PyObject *>(this), new_args.args.get(),
12881314
new_args.kwargs.get()};
12891315
result = py_ref::steal(Q_PyObject_VectorcallMethod(
1290-
identifiers.ua_function.get(), args,
1316+
identifiers.ua_function->get(), args,
12911317
array_size(args) | Q_PY_VECTORCALL_ARGUMENTS_OFFSET, nullptr));
12921318

12931319
// raise BackendNotImplemeted is equivalent to return NotImplemented
@@ -1499,7 +1525,7 @@ PyObject * get_state(PyObject * /* self */, PyObject * /* args */) {
14991525

15001526
output->locals = local_domain_map;
15011527
output->use_thread_local_globals =
1502-
(current_global_state != &global_domain_map);
1528+
(current_global_state != global_domain_map.get());
15031529
output->globals = *current_global_state;
15041530

15051531
return ref.release();
@@ -1523,7 +1549,7 @@ PyObject * set_state(PyObject * /* self */, PyObject * args) {
15231549
bool use_thread_local_globals =
15241550
(!reset_allowed) || state->use_thread_local_globals;
15251551
current_global_state =
1526-
use_thread_local_globals ? &thread_local_domain_map : &global_domain_map;
1552+
use_thread_local_globals ? &thread_local_domain_map : global_domain_map.get();
15271553

15281554
if (use_thread_local_globals)
15291555
thread_local_domain_map = state->globals;
@@ -1554,7 +1580,7 @@ PyObject * determine_backend(PyObject * /*self*/, PyObject * args) {
15541580
auto result = for_each_backend_in_domain(
15551581
domain, [&](PyObject * backend, bool coerce_backend) {
15561582
auto has_ua_convert =
1557-
PyObject_HasAttr(backend, identifiers.ua_convert.get());
1583+
PyObject_HasAttr(backend, identifiers.ua_convert->get());
15581584

15591585
if (!has_ua_convert) {
15601586
// If no __ua_convert__, assume it won't accept the type
@@ -1566,7 +1592,7 @@ PyObject * determine_backend(PyObject * /*self*/, PyObject * args) {
15661592
(coerce && coerce_backend) ? Py_True : Py_False};
15671593

15681594
auto res = py_ref::steal(Q_PyObject_VectorcallMethod(
1569-
identifiers.ua_convert.get(), convert_args,
1595+
identifiers.ua_convert->get(), convert_args,
15701596
array_size(convert_args) | Q_PY_VECTORCALL_ARGUMENTS_OFFSET,
15711597
nullptr));
15721598
if (!res) {

scipy/ndimage/tests/test_measurements.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def test_label_output_dtype(xp):
364364
assert output.dtype == t
365365

366366

367-
@skip_xp_backends('dask.array', reason='Dask does not raise')
367+
@xfail_xp_backends('dask.array', reason='Dask does not raise')
368368
@xfail_xp_backends('jax.numpy', reason='JAX does not raise')
369369
def test_label_output_wrong_size(xp):
370370
data = xp.ones([5])
@@ -1158,11 +1158,9 @@ def test_maximum_position06(xp):
11581158
assert output[1] == (1, 1)
11591159

11601160
@xfail_xp_backends("dask.array", reason="crash in dask.array searchsorted")
1161+
@xfail_xp_backends("torch", reason="output[1] is wrong on pytorch")
11611162
def test_maximum_position07(xp):
11621163
# Test float labels
1163-
if is_torch(xp):
1164-
pytest.xfail("output[1] is wrong on pytorch")
1165-
11661164
labels = xp.asarray([1.0, 2.5, 0.0, 4.5])
11671165
for type in types:
11681166
dtype = getattr(xp, type)

scipy/stats/tests/test_entropy.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66

77
from scipy import stats
88
from scipy.stats import norm, expon # type: ignore[attr-defined]
9-
from scipy._lib._array_api import array_namespace, is_array_api_strict, is_jax
9+
from scipy._lib._array_api import array_namespace
1010
from scipy._lib._array_api_no_0d import (xp_assert_close, xp_assert_equal,
1111
xp_assert_less)
1212

13+
skip_xp_backends = pytest.mark.skip_xp_backends
14+
1315
@pytest.mark.skip_xp_backends("dask.array", reason="boolean index assignment")
1416
class TestEntropy:
1517
def test_entropy_positive(self, xp):
@@ -226,13 +228,21 @@ def test_input_validation(self, xp):
226228
with pytest.raises(ValueError, match=message):
227229
stats.differential_entropy(x, method='ekki-ekki')
228230

229-
@pytest.mark.parametrize('method', ['vasicek', 'van es',
230-
'ebrahimi', 'correa'])
231+
@pytest.mark.parametrize('method', [
232+
'vasicek',
233+
'van es',
234+
pytest.param(
235+
'ebrahimi',
236+
marks=skip_xp_backends("jax.numpy",
237+
reason="JAX doesn't support item assignment")
238+
),
239+
pytest.param(
240+
'correa',
241+
marks=skip_xp_backends("array_api_strict",
242+
reason="Needs fancy indexing.")
243+
)
244+
])
231245
def test_consistency(self, method, xp):
232-
if is_jax(xp) and method == 'ebrahimi':
233-
pytest.xfail("Needs array assignment.")
234-
elif is_array_api_strict(xp) and method == 'correa':
235-
pytest.xfail("Needs fancy indexing.")
236246
# test that method is a consistent estimator
237247
n = 10000 if method == 'correa' else 1000000
238248
rvs = stats.norm.rvs(size=n, random_state=0)
@@ -260,17 +270,25 @@ def test_consistency(self, method, xp):
260270
rmse_std_cases = {norm: norm_rmse_std_cases,
261271
expon: expon_rmse_std_cases}
262272

263-
@pytest.mark.parametrize('method', ['vasicek', 'van es', 'ebrahimi', 'correa'])
273+
@pytest.mark.parametrize('method', [
274+
'vasicek',
275+
'van es',
276+
pytest.param(
277+
'ebrahimi',
278+
marks=skip_xp_backends("jax.numpy",
279+
reason="JAX doesn't support item assignment")
280+
),
281+
pytest.param(
282+
'correa',
283+
marks=skip_xp_backends("array_api_strict",
284+
reason="Needs fancy indexing.")
285+
)
286+
])
264287
@pytest.mark.parametrize('dist', [norm, expon])
265288
def test_rmse_std(self, method, dist, xp):
266289
# test that RMSE and standard deviation of estimators matches values
267290
# given in differential_entropy reference [6]. Incidentally, also
268291
# tests vectorization.
269-
if is_jax(xp) and method == 'ebrahimi':
270-
pytest.xfail("Needs array assignment.")
271-
elif is_array_api_strict(xp) and method == 'correa':
272-
pytest.xfail("Needs fancy indexing.")
273-
274292
reps, n, m = 10000, 50, 7
275293
expected = self.rmse_std_cases[dist][method]
276294
rmse_expected, std_expected = xp.asarray(expected[0]), xp.asarray(expected[1])
@@ -284,12 +302,15 @@ def test_rmse_std(self, method, dist, xp):
284302
xp_test = array_namespace(res)
285303
xp_assert_close(xp_test.std(res, correction=0), std_expected, atol=0.002)
286304

287-
@pytest.mark.parametrize('n, method', [(8, 'van es'),
288-
(12, 'ebrahimi'),
289-
(1001, 'vasicek')])
305+
@pytest.mark.parametrize('n, method', [
306+
(8, 'van es'),
307+
pytest.param(
308+
12, 'ebrahimi',
309+
marks=skip_xp_backends("jax.numpy", reason="Needs array assignment")
310+
),
311+
(1001, 'vasicek')
312+
])
290313
def test_method_auto(self, n, method, xp):
291-
if is_jax(xp) and method == 'ebrahimi':
292-
pytest.xfail("Needs array assignment.")
293314
rvs = stats.norm.rvs(size=(n,), random_state=0)
294315
rvs = xp.asarray(rvs.tolist())
295316
res1 = stats.differential_entropy(rvs)
@@ -298,14 +319,20 @@ def test_method_auto(self, n, method, xp):
298319

299320
@pytest.mark.skip_xp_backends('jax.numpy',
300321
reason="JAX doesn't support item assignment")
301-
@pytest.mark.parametrize('method', ["vasicek", "van es", "correa", "ebrahimi"])
322+
@pytest.mark.parametrize('method', [
323+
"vasicek",
324+
"van es",
325+
pytest.param(
326+
"correa",
327+
marks=skip_xp_backends("array_api_strict", reason="Needs fancy indexing.")
328+
),
329+
"ebrahimi"
330+
])
302331
@pytest.mark.parametrize('dtype', [None, 'float32', 'float64'])
303332
def test_dtypes_gh21192(self, xp, method, dtype):
304333
# gh-21192 noted a change in the output of method='ebrahimi'
305334
# with integer input. Check that the output is consistent regardless
306335
# of input dtype.
307-
if is_array_api_strict(xp) and method == 'correa':
308-
pytest.xfail("Needs fancy indexing.")
309336
x = [1, 1, 2, 3, 3, 4, 5, 5, 6, 7, 8, 9, 10, 11]
310337
dtype_in = getattr(xp, str(dtype), None)
311338
dtype_out = getattr(xp, str(dtype), xp.asarray(1.).dtype)

0 commit comments

Comments
 (0)