-
-
Notifications
You must be signed in to change notification settings - Fork 5
Description
This is more of a question, but I didn't see a discussions tab. Feel free to move/close if not appropriate :)
I'm trying to pass some quax values through a JVP computation. In principle, it seems to work fine, but I'm not able to find a way to take the JVP of a subset of arguments only. Here is a small example.
Suppose we have a simple array wrapper class:
class ArrayWrapper(quax.ArrayValue):
array: jax.Array
def aval(self):
return self.array.aval
def materialise(self):
raise NotImplementedError
And I've defined dot_general
on it:
@quax.register(jax.lax.dot_general_p)
def dot_general(a: ArrayWrapper, b: ArrayWrapper, **params):
return ArrayWrapper(jax.lax.dot_general_p.bind(a.array, b.array, **params))
@quax.register(jax.lax.dot_general_p)
def dot_general(a, b: ArrayWrapper, **params):
return ArrayWrapper(jax.lax.dot_general_p.bind(a, b.array, **params))
@quax.register(jax.lax.dot_general_p)
def dot_general(a: ArrayWrapper, b, **params):
return ArrayWrapper(jax.lax.dot_general_p.bind(a.array, b, **params))
Now, we can evaluate jnp.dot
:
x = ArrayWrapper(1 + jnp.arange(3, dtype=float))
quax_dot = quax.quaxify(jnp.dot)
quax_dot(x, x)
# Output: ArrayWrapper(14.0)
Then, I would like to take the JVP of jnp.dot
. I've found this is possible without any further changes by wrapping the jax.jvp
call in quaxify
:
quax.quaxify(lambda p, t: jax.jvp(jnp.dot, p, t))((x, x), (x, x))
# Output: (ArrayWrapper(14.0), ArrayWrapper(28.0))
Now, suppose I only want the JVP with respect to the second argument. Normally, this is possible by a simple partial application:
(lambda p, t: jax.jvp(partial(jnp.dot, x.array), p, t))((x.array,), (x.array,))
# Output: (14.0, 14.0)
However, I can't seem to find a way to accomplish this while preserving my ArrayWrapper
types. The naive partial application:
quax.quaxify(lambda p, t: jax.jvp(partial(jnp.dot, x), p, t))((x,), (x,))
doesn't work of course, since the partially applied x
does not pass through quaxify
. However, using a nested quaxify
also doesn't work correctly:
quax.quaxify(lambda p, t: jax.jvp(partial(quax.quaxify(jnp.dot), x), p, t))((x,), (x,))
# Output: (ArrayWrapper(ArrayWrapper(14.0)), ArrayWrapper(ArrayWrapper(14.0)))
I've tried a number of different ways of nesting quaxify
and using filter_spec
but I can't seem to find a way to get this to work as expected. To be precise, what I would expect as result is (ArrayWrapper(14.0), ArrayWrapper(14.0))
. Of course, to obtain this I could simply set the tangents of the other arguments to zero:
quax.quaxify(lambda p, t: jax.jvp(jnp.dot, p, t))((x, x), (ArrayWrapper(jnp.zeros_like(x.array)), x))
# Output: (ArrayWrapper(14.0), ArrayWrapper(14.0))
But this seems like more of a workaround, and is not possible if some of the arguments are not differentiable.
Am I missing some proper way to do this or am I asking for something that for a good reason is not possible?