Skip to content

Loss varies with params but gradient if NULL? Bug? #8411

Answered by jakevdp
marcosrdac asked this question in General
Discussion options

You must be logged in to vote

Thanks for the clear repro. It looks like it's some sort of issue with reverse-mode autodiff of lax.conv. Look at the difference between forward-mode and reverse-mode:

params = {'distribution': random.normal(random.PRNGKey(0), (1, 3)),
          'function': random.normal(random.PRNGKey(1), (2,))}

print(jacfwd(convolve)(params))
# {'distribution':
#    DeviceArray([[[[ 2.2125056 , -0.11617047,  0.        ]],
#                  [[ 0.        ,  2.2125056 , -0.11617047]],
#                  [[ 0.        ,  0.        ,  2.2125056 ]]]], dtype=float32),
#  'function':
#    DeviceArray([[[-0.4826233 ,  1.8160859 ],
#                  [ 0.33988902, -0.4826233 ],
#                  [ 0.        ,  …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@marcosrdac
Comment options

Answer selected by marcosrdac
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants