Skip to content

Commit 1dc1c37

Browse files
Made set_workers a context manager, like it should be. os.get_cpu_count->os.cpu_count
1 parent c152832 commit 1dc1c37

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

mkl_fft/_scipy_fft_backend.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from numpy.core import (take, sqrt, prod)
3232
import contextvars
33+
import contextlib
3334
import operator
3435
import os
3536

@@ -90,14 +91,20 @@ def get_workers():
9091
return _workers_global_settings.get().workers
9192

9293

94+
@contextlib.contextmanager
9395
def set_workers(n_workers):
9496
"Set the value of workers used by default, returns the previous value"
9597
nw = operator.index(n_workers)
96-
wd = _workers_global_settings.get()
97-
saved_nw = wd.workers
98-
wd.workers = nw
99-
_workers_global_settings.set(wd)
100-
return saved_nw
98+
token = None
99+
try:
100+
new_wd = _workers_data(nw)
101+
token = _workers_global_settings.set(new_wd)
102+
yield
103+
finally:
104+
if token:
105+
_workers_global_settings.reset(token)
106+
else:
107+
raise ValueError
101108

102109

103110
__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn',
@@ -154,7 +161,7 @@ def _workers_to_num_threads(w):
154161
if (_w == 0):
155162
raise ValueError("Number of workers must not be zero")
156163
if (_w < 0):
157-
ub = os.get_cpu_count()
164+
ub = os.cpu_count()
158165
_w += ub + 1
159166
if _w <= 0:
160167
raise ValueError("workers value out of range; got {}, must not be"

0 commit comments

Comments
 (0)