Skip to content

Class methods vs. quaxified functions #33

@vadmbertr

Description

@vadmbertr

Hi,

I am really happy to see some multiple dispatching mechanism brought to JAX, thanks for that!

I ran a highly toy use case to compare some timings between using quaxified functions and class methods and I was quite surprised by the results.
From the snippet below, is there something I have missed? Maybe quax targets more large, already-implemented models/functions?

from __future__ import annotations

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike
import quax


# using equinox first
class ArrayIsh(eqx.Module):
    array: ArrayLike

    @staticmethod
    def __get_other_array(other: ArrayIsh | ArrayLike) -> ArrayLike:
        if isinstance(other, ArrayIsh):
            return other.array
        else:
            return other

    def __sub__(self, other: ArrayIsh | ArrayLike):
        return self.array + self.__get_other_array(other)

    @classmethod
    def from_array(cls, array: ArrayLike):
        return cls(array)


array_ish1 = ArrayIsh.from_array(jnp.full((10, 2), 5.))
array_ish2 = ArrayIsh.from_array(jnp.full((10, 2), 3.))
array_like = jnp.full((10, 2), 2.)

# substract array attributes directly
%timeit (array_ish1.array - array_ish2.array).block_until_ready()
3.6 μs ± 48.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# substract ArrayIsh objects using __sub__ method
%timeit (array_ish1 - array_ish2).block_until_ready()
5.1 μs ± 246 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# substract ArrayIsh object and ArrayLike using __sub__ method
%timeit (array_ish1 - array_like).block_until_ready()
5.26 μs ± 21.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# now using quax
class ArrayIsh(quax.ArrayValue):
    array: ArrayLike

    def aval(self):
        shape = jnp.shape(self.array)
        dtype = jnp.result_type(self.array)
        return jax.core.ShapedArray(shape, dtype)

    def materialise(self):
        raise ValueError("Refusing to materialise ArrayIsh array.")

    @classmethod
    def from_array(cls, array: ArrayLike):
        return cls(array)


@quax.register(jax.lax.sub_p)
def _(x: ArrayIsh, y: ArrayIsh):
    return x.array - y.array

@quax.register(jax.lax.sub_p)
def _(x: ArrayIsh, y: ArrayLike):
    return x.array - y

@quax.quaxify
def sub(x, y):
    return x - y


array_ish1 = ArrayIsh.from_array(jnp.full((10, 2), 5.))
array_ish2 = ArrayIsh.from_array(jnp.full((10, 2), 3.))

# substract array attributes directly
%timeit (array_ish1.array - array_ish2.array).block_until_ready()
3.64 μs ± 44.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# substract ArrayIsh objects using quaxify registered function sub
%timeit sub(array_ish1, array_ish2).block_until_ready()
440 μs ± 4.89 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# substract ArrayIsh object and ArrayLike using quaxify registered function sub
%timeit sub(array_ish1, array_like).block_until_ready()
425 μs ± 3.77 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Again, thanks for this library and the feedback.

Vadim

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions