Object that is jax valid type #22029
-
I'm trying to create a class to track my model losses. I'm unable to use a normal python object inside flax.train_state.TrainState. I went through the Jax documentation and I added these below methods to my class.
When I did the above I found that when my data members like tp, fp, etc, are jnp.arrays the time taken to update these data members increases a lot and when I leave them as python scalar , the time taken to update them is very less even if the data member holds a object with same size. All the updates to these tracker's data members like 'fp' are called from jax.pmap function which automatically 'jit'. So my question is, is there a way to solve this time taken issue or is there an alternate way to create loss and metric trackers instead of the above method? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 7 replies
-
Hi - indeed the pytree registry is the best way to create a custom type that is compatible with JAX transformations. Regarding your question, I don't entirely understand when you say "the time taken to update these data members increases a lot". What is the context of these updates? Can you show an example of the issue? |
Beta Was this translation helpful? Give feedback.
Initializing those data members as np.array instead of jnp.array solved my issue. Still don't understand the reason.