JAX + PETSc in-place edits #29655
Unanswered
Cattaneo123
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
This may be a bit confusing, so please let me know if I can clarify something. I am trying to get petsc4py and JAX to work together, so far this has been working rather well but there is one small hitch. I can only seem to get no-copy things to work in one direction, not the other. That is to say if I create a PETSc vector based on a JAX array and I make an edit to it using the PETSc setValue functionality the JAX memory is updated as expected. However if I attempt the opposite, using JAXs set() function it makes a copy. Is there a way to prevent this? Here is a little example.
I have also tried returning set_in_place output to a difference variable y in case it needed the in_place_update to still be referenced by a different name, but that doesn't appear to have helped.
Are there any workarounds? I suspect the XLA compiler sees that something is still holding onto the referenced array x and as such it's not performing an in-place update, but I'm not certain, and even if I was I wouldn't know quite how to solve the problem.
Beta Was this translation helpful? Give feedback.
All reactions