Skip to content

Overloading jax operations to avoid undefined derivatives #35

@alleSini99

Description

@alleSini99

Hello,
I would like to overload some jax operations to work with cases where the derivative of a function f is undefined/diverging, but the derivative of a composite function of f, let's say g(f), is defined.
I report here an example. In this case f is -inf and so exp(f) is 0. If I compute the derivative of f with jax.grad this is inf as expected, so the derivative of exp(f) is nan since this is computed from the chain rule as d exp(f) = exp(f) df and I get a 0 times inf. However, thanks to mathematical simplifications, the correct d exp(f) is finite and I would like to obtain this as the result. We have partially solved the problem by defining special versions of log and exp based on equinox classes, but is there a general way to handle this issue for arbitrary jax operations within quad?
Example:

import jax.numpy as jnp 
import jax 
import equinox as eqx

def psii(theta, x): 
    return theta * x
thetas = jnp.array([1., 1.])
def f(theta, x): 
    return jnp.log(psii(theta[0], x) - psii(theta[1], x))
expf = lambda thetas, x: jnp.exp(f(thetas, x))
x = 0.5
print("f:", f(thetas, x))
print("expf:", expf(thetas, x))
print("df:", jax.grad(f)(thetas, x))
print("d expf:", jax.grad(expf)(thetas, x))

class MagicLog(eqx.Module):
    val : jax.Array
    def __init__(self, x):
        self.val = x
    def __jax_array__(self, dtype=None): 
        return jnp.log(self.val)
def magiclog(x):
    return MagicLog(x)
def magicexp(x):
    if isinstance(x, MagicLog):
        return x.val
    else:
        return jnp.exp(x)
def magiclogpsi(theta, x): 
    return magiclog(psii(theta[0], x) - psii(theta[1], x))
magicexpf = lambda thetas, x: magicexp(magiclogpsi(thetas, x))
print("magic expf:", magicexpf(thetas, x))
print("magic d expf:", jax.grad(magicexpf)(thetas, x))

@PhilipVinc

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions