Loss varies with params but gradient if NULL? Bug? #8411
-
Hello guys. I have finally gotten a small example of the worst problem I face when dealing with JAX (although I love it). Here is a map from various 1D distributions (one per row) and a 1D Green function to the full responses (which are obtained by their convolution). from jax import numpy as jnp, lax, random, value_and_grad
# defining forward problem: convolution
def convolve(params):
'''Convolves rows of distribution (2D) with a function (1D).'''
return lax.conv(params['distribution'][None, None, :, :],
params['function'][None, None, None, ::-1],
window_strides=(1, 1), padding='SAME')[1, 1, :, :]
# defining ground-truth
params_true = {
'distribution': random.normal(random.PRNGKey(0), (20,100)),
'function': random.normal(random.PRNGKey(1), (10,)),
}
response_true = convolve(params_true) Now I want to get to ground truth by optimizing some initial parameters (distributions and function): # making loss gradient function
def mse_loss_fun(params):
return jnp.mean((convolve(params) - response_true)**2)
loss_and_grad = value_and_grad(mse_loss_fun)
params_init = {
'distribution': random.normal(random.PRNGKey(2), (20,100)),
'function': random.normal(random.PRNGKey(3), (10,)),
}
loss, grad = loss_and_grad(params_init) Everything good right? But when I look at the loss gradients... print(loss) # non-zero
print(jnp.all(jnp.isclose(grad['distribution'], 0))) # True
print(jnp.all(jnp.isclose(grad['function'], 0))) # True Does anyone know why the gradients are null? When I change the parameters (i.e. by changing initializer PRNGKeys) the loss clearly varies, so it does not make sense to me to have perfect zeros in here. As it is a simple model, could it be any kind of bug? I had this problem in at least two moments before... In such moments I had to change my way and get back to Tensorflow (which I do not currently like). Thank you for your attention! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the clear repro. It looks like it's some sort of issue with reverse-mode autodiff of params = {'distribution': random.normal(random.PRNGKey(0), (1, 3)),
'function': random.normal(random.PRNGKey(1), (2,))}
print(jacfwd(convolve)(params))
# {'distribution':
# DeviceArray([[[[ 2.2125056 , -0.11617047, 0. ]],
# [[ 0. , 2.2125056 , -0.11617047]],
# [[ 0. , 0. , 2.2125056 ]]]], dtype=float32),
# 'function':
# DeviceArray([[[-0.4826233 , 1.8160859 ],
# [ 0.33988902, -0.4826233 ],
# [ 0. , 0.33988902]]], dtype=float32)}
print(jacrev(convolve)(params))
# {'distribution':
# DeviceArray([[[[0., 0., 0.]],
# [[0., 0., 0.]],
# [[0., 0., 0.]]]], dtype=float32),
# 'function':
# DeviceArray([[[0., 0.],
# [0., 0.],
# [0., 0.]]], dtype=float32)}
|
Beta Was this translation helpful? Give feedback.
Thanks for the clear repro. It looks like it's some sort of issue with reverse-mode autodiff of
lax.conv
. Look at the difference between forward-mode and reverse-mode: