Skip to content

Fixed #1069 Added numba cache dir for pytest #1070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 70 additions & 14 deletions stumpy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CACHE_WARNING += "and should never be used or depended upon as it is not supported! "
CACHE_WARNING += "All caching capabilities are not tested and may be removed/changed "
CACHE_WARNING += "without prior notice. Please proceed with caution!"
CACHE_CLEARED = True


def get_njit_funcs():
Expand Down Expand Up @@ -102,58 +103,78 @@ def _enable():
raise


def _clear():
def _clear(cache_dir=None):
"""
Clear numba cache

Parameters
----------
None
cache_dir : str, default None
The path to the numba cache directory

Returns
-------
None
"""
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
global CACHE_CLEARED

if cache_dir is not None:
numba_cache_dir = str(cache_dir)
else: # pragma: no cover
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"

[f.unlink() for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]

CACHE_CLEARED = True


def clear():
def clear(cache_dir=None):
"""
Clear numba cache directory

Parameters
----------
None
cache_dir : str, default None
The path to the numba cache directory. When `cache_dir` is `None`, then this
defaults to `site-packages/stumpy/__pycache__`.

Returns
-------
None
"""
warnings.warn(CACHE_WARNING)
_clear()
_clear(cache_dir)

return


def _get_cache():
def _get_cache(cache_dir=None):
"""
Retrieve a list of cached numba functions

Parameters
----------
None
cache_dir : str
The path to the numba cache directory

Returns
-------
out : list
A list of cached numba functions
"""
warnings.warn(CACHE_WARNING)
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
return [f.name for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]
if cache_dir is not None:
numba_cache_dir = str(cache_dir)
else: # pragma: no cover
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"

return [
f"{numba_cache_dir}/{f.name}"
for f in pathlib.Path(numba_cache_dir).glob("*nb*")
if f.is_file()
]


def _recompile():
Expand Down Expand Up @@ -202,16 +223,24 @@ def _save():
-------
None
"""
global CACHE_CLEARED

if not CACHE_CLEARED: # pragma: no cover
msg = "Numba njit cached files are not cleared before saving/overwriting. "
msg = "You may need to call `cache.clear()` before calling `cache.save()`."
warnings.warn(msg)

_enable()
_recompile()

CACHE_CLEARED = False

return


def save():
"""
Save/overwrite all the cache data files of
all-so-far compiled njit functions.
Save/overwrite all of the cached njit functions.

Parameters
----------
Expand All @@ -220,13 +249,40 @@ def save():
Returns
-------
None

Notes
-----
The cache is never cleared before saving/overwriting and may be explicitly cleared
by calling `cache.clear()` before saving. It is best practice to call `cache.save()`
only after calling all of your `njit` functions. If `cache.save()` is called for the
first time (before any `njit` function is called) then only the `.nbi` files (i.e.,
the "cache index") for all `njit` functions are saved. As each `njit` function (and
sub-functions) is called then their corresponding `.nbc` file (i.e., "object code")
is saved. Each `.nbc` file will only be saved after its `njit` function is called
at least once. However, subsequent calls to `cache.save()` (after clearing the cache
via `cache.clear()`) will automatically save BOTH the `.nbi` files as well as the
`.nbc` files as long as their `njit` function has been called at least once.

Examples
--------
>>> import stumpy
>>> from stumpy import cache
>>> import numpy as np
>>> cache.clear()
>>> mp = stumpy.stump(np.array([584., -11., 23., 79., 1001., 0., -19.]), m=3)
>>> cache.save()
"""
if numba.config.DISABLE_JIT:
msg = "Could not save/cache function because NUMBA JIT is disabled"
warnings.warn(msg)
else: # pragma: no cover
warnings.warn(CACHE_WARNING)

if numba.config.CACHE_DIR != "": # pragma: no cover
msg = "Found user specified `NUMBA_CACHE_DIR`/`numba.config.CACHE_DIR`. "
msg += "The `stumpy` cache files may not be saved/cleared correctly!"
warnings.warn(msg)

_save()

return
8 changes: 6 additions & 2 deletions stumpy/fastmath.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import warnings

import numba
from numba import njit
Expand Down Expand Up @@ -55,12 +56,15 @@ def _set(module_name, func_name, flag):
func = getattr(module, func_name)
try:
func.targetoptions["fastmath"] = flag
func.recompile()
msg = "One or more fastmath flags have been set/reset. "
msg += "Please call `cache._recompile()` to ensure that all njit functions "
msg += "are properly recompiled."
warnings.warn(msg)
except AttributeError as e:
if numba.config.DISABLE_JIT and (
str(e) == "'function' object has no attribute 'targetoptions'"
or str(e) == "'function' object has no attribute 'recompile'"
):
warnings.warn("Fastmath flags could not be set as Numba JIT is disabled")
pass
else: # pragma: no cover
raise
Expand Down
28 changes: 20 additions & 8 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numba
import numpy as np

from stumpy import cache, stump
Expand All @@ -11,17 +12,28 @@ def test_cache_get_njit_funcs():
def test_cache_save_after_clear():
T = np.random.rand(10)
m = 3

cache_dir = "stumpy/__pycache__"

cache.clear(cache_dir)
stump(T, m)
cache.save() # Enable and save both `.nbi` and `.nbc` cache files

cache.save()
ref_cache = cache._get_cache()
ref_cache = cache._get_cache(cache_dir)

cache.clear()
# testing cache._clear()
assert len(cache._get_cache()) == 0
if numba.config.DISABLE_JIT:
assert len(ref_cache) == 0
else: # pragma: no cover
assert len(ref_cache) > 0

cache.save()
comp_cache = cache._get_cache()
cache.clear(cache_dir)
assert len(cache._get_cache(cache_dir)) == 0
# Note that `stump(T, m)` has already been called once above and any subsequent
# calls to `cache.save()` will automatically save both `.nbi` and `.nbc` cache files
cache.save() # Save both `.nbi` and `.nbc` cache files

comp_cache = cache._get_cache(cache_dir)

# testing cache._save() after cache._clear()
assert sorted(ref_cache) == sorted(comp_cache)

cache.clear(cache_dir)
8 changes: 7 additions & 1 deletion tests/test_fastmath.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numba
import numpy as np

from stumpy import fastmath
from stumpy import cache, fastmath


def test_set():
Expand All @@ -11,11 +11,13 @@ def test_set():

# case1: flag=False
fastmath._set("fastmath", "_add_assoc", flag=False)
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
assert np.isnan(out)

# case2: flag={'reassoc', 'nsz'}
fastmath._set("fastmath", "_add_assoc", flag={"reassoc", "nsz"})
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
if numba.config.DISABLE_JIT:
assert np.isnan(out)
Expand All @@ -24,11 +26,13 @@ def test_set():

# case3: flag={'reassoc'}
fastmath._set("fastmath", "_add_assoc", flag={"reassoc"})
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
assert np.isnan(out)

# case4: flag={'nsz'}
fastmath._set("fastmath", "_add_assoc", flag={"nsz"})
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
assert np.isnan(out)

Expand All @@ -39,7 +43,9 @@ def test_reset():
# https://numba.pydata.org/numba-doc/dev/user/performance-tips.html#fastmath
# and then reset it to the default value, i.e. `True`
fastmath._set("fastmath", "_add_assoc", False)
cache._recompile()
fastmath._reset("fastmath", "_add_assoc")
cache._recompile()
if numba.config.DISABLE_JIT:
assert np.isnan(fastmath._add_assoc(0.0, np.inf))
else: # pragma: no cover
Expand Down