Skip to content

Commit d6bfb4a

Browse files
authored
Fixed #1069 Added numba cache dir for pytest (#1070)
* Added numba cache dir for pytest * Added cache._clear() to cache._save() * Removed recompile from fastmath * Added ref cache length check * Improved coverage * Fixed black formatting * Fixed if to elif * Made get_cache more verbose * Added warning * Fixed typo * Refactored code * Cleaned up from comments * Added warning to clear before save * Reset CACHE_CLEARED after cache._save() is called * Cleaned up code * Added detailed cache note * Fixed black formatting * Added example * Updated test and added more comments
1 parent 82ecd51 commit d6bfb4a

File tree

4 files changed

+103
-25
lines changed

4 files changed

+103
-25
lines changed

stumpy/cache.py

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
CACHE_WARNING += "and should never be used or depended upon as it is not supported! "
1616
CACHE_WARNING += "All caching capabilities are not tested and may be removed/changed "
1717
CACHE_WARNING += "without prior notice. Please proceed with caution!"
18+
CACHE_CLEARED = True
1819

1920

2021
def get_njit_funcs():
@@ -102,58 +103,78 @@ def _enable():
102103
raise
103104

104105

105-
def _clear():
106+
def _clear(cache_dir=None):
106107
"""
107108
Clear numba cache
108109
109110
Parameters
110111
----------
111-
None
112+
cache_dir : str, default None
113+
The path to the numba cache directory
112114
113115
Returns
114116
-------
115117
None
116118
"""
117-
site_pkg_dir = site.getsitepackages()[0]
118-
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
119+
global CACHE_CLEARED
120+
121+
if cache_dir is not None:
122+
numba_cache_dir = str(cache_dir)
123+
else: # pragma: no cover
124+
site_pkg_dir = site.getsitepackages()[0]
125+
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
126+
119127
[f.unlink() for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]
120128

129+
CACHE_CLEARED = True
130+
121131

122-
def clear():
132+
def clear(cache_dir=None):
123133
"""
124134
Clear numba cache directory
125135
126136
Parameters
127137
----------
128-
None
138+
cache_dir : str, default None
139+
The path to the numba cache directory. When `cache_dir` is `None`, then this
140+
defaults to `site-packages/stumpy/__pycache__`.
129141
130142
Returns
131143
-------
132144
None
133145
"""
134146
warnings.warn(CACHE_WARNING)
135-
_clear()
147+
_clear(cache_dir)
136148

137149
return
138150

139151

140-
def _get_cache():
152+
def _get_cache(cache_dir=None):
141153
"""
142154
Retrieve a list of cached numba functions
143155
144156
Parameters
145157
----------
146-
None
158+
cache_dir : str
159+
The path to the numba cache directory
147160
148161
Returns
149162
-------
150163
out : list
151164
A list of cached numba functions
152165
"""
153166
warnings.warn(CACHE_WARNING)
154-
site_pkg_dir = site.getsitepackages()[0]
155-
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
156-
return [f.name for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]
167+
if cache_dir is not None:
168+
numba_cache_dir = str(cache_dir)
169+
else: # pragma: no cover
170+
site_pkg_dir = site.getsitepackages()[0]
171+
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
172+
173+
return [
174+
f"{numba_cache_dir}/{f.name}"
175+
for f in pathlib.Path(numba_cache_dir).glob("*nb*")
176+
if f.is_file()
177+
]
157178

158179

