-
-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Labels
questionUser queriesUser queries
Description
Hi,
I am really happy to see some multiple dispatching mechanism brought to JAX, thanks for that!
I ran a highly toy use case to compare some timings between using quaxified functions and class methods and I was quite surprised by the results.
From the snippet below, is there something I have missed? Maybe quax targets more large, already-implemented models/functions?
from __future__ import annotations
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike
import quax
# using equinox first
class ArrayIsh(eqx.Module):
array: ArrayLike
@staticmethod
def __get_other_array(other: ArrayIsh | ArrayLike) -> ArrayLike:
if isinstance(other, ArrayIsh):
return other.array
else:
return other
def __sub__(self, other: ArrayIsh | ArrayLike):
return self.array + self.__get_other_array(other)
@classmethod
def from_array(cls, array: ArrayLike):
return cls(array)
array_ish1 = ArrayIsh.from_array(jnp.full((10, 2), 5.))
array_ish2 = ArrayIsh.from_array(jnp.full((10, 2), 3.))
array_like = jnp.full((10, 2), 2.)
# substract array attributes directly
%timeit (array_ish1.array - array_ish2.array).block_until_ready()
3.6 μs ± 48.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# substract ArrayIsh objects using __sub__ method
%timeit (array_ish1 - array_ish2).block_until_ready()
5.1 μs ± 246 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# substract ArrayIsh object and ArrayLike using __sub__ method
%timeit (array_ish1 - array_like).block_until_ready()
5.26 μs ± 21.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# now using quax
class ArrayIsh(quax.ArrayValue):
array: ArrayLike
def aval(self):
shape = jnp.shape(self.array)
dtype = jnp.result_type(self.array)
return jax.core.ShapedArray(shape, dtype)
def materialise(self):
raise ValueError("Refusing to materialise ArrayIsh array.")
@classmethod
def from_array(cls, array: ArrayLike):
return cls(array)
@quax.register(jax.lax.sub_p)
def _(x: ArrayIsh, y: ArrayIsh):
return x.array - y.array
@quax.register(jax.lax.sub_p)
def _(x: ArrayIsh, y: ArrayLike):
return x.array - y
@quax.quaxify
def sub(x, y):
return x - y
array_ish1 = ArrayIsh.from_array(jnp.full((10, 2), 5.))
array_ish2 = ArrayIsh.from_array(jnp.full((10, 2), 3.))
# substract array attributes directly
%timeit (array_ish1.array - array_ish2.array).block_until_ready()
3.64 μs ± 44.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# substract ArrayIsh objects using quaxify registered function sub
%timeit sub(array_ish1, array_ish2).block_until_ready()
440 μs ± 4.89 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# substract ArrayIsh object and ArrayLike using quaxify registered function sub
%timeit sub(array_ish1, array_like).block_until_ready()
425 μs ± 3.77 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Again, thanks for this library and the feedback.
Vadim
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries