Replies: 1 comment
-
I'm not sure what you mean by "detect" here, but one way to see which parameters in your computation are unused is to print the jaxpr and look for parameters that are assigned to print(jax.make_jaxpr(f_tr.apply)(params, None, x))
These are among the parameters that will be eliminated as dead code by the compiler before execution. Does that answer your question? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
When I pass a large PyTree of neural network parameters to a jitted function, it may contain certain leaves that is not used anywhere inside that function. Sometimes, this indicates a bug in code. So, is there any way to detect this?
For example, haiku creates parameters for all
get_parameter
calls. So,params
will contain weights that is eventually unused (after jit optimization off_tr.apply
)Beta Was this translation helpful? Give feedback.
All reactions