Skip to content

Figure out SciPy use-case #12

@seberg

Description

@seberg

@lucascolley mentioned that SciPy has currently a light-weight dispatching that is limited to three specific types though (numpy, cupy, torch). Implemented across:

I think that everything for this use-case (practically) in-place, with maybe slight additions and some points on how to achieve things.
(Plus, one thing I realized is that the way we document backends may be very much open and likely needs customization for libraries.)

The basic pattern here would probably be (untested):

import abc

class JaxLike(abc.ABC):
    # We could just match "~jax.numpy:ndarray" below as well.
    # this is (maybe ridiculously so) closer to the original
    # by going via `__array_namespace__`.  (N.B. we look it up
    # on the class)
    @classmethod
    def __subclasscheck__(cls, other):
        get_xp = getattr(other, "__array_namespace__")
        if get_xp is None:
            return False

        # Ideally, we could just do this, but JAX doesn't define
        # this as a class method, but in spatch we prefer to match
        # types.  (Otherwise, we would just have to match anything
        # with `__array_function__` and use `should_run`...
        try:
           xp = get_xp()
        except Exception:
           pass
        else:
           return is_jax(xp)

        # Just guess that if it came from jax a jax module
        # we can use the JAX version.
        return getattr(xp, "__module__", "").startswith("jax.")


class JaxBackend:
    name = "jax_array_api"
    require_opt_in = False  # pure type dispatcher
    # Note that the "@" indicates that JaxLike is an abstract class
    primary_types = [f"@{JaxLike.__module__}:JaxLike"]
    secondary_types = []  # none, we don't allow NumPy coercion.
    functions = {
        # trivial to generate, you can do so even on import.
        "scipy.submodule:name": {
            "function": "jax.submodule:name",
            # more options, we don't need right now.
        }
    }


from spatch.backend_system import BackendSystem

_bs = BackendSystem(
    None,  # don't load entry-points for now (no 3rd party backends)
    "_SCIPY_INTERNAL_BACKENDS",  # spatch env-var prefix
    # could also allow subclasses for ndarray with "~numpy:ndarray".
    # Doesn't really matter, since we'll fall back to it anyway:
    default_primary_types=["numpy:ndarray"],
    # Register SciPy internal backends explicitly:
    backends=[JaxBackend],
)


@_bs.dispatchable(["argname"])
def scipy_function(argname, other_arg):
    # SciPy (numpy/fallback) version.

The JaxLike is a bit much, because spatch insists (or strongly prefers) to use type based semantics, but the Array API/JAX don't allow us to match "jax array API like" just based on the type (which could be changed there for sure), but this'll work in practice (or just change to list JAX types explicitly...).

Now, you are using also NumPy's __array_function__ pattern to have an ..._signature function that extracts the "relevant arguments", while the above passes a list of argument names instead.
The dedicated function version will be a lot faster for now (unless we move things to C, it may be a 1/3 or higher difference!).
I think we need to allow a callable there anyway if just for complex cases (i.e. numpy.concatenate considers all entries in the first argument as "relevant args", it needs to unpack that sequence.)

So @lucascolley if you are curious to test things out and need e.g. the version where you can do @_bs.dispatchable(..._signature), let's just add it here quickly (it is very easy of course).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions