-
-
Notifications
You must be signed in to change notification settings - Fork 5
Description
Hello,
I would like to overload some jax
operations to work with cases where the derivative of a function f
is undefined/diverging, but the derivative of a composite function of f
, let's say g(f)
, is defined.
I report here an example. In this case f
is -inf
and so exp(f)
is 0
. If I compute the derivative of f
with jax.grad
this is inf
as expected, so the derivative of exp(f)
is nan
since this is computed from the chain rule as d exp(f) = exp(f) df
and I get a 0
times inf
. However, thanks to mathematical simplifications, the correct d exp(f)
is finite and I would like to obtain this as the result. We have partially solved the problem by defining special versions of log
and exp
based on equinox
classes, but is there a general way to handle this issue for arbitrary jax
operations within quad
?
Example:
import jax.numpy as jnp
import jax
import equinox as eqx
def psii(theta, x):
return theta * x
thetas = jnp.array([1., 1.])
def f(theta, x):
return jnp.log(psii(theta[0], x) - psii(theta[1], x))
expf = lambda thetas, x: jnp.exp(f(thetas, x))
x = 0.5
print("f:", f(thetas, x))
print("expf:", expf(thetas, x))
print("df:", jax.grad(f)(thetas, x))
print("d expf:", jax.grad(expf)(thetas, x))
class MagicLog(eqx.Module):
val : jax.Array
def __init__(self, x):
self.val = x
def __jax_array__(self, dtype=None):
return jnp.log(self.val)
def magiclog(x):
return MagicLog(x)
def magicexp(x):
if isinstance(x, MagicLog):
return x.val
else:
return jnp.exp(x)
def magiclogpsi(theta, x):
return magiclog(psii(theta[0], x) - psii(theta[1], x))
magicexpf = lambda thetas, x: magicexp(magiclogpsi(thetas, x))
print("magic expf:", magicexpf(thetas, x))
print("magic d expf:", jax.grad(magicexpf)(thetas, x))