Skip to content

Using quax with jax.jvp #65

@nardi

Description

@nardi

This is more of a question, but I didn't see a discussions tab. Feel free to move/close if not appropriate :)

I'm trying to pass some quax values through a JVP computation. In principle, it seems to work fine, but I'm not able to find a way to take the JVP of a subset of arguments only. Here is a small example.

Suppose we have a simple array wrapper class:

class ArrayWrapper(quax.ArrayValue):
    array: jax.Array

    def aval(self):
        return self.array.aval

    def materialise(self):
        raise NotImplementedError

And I've defined dot_general on it:

@quax.register(jax.lax.dot_general_p)
def dot_general(a: ArrayWrapper, b: ArrayWrapper, **params):
    return ArrayWrapper(jax.lax.dot_general_p.bind(a.array, b.array, **params))

@quax.register(jax.lax.dot_general_p)
def dot_general(a, b: ArrayWrapper, **params):
    return ArrayWrapper(jax.lax.dot_general_p.bind(a, b.array, **params))

@quax.register(jax.lax.dot_general_p)
def dot_general(a: ArrayWrapper, b, **params):
    return ArrayWrapper(jax.lax.dot_general_p.bind(a.array, b, **params))

Now, we can evaluate jnp.dot:

x = ArrayWrapper(1 + jnp.arange(3, dtype=float))
quax_dot = quax.quaxify(jnp.dot)

quax_dot(x, x)
# Output: ArrayWrapper(14.0)

Then, I would like to take the JVP of jnp.dot. I've found this is possible without any further changes by wrapping the jax.jvp call in quaxify:

quax.quaxify(lambda p, t: jax.jvp(jnp.dot, p, t))((x, x), (x, x))
# Output: (ArrayWrapper(14.0), ArrayWrapper(28.0))

Now, suppose I only want the JVP with respect to the second argument. Normally, this is possible by a simple partial application:

(lambda p, t: jax.jvp(partial(jnp.dot, x.array), p, t))((x.array,), (x.array,))
# Output: (14.0, 14.0)

However, I can't seem to find a way to accomplish this while preserving my ArrayWrapper types. The naive partial application:

quax.quaxify(lambda p, t: jax.jvp(partial(jnp.dot, x), p, t))((x,), (x,))

doesn't work of course, since the partially applied x does not pass through quaxify. However, using a nested quaxify also doesn't work correctly:

quax.quaxify(lambda p, t: jax.jvp(partial(quax.quaxify(jnp.dot), x), p, t))((x,), (x,))
# Output: (ArrayWrapper(ArrayWrapper(14.0)), ArrayWrapper(ArrayWrapper(14.0)))

I've tried a number of different ways of nesting quaxify and using filter_spec but I can't seem to find a way to get this to work as expected. To be precise, what I would expect as result is (ArrayWrapper(14.0), ArrayWrapper(14.0)). Of course, to obtain this I could simply set the tangents of the other arguments to zero:

quax.quaxify(lambda p, t: jax.jvp(jnp.dot, p, t))((x, x), (ArrayWrapper(jnp.zeros_like(x.array)), x))
# Output: (ArrayWrapper(14.0), ArrayWrapper(14.0))

But this seems like more of a workaround, and is not possible if some of the arguments are not differentiable.

Am I missing some proper way to do this or am I asking for something that for a good reason is not possible?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions