Skip to content

How to specify in_axes for custom pytree node? #16127

Discussion options

You must be logged in to vote

You want in_axes=(C(0, 1),).

  • in_axes should be a tuple of length equal to your number of input arguments.
  • the static data for the custom pytree node must be the same.

Replies: 1 comment 11 replies

Comment options

You must be logged in to vote
11 replies
@davisyoshida
Comment options

@darsnack
Comment options

@jakevdp
Comment options

@patrick-kidger
Comment options

@jakevdp
Comment options

Answer selected by IrishWhiskey
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
5 participants