-
Notifications
You must be signed in to change notification settings - Fork 4
Description
@lucascolley mentioned that SciPy has currently a light-weight dispatching that is limited to three specific types though (numpy, cupy, torch). Implemented across:
- https://github.com/scipy/scipy/blob/main/scipy/signal/_signal_api.py
- https://github.com/scipy/scipy/blob/main/scipy/signal/_support_alternative_backends.py
- https://github.com/scipy/scipy/blob/main/scipy/signal/_delegators.py
I think that everything for this use-case (practically) in-place, with maybe slight additions and some points on how to achieve things.
(Plus, one thing I realized is that the way we document backends may be very much open and likely needs customization for libraries.)
The basic pattern here would probably be (untested):
import abc
class JaxLike(abc.ABC):
# We could just match "~jax.numpy:ndarray" below as well.
# this is (maybe ridiculously so) closer to the original
# by going via `__array_namespace__`. (N.B. we look it up
# on the class)
@classmethod
def __subclasscheck__(cls, other):
get_xp = getattr(other, "__array_namespace__")
if get_xp is None:
return False
# Ideally, we could just do this, but JAX doesn't define
# this as a class method, but in spatch we prefer to match
# types. (Otherwise, we would just have to match anything
# with `__array_function__` and use `should_run`...
try:
xp = get_xp()
except Exception:
pass
else:
return is_jax(xp)
# Just guess that if it came from jax a jax module
# we can use the JAX version.
return getattr(xp, "__module__", "").startswith("jax.")
class JaxBackend:
name = "jax_array_api"
require_opt_in = False # pure type dispatcher
# Note that the "@" indicates that JaxLike is an abstract class
primary_types = [f"@{JaxLike.__module__}:JaxLike"]
secondary_types = [] # none, we don't allow NumPy coercion.
functions = {
# trivial to generate, you can do so even on import.
"scipy.submodule:name": {
"function": "jax.submodule:name",
# more options, we don't need right now.
}
}
from spatch.backend_system import BackendSystem
_bs = BackendSystem(
None, # don't load entry-points for now (no 3rd party backends)
"_SCIPY_INTERNAL_BACKENDS", # spatch env-var prefix
# could also allow subclasses for ndarray with "~numpy:ndarray".
# Doesn't really matter, since we'll fall back to it anyway:
default_primary_types=["numpy:ndarray"],
# Register SciPy internal backends explicitly:
backends=[JaxBackend],
)
@_bs.dispatchable(["argname"])
def scipy_function(argname, other_arg):
# SciPy (numpy/fallback) version.
The JaxLike
is a bit much, because spatch insists (or strongly prefers) to use type based semantics, but the Array API/JAX don't allow us to match "jax array API like" just based on the type (which could be changed there for sure), but this'll work in practice (or just change to list JAX types explicitly...).
Now, you are using also NumPy's __array_function__
pattern to have an ..._signature
function that extracts the "relevant arguments", while the above passes a list of argument names instead.
The dedicated function version will be a lot faster for now (unless we move things to C, it may be a 1/3 or higher difference!).
I think we need to allow a callable there anyway if just for complex cases (i.e. numpy.concatenate
considers all entries in the first argument as "relevant args", it needs to unpack that sequence.)
So @lucascolley if you are curious to test things out and need e.g. the version where you can do @_bs.dispatchable(..._signature)
, let's just add it here quickly (it is very easy of course).