Different PyTree for Primal and Tangent space #28723
Replies: 2 comments
-
Great question! I don't have all the relevant context here so perhaps @mattjj, @froystig, or @dougalm will chime in with more details, but the short answer is that the general thing that you're asking about isn't currently supported in JAX, although it's always been something we hope to support one day. The more specific side point about using different dtypes for the primal and tangent is approximately supported as discussed in this issue comment, but it's a pretty experimental feature, and hopefully we can eventually support the more general use case! |
Beta Was this translation helpful? Give feedback.
-
Thanks a lot for answering. The dtype thing I was not aware of, thank you for pointing that out! IIUC the API is a bit too rough for our current use case, but that's definitely something useful for another project. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
Say I have a PyTree that represents some data structure.
In some cases, it might be beneficial to have a different data structure in tangent space. For example, Primal might be a more complex buffer structure, and Tangent a lossy emulation because we want to save memory by storing a reduced representation of the gradient
As a requirement, that would mean that the only way to interact with such datastructure would be functions with custom_jvp's
As a very simplified example:
A potential way to construct such object could be
This fails with
Thus, this it seems this is not supported in current jax. The add() function doesn't work as
t
is a Primal and not a Tangent.Is there some way to achieve this in current jax or is this maybe planned in the future?
Thanks,
Lennart
Beta Was this translation helpful? Give feedback.
All reactions