Improving JIT and tree_map compatibility: separating state and structure #513
theo-brown
started this conversation in
Ideas
Replies: 1 comment 1 reply
-
Hi @theo-brown . I 100% agree with your proposal here. To be frank, the reason we didn't do this is that NNX was still evolving as we migrated to it, and likely this detail passed us by. If you'd be willing to lead the API change, then I'd most happily support and integrate when ready. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi! This is a suggestion crossed with a question.
In Flax, the 'correct' way of doing things is to separate the state and the structure. For example,
This enables better JIT compatibility, as functions can be JITted once based on the structure of the model (e.g. a GP with constant mean and a Matern52 kernel), and then reused even if the parameters change (e.g. a modified lengthscale from marginal likelihood fitting to some new data). It also might fix the problem I encountered in #516.
So the question/suggestion is: why is this not done in GPJax? It would be fairly straightforward to modify some of the APIs (especially the fit functions) so that the
nnx.GraphDef
andnnx.State
are handled as separate arguments. In fact, I've ended up doing this in my project that uses GPjax, so I'd be happy to propagate some of the changes over.While this may not make a difference to the actual runtime of performing GP inference - because presumably the matrix inversions etc are all JITted anyway? - it would make it a lot easier to integrate GPJax GPs with more complex routines where GPs must be passed around between JITted functions (e.g. in some more complex BayesOpt workflows).
Beta Was this translation helpful? Give feedback.
All reactions