Gradient of primitive and confusion #19514
Unanswered
ericmjonas
asked this question in
General
Replies: 1 comment 4 replies
-
Can I ask why you're defining primitives at all? A big reason that |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello ! I'm filing this as a "discussion" because I'm sure it's an error in my understanding and not actually a bug. I'm trying to fit a function's gradient with machine learning, where the function contains calls to a custom primitive p(x) which is backed by a pile of C++/CUDA.
In pcode I have:
Technically I have two primitives, a fwd and a bwd that I then set up with defvjp. As far as I can tell, this should only ever require the VJP of
my_func
.However, I am getting the error:
Differentiation rule for 'my_bwd_p' not implemented
which is confusing me, as I really don't think we should need anything beyond first derivatives formy_func
.I have constructed a full example for my own primitive that just implements
sin
, below. I've tried everything I can think of, including liberal application ofjax.lax.stop_gradient
. The real code I care about formy_func
'sfwd
andbwd
primitives is incredibly complicated, and the idea of implementing a higher-order gradient (that is, the derivative rule forbwd
) is sort of soul-crushing. But I don't think it should be necessary?I'm attaching a full example of what I'm running into below, and it is also available as a collab notebook here: https://colab.research.google.com/drive/1T0HcQlELiUJVw6OhptNox6Ed_LsaCYOm?usp=sharing
Beta Was this translation helpful? Give feedback.
All reactions