Open
Description
Hello thanks for this awesome repo! We have had a slight issue with using distrax which creates nan
at vwxyzjn/cleanrl#300. See the following reproduction script:
from typing import Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
# import pybullet_envs # noqa
import tensorflow_probability
from flax.training.train_state import TrainState
tfp = tensorflow_probability.substrates.jax
tfd = tfp.distributions
jax.config.update("jax_platform_name", "cpu")
import distrax
class Actor(nn.Module):
action_dim: Sequence[int]
n_units: int = 256
log_std_min: float = -20
log_std_max: float = 2
@nn.compact
def __call__(self, x: jnp.ndarray):
x = nn.Dense(self.n_units)(x)
x = nn.relu(x)
x = nn.Dense(self.n_units)(x)
x = nn.relu(x)
mean = nn.Dense(self.action_dim)(x)
log_std = nn.Dense(self.action_dim)(x)
log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
return mean, log_std
# @jax.jit
def custom_log_prob(
mean: jnp.ndarray,
log_std: jnp.ndarray,
subkey: jax.random.KeyArray,
gaussian_action: jnp.ndarray,
):
std = jnp.exp(log_std)
gaussian_action = mean + std * jax.random.normal(subkey, shape=mean.shape)
log_prob = -0.5 * ((gaussian_action - mean) / std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - log_std
log_prob = log_prob.sum(axis=1)
# https://github.com/vwxyzjn/cleanrl/pull/300#issuecomment-1326285592
log_prob -= jnp.sum(2.0 * (np.log(2.0) - gaussian_action - jax.nn.softplus(-2.0 * gaussian_action)), 1)
return log_prob
if __name__ == "__main__":
key = jax.random.PRNGKey(0)
key, actor_key = jax.random.split(key, 2)
# with open("test.npy", "rb") as f:
# obs = np.load(f)
obs = jnp.array([[ -0.06284985, -0.0164921 , -0.10846169, 0.28114545,
-0.28463456, 0.4503281 , 0.27488193, -0.0666963 ,
0.6118138 , 0.34202537, -1.262452 , 0.7542422 ,
13.809639 , -0.6205632 , -4.0013294 , 5.3532414 ,
11.587792 ],
[ -0.15303956, 0.9534635 , -0.3092537 , -0.2033926 ,
0.03336933, 0.6362027 , 0.02348915, -0.32627296,
-0.29046476, 0.46484601, -0.42002085, -3.1616204 ,
2.247283 , 14.114895 , 2.6248324 , -1.9809983 ,
-12.693646 ],
[ -0.07995494, 0.09804074, -0.20460981, -0.13476144,
0.1701505 , 0.05989099, -0.06446445, -0.22749065,
0.39946172, 0.42318228, 2.5876977 , 3.8510017 ,
-8.23167 , -7.292657 , 7.64345 , -9.558817 ,
-1.9690503 ],
])
# obs = obs[0:5]
actor = Actor(action_dim=6)
actor_state = TrainState.create(
apply_fn=actor.apply,
params=actor.init(actor_key, obs),
tx=optax.adam(learning_rate=3e-4),
)
key, subkey = jax.random.split(key, 2)
mean, log_std = actor.apply(actor_state.params, obs)
action_std = jnp.exp(log_std)
tfd_dist = tfd.TransformedDistribution(
tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), bijector=tfp.bijectors.Tanh()
)
distrax_dist = distrax.Transformed(
distrax.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), bijector=distrax.Block(distrax.Tanh(), 1)
)
# action generation
gaussian_action = mean + action_std * jax.random.normal(subkey, shape=mean.shape)
action_custom = jnp.tanh(gaussian_action)
reverse_action_custom = jnp.arctanh(action_custom)
action_tfp = tfd_dist.sample(seed=subkey)
action_distrax = distrax_dist.sample(seed=subkey)
print("action_custom.sum()", action_custom.sum())
print("action_tfp.sum()", action_tfp.sum())
print("action_distrax.sum()", action_distrax.sum())
print("gaussian_action.sum()", gaussian_action.sum())
print("reverse_action_custom.sum()", reverse_action_custom.sum())
# log_prob
for idx, (action, name) in enumerate(
zip([action_custom, action_tfp, action_distrax], ["action_custom", "action_tfp", "action_distrax"])
):
log_prob_custom = custom_log_prob(mean, log_std, subkey, jnp.arctanh(action))
log_prob_tfp = tfd_dist.log_prob(action)
log_prob_distrax = distrax_dist.log_prob(action)
print(name)
print("┣━━ log_prob_custom.sum()", log_prob_custom.sum())
print("┣━━ log_prob_tfp.sum()", log_prob_tfp.sum())
print("┣━━ log_prob_distrax.sum()", log_prob_distrax.sum())
action_custom.sum() 2.8352258
action_tfp.sum() 5.978534
action_distrax.sum() 2.8352258
gaussian_action.sum() 34.332348
reverse_action_custom.sum() inf
action_custom
┣━━ log_prob_custom.sum() 58.477264
┣━━ log_prob_tfp.sum() nan
┣━━ log_prob_distrax.sum() nan
action_tfp
┣━━ log_prob_custom.sum() 58.477264
┣━━ log_prob_tfp.sum() 60.565056
┣━━ log_prob_distrax.sum() nan
action_distrax
┣━━ log_prob_custom.sum() 58.477264
┣━━ log_prob_tfp.sum() nan
┣━━ log_prob_distrax.sum() nan
Metadata
Metadata
Assignees
Labels
No labels