Efficient calculation of gradient for binary tree traversal #26452
Unanswered
ChrisBoettner
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hey everyone,
First of all, thanks for all the work! I've been learning jax more closely lately, and its a lot of fun. I am currently working on an implementation of the julia AutoGP.jl package in python using jax.
A big part of this work is that I have to evulate algebraic expressions that are defined over via a binary tree (tree leafs are kernel operations, and nodes are additions or multiplications). These expressions evolve dynamically over time, so I had to find a way to represent them statically for jax. In the end, I've decided to encode the tree structure as an array that get's traversed in level-order using a stack. The final functions looks like the following:
I am making use of an equinox bounded_while_loop, since the tree expression might be very long (corresponding to the maximum number of nodes of a binary tree of size d), but usually, only the first ~10ish entries are used, so that looping over the entire array would be a waste.
Now comes my problem though. This expression is reasonably fast to evaluate for two inputs, and can easily be vmap-ed (to calculate e.g. the cross_covariance and gram matrices). However, calculating the gradient slows down a lot. For singular inputs x and y, the kernel evaluation and gradient calculation take about the same time. But when vmap-ed, the computations that involve the gradient are about 10x slower than direct evaluation of the kernel. I am struggling to figure out why. Do you maybe have any ideas?
Beta Was this translation helpful? Give feedback.
All reactions