159180
def _recompile():
@@ -202,16 +223,24 @@ def _save():
202223
-------
203224
None
204225
"""
226+
global CACHE_CLEARED
227+
228+
if not CACHE_CLEARED: # pragma: no cover
229+
msg = "Numba njit cached files are not cleared before saving/overwriting. "
230+
msg = "You may need to call `cache.clear()` before calling `cache.save()`."
231+
warnings.warn(msg)
232+
205233
_enable()
206234
_recompile()
207235

236+
CACHE_CLEARED = False
237+
208238
return
209239

210240

211241
def save():
212242
"""
213-
Save/overwrite all the cache data files of
214-
all-so-far compiled njit functions.
243+
Save/overwrite all of the cached njit functions.
215244
216245
Parameters
217246
----------
@@ -220,13 +249,40 @@ def save():
220249
Returns
221250
-------
222251
None
252+
253+
Notes
254+
-----
255+
The cache is never cleared before saving/overwriting and may be explicitly cleared
256+
by calling `cache.clear()` before saving. It is best practice to call `cache.save()`
257+
only after calling all of your `njit` functions. If `cache.save()` is called for the
258+
first time (before any `njit` function is called) then only the `.nbi` files (i.e.,
259+
the "cache index") for all `njit` functions are saved. As each `njit` function (and
260+
sub-functions) is called then their corresponding `.nbc` file (i.e., "object code")
261+
is saved. Each `.nbc` file will only be saved after its `njit` function is called
262+
at least once. However, subsequent calls to `cache.save()` (after clearing the cache
263+
via `cache.clear()`) will automatically save BOTH the `.nbi` files as well as the
264+
`.nbc` files as long as their `njit` function has been called at least once.
265+
266+
Examples
267+
--------
268+
>>> import stumpy
269+
>>> from stumpy import cache
270+
>>> import numpy as np
271+
>>> cache.clear()
272+
>>> mp = stumpy.stump(np.array([584., -11., 23., 79., 1001., 0., -19.]), m=3)
273+
>>> cache.save()
223274
"""
224275
if numba.config.DISABLE_JIT:
225276
msg = "Could not save/cache function because NUMBA JIT is disabled"
226277
warnings.warn(msg)
227278
else: # pragma: no cover
228279
warnings.warn(CACHE_WARNING)
229280

281+
if numba.config.CACHE_DIR != "": # pragma: no cover
282+
msg = "Found user specified `NUMBA_CACHE_DIR`/`numba.config.CACHE_DIR`. "
283+
msg += "The `stumpy` cache files may not be saved/cleared correctly!"
284+
warnings.warn(msg)
285+
230286
_save()
231287

232288
return

stumpy/fastmath.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib
2+
import warnings
23

34
import numba
45
from numba import njit
@@ -55,12 +56,15 @@ def _set(module_name, func_name, flag):
5556
func = getattr(module, func_name)
5657
try:
5758
func.targetoptions["fastmath"] = flag
58-
func.recompile()
59+
msg = "One or more fastmath flags have been set/reset. "
60+
msg += "Please call `cache._recompile()` to ensure that all njit functions "
61+
msg += "are properly recompiled."
62+
warnings.warn(msg)
5963
except AttributeError as e:
6064
if numba.config.DISABLE_JIT and (
6165
str(e) == "'function' object has no attribute 'targetoptions'"
62-
or str(e) == "'function' object has no attribute 'recompile'"
6366
):
67+
warnings.warn("Fastmath flags could not be set as Numba JIT is disabled")
6468
pass
6569
else: # pragma: no cover
6670
raise

tests/test_cache.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numba
12
import numpy as np
23

34
from stumpy import cache, stump
@@ -11,17 +12,28 @@ def test_cache_get_njit_funcs():
1112
def test_cache_save_after_clear():
1213
T = np.random.rand(10)
1314
m = 3
15+
16+
cache_dir = "stumpy/__pycache__"
17+
18+
cache.clear(cache_dir)
1419
stump(T, m)
20+
cache.save() # Enable and save both `.nbi` and `.nbc` cache files
1521

16-
cache.save()
17-
ref_cache = cache._get_cache()
22+
ref_cache = cache._get_cache(cache_dir)
1823

19-
cache.clear()
20-
# testing cache._clear()
21-
assert len(cache._get_cache()) == 0
24+
if numba.config.DISABLE_JIT:
25+
assert len(ref_cache) == 0
26+
else: # pragma: no cover
27+
assert len(ref_cache) > 0
2228

23-
cache.save()
24-
comp_cache = cache._get_cache()
29+
cache.clear(cache_dir)
30+
assert len(cache._get_cache(cache_dir)) == 0
31+
# Note that `stump(T, m)` has already been called once above and any subsequent
32+
# calls to `cache.save()` will automatically save both `.nbi` and `.nbc` cache files
33+
cache.save() # Save both `.nbi` and `.nbc` cache files
34+
35+
comp_cache = cache._get_cache(cache_dir)
2536

26-
# testing cache._save() after cache._clear()
2737
assert sorted(ref_cache) == sorted(comp_cache)
38+
39+
cache.clear(cache_dir)

tests/test_fastmath.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numba
22
import numpy as np
33

4-
from stumpy import fastmath
4+
from stumpy import cache, fastmath
55

66

77
def test_set():
@@ -11,11 +11,13 @@ def test_set():
1111

1212
# case1: flag=False
1313
fastmath._set("fastmath", "_add_assoc", flag=False)
14+
cache._recompile()
1415
out = fastmath._add_assoc(0, np.inf)
1516
assert np.isnan(out)
1617

1718
# case2: flag={'reassoc', 'nsz'}
1819
fastmath._set("fastmath", "_add_assoc", flag={"reassoc", "nsz"})
20+
cache._recompile()
1921
out = fastmath._add_assoc(0, np.inf)
2022
if numba.config.DISABLE_JIT:
2123
assert np.isnan(out)
@@ -24,11 +26,13 @@ def test_set():
2426

2527
# case3: flag={'reassoc'}
2628
fastmath._set("fastmath", "_add_assoc", flag={"reassoc"})
29+
cache._recompile()
2730
out = fastmath._add_assoc(0, np.inf)
2831
assert np.isnan(out)
2932

3033
# case4: flag={'nsz'}
3134
fastmath._set("fastmath", "_add_assoc", flag={"nsz"})
35+
cache._recompile()
3236
out = fastmath._add_assoc(0, np.inf)
3337
assert np.isnan(out)
3438

@@ -39,7 +43,9 @@ def test_reset():
3943
# https://numba.pydata.org/numba-doc/dev/user/performance-tips.html#fastmath
4044
# and then reset it to the default value, i.e. `True`
4145
fastmath._set("fastmath", "_add_assoc", False)
46+
cache._recompile()
4247
fastmath._reset("fastmath", "_add_assoc")
48+
cache._recompile()
4349
if numba.config.DISABLE_JIT:
4450
assert np.isnan(fastmath._add_assoc(0.0, np.inf))
4551
else: # pragma: no cover

0 commit comments

Comments
 (0)