How do I do tree_map over a list of NN params? #16992
Unanswered
niladridas
asked this question in
General
Replies: 1 comment
-
depending on the network, The following command will give you a better idea of where the function is being applied on the pytree: jax.tree_util.tree_map(lambda x:f'function applied to x of shape {x.shape}', example_pytrees) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have a NN made is haiku as
net = hk.transform(net_fn)
.I have a training batch and then I do
params = net.init(jax.random.PRNGKey(42), next(iter(train_batches)))
to get the params.I make a pytree of params as :
example_pytrees = [params,params]
I am trying to iterate over this pytree to produce output from the NN, which is now
apply_net = lambda x: net.apply(x,jax.random.PRNGKey(0),next(iter(train_batches)))
.My question is:
apply_net(params)
works.But
jax.tree_util.tree_map(apply_net, example_pytrees)
shows error as:"params argument does not appear valid. It should be a mapping but is of type <class 'jaxlib.xla_extension.ArrayImpl'>.
What should I do?
Beta Was this translation helpful? Give feedback.
All reactions