How to specify in_axes for custom pytree node? #16127
Answered
by
patrick-kidger
IrishWhiskey
asked this question in
Q&A
-
Let's say I have a function, which takes a custom object as input, that I want to transform with pmap. How should the If I run
I get Am I doing anything wrong? How can I fix it? |
Beta Was this translation helpful? Give feedback.
Answered by
patrick-kidger
May 26, 2023
Replies: 1 comment 11 replies
-
You want
|
Beta Was this translation helpful? Give feedback.
11 replies
Answer selected by
IrishWhiskey
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You want
in_axes=(C(0, 1),)
.in_axes
should be a tuple of length equal to your number of input arguments.