Replies: 3 comments 1 reply
-
Could you give us any example to illustrate your need? I am not sure what you really intend to do, but Equinox can handle Pytrees operations very well. |
Beta Was this translation helpful? Give feedback.
-
You can use pytrees directly within JIT-compiled functions. You do not need to manually flatten and unflatten them: flattening and unflattening are internal mechanisms used by JAX at the boundaries of transformations like Beyond that, it's hard to answer your question because the details are very vague. I'd suggest putting together a short example illustrating the kind of operation you have in mind. |
Beta Was this translation helpful? Give feedback.
-
Hi, thank you for your fast responses and sorry for my late one. I've tried to condense my case as much as possible, while keeping the most core stuff I think I could be messing up. This is my very much simplified code of my pytrees:
This is an approximation of the transform I want to reach:
And this is an approximation of the transform I have now:
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I've taken a liking to Jax as an oportunity to bridge base Python and ml, while keeping performance.
My specific use case would be to use a regular Python class decorated as a pytree, with util functions to allow easy formatting and input of data.
I've created a simple training loop with such class, and I'd like to fit it using the pytrees. However, it is unclear to me what should be the workflow.
If possible, I'd like to get a pytree in and return a modified (copy) version of that same pytree, while being able to use the functions inside the pytree, if possible (all within the purity constraints of Jax).
Trying back and forth I got that I can flatten the pytree inside the function, and unflatten it at the end with the arrays. Although this works, it takes about 20s per sample of compilation the first time, and 0.1s later.
To me it seems I am in damp terrain, so I'd like to ask what is the actual expected workflow.
I know we can use datasets, but I wanted to check if it was possible to use pytrees inside jitted fns to use the pytree class functions to modify the functionality of the training (even if it has to recompile when you change the pytree class)
Beta Was this translation helpful? Give feedback.
All reactions