Replies: 2 comments 4 replies
-
Hi - I'm not precisely sure what might be causing the performance degredation, but one thing stands out: your code makes use of Code in |
Beta Was this translation helpful? Give feedback.
-
Why is this line: params_NN = opt_state |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I am trying to code a policy gradient reinforcement learning in JAX by transforming the script provided in https://github.com/tsmatz/reinforcement-learning-tutorials/blob/master/02-policy-gradient.ipynb as following:
Compared to the original Pytorch code, the performance is much worse. I suspected that it may be caused by the initialization so I tried changing these lines with available jax initializers options:
Winit = initializers.he_uniform(dtype=jnp.float64) binit = initializers.normal(stddev=0.1, dtype=jnp.float64)
However, I ran into another problem that when I use the
binit = initializers.he_uniform(stddev=0.1, dtype=jnp.float64)
or something else rather thannormal
, it gives me an error.I am very much obliged for you help!!
Beta Was this translation helpful? Give feedback.
All reactions