Unbatch a tree
#8518
Replies: 2 comments
-
Hey @limbryan,
|
Beta Was this translation helpful? Give feedback.
0 replies
-
Hey @cgarciae, Thanks for the quick reply and some suggestions.
Hope this helps clear things up |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I am trying to unbatch a tree into a list of trees with the same structure of the length of the batch (more particularly for a flax NN structure). My attempt at this works but is really slow and time consuming and does not scale with the batch size. You can find the function I currently use below:
Would appreciate if you could point me in the right direction or an implementation that would do this much more efficiently. I am probably just missing some functionality of jax that I am not aware of that allows me to do this easily. Thanks!
Beta Was this translation helpful? Give feedback.
All reactions