vmap for a class instance or member funtions #18560
Replies: 2 comments
-
Hi - thanks for the question! JAX transformations like It's hard to say more from the stripped-down example that you give, but I'd try re-expressing your function in a way that avoids implicit mutation: for example you could register |
Beta Was this translation helpful? Give feedback.
-
If you want to do "classes with JAX", I'd recommend giving Equinox a try. +1 to Jake's comment about avoiding mutation, as this is often a footgun when used with JAX transformations. (And something like Equinox will do its best to guide you along the happy path.) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all, I am curious about the implementation of
vmap
function. I want to use it to do auto-batching.My case is like:
As can be seen,
my_class.update_paras
is a member function that would modify some attributes in the class. Isvmap
a multiple-threads function? Would it cause synchronization problem? What is the safe usage ofvmap
? Thanks!Beta Was this translation helpful? Give feedback.
All reactions