Skip to content

nan in MultivariateNormalDiag log prob #216

Open
@vwxyzjn

Description

@vwxyzjn

Hello thanks for this awesome repo! We have had a slight issue with using distrax which creates nan at vwxyzjn/cleanrl#300. See the following reproduction script:

from typing import Sequence

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax

# import pybullet_envs  # noqa
import tensorflow_probability
from flax.training.train_state import TrainState

tfp = tensorflow_probability.substrates.jax
tfd = tfp.distributions
jax.config.update("jax_platform_name", "cpu")
import distrax


class Actor(nn.Module):
    action_dim: Sequence[int]
    n_units: int = 256
    log_std_min: float = -20
    log_std_max: float = 2

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        x = nn.Dense(self.n_units)(x)
        x = nn.relu(x)
        x = nn.Dense(self.n_units)(x)
        x = nn.relu(x)
        mean = nn.Dense(self.action_dim)(x)
        log_std = nn.Dense(self.action_dim)(x)
        log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
        return mean, log_std


# @jax.jit
def custom_log_prob(
    mean: jnp.ndarray,
    log_std: jnp.ndarray,
    subkey: jax.random.KeyArray,
    gaussian_action: jnp.ndarray,
):
    std = jnp.exp(log_std)
    gaussian_action = mean + std * jax.random.normal(subkey, shape=mean.shape)
    log_prob = -0.5 * ((gaussian_action - mean) / std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - log_std
    log_prob = log_prob.sum(axis=1)
    # https://github.com/vwxyzjn/cleanrl/pull/300#issuecomment-1326285592
    log_prob -= jnp.sum(2.0 * (np.log(2.0) - gaussian_action - jax.nn.softplus(-2.0 * gaussian_action)), 1)
    return log_prob


if __name__ == "__main__":
    key = jax.random.PRNGKey(0)
    key, actor_key = jax.random.split(key, 2)
    # with open("test.npy", "rb") as f:
    #     obs = np.load(f)
    obs = jnp.array([[ -0.06284985,  -0.0164921 ,  -0.10846169,   0.28114545,
         -0.28463456,   0.4503281 ,   0.27488193,  -0.0666963 ,
          0.6118138 ,   0.34202537,  -1.262452  ,   0.7542422 ,
         13.809639  ,  -0.6205632 ,  -4.0013294 ,   5.3532414 ,
         11.587792  ],
       [ -0.15303956,   0.9534635 ,  -0.3092537 ,  -0.2033926 ,
          0.03336933,   0.6362027 ,   0.02348915,  -0.32627296,
         -0.29046476,   0.46484601,  -0.42002085,  -3.1616204 ,
          2.247283  ,  14.114895  ,   2.6248324 ,  -1.9809983 ,
        -12.693646  ],
       [ -0.07995494,   0.09804074,  -0.20460981,  -0.13476144,
          0.1701505 ,   0.05989099,  -0.06446445,  -0.22749065,
          0.39946172,   0.42318228,   2.5876977 ,   3.8510017 ,
         -8.23167   ,  -7.292657  ,   7.64345   ,  -9.558817  ,
         -1.9690503 ],
    ])
    # obs = obs[0:5]
    actor = Actor(action_dim=6)
    actor_state = TrainState.create(
        apply_fn=actor.apply,
        params=actor.init(actor_key, obs),
        tx=optax.adam(learning_rate=3e-4),
    )

    key, subkey = jax.random.split(key, 2)
    mean, log_std = actor.apply(actor_state.params, obs)
    action_std = jnp.exp(log_std)
    tfd_dist = tfd.TransformedDistribution(
        tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), bijector=tfp.bijectors.Tanh()
    )
    distrax_dist = distrax.Transformed(
        distrax.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), bijector=distrax.Block(distrax.Tanh(), 1)
    )

    # action generation
    gaussian_action = mean + action_std * jax.random.normal(subkey, shape=mean.shape)
    action_custom = jnp.tanh(gaussian_action)
    reverse_action_custom = jnp.arctanh(action_custom)
    action_tfp = tfd_dist.sample(seed=subkey)
    action_distrax = distrax_dist.sample(seed=subkey)

    print("action_custom.sum()", action_custom.sum())
    print("action_tfp.sum()", action_tfp.sum())
    print("action_distrax.sum()", action_distrax.sum())
    print("gaussian_action.sum()", gaussian_action.sum())
    print("reverse_action_custom.sum()", reverse_action_custom.sum())

    # log_prob
    for idx, (action, name) in enumerate(
        zip([action_custom, action_tfp, action_distrax], ["action_custom", "action_tfp", "action_distrax"])
    ):
        log_prob_custom = custom_log_prob(mean, log_std, subkey, jnp.arctanh(action))
        log_prob_tfp = tfd_dist.log_prob(action)
        log_prob_distrax = distrax_dist.log_prob(action)
        print(name)
        print("┣━━ log_prob_custom.sum()", log_prob_custom.sum())
        print("┣━━ log_prob_tfp.sum()", log_prob_tfp.sum())
        print("┣━━ log_prob_distrax.sum()", log_prob_distrax.sum())
action_custom.sum() 2.8352258
action_tfp.sum() 5.978534
action_distrax.sum() 2.8352258
gaussian_action.sum() 34.332348
reverse_action_custom.sum() inf
action_custom
┣━━ log_prob_custom.sum() 58.477264
┣━━ log_prob_tfp.sum() nan
┣━━ log_prob_distrax.sum() nan
action_tfp
┣━━ log_prob_custom.sum() 58.477264
┣━━ log_prob_tfp.sum() 60.565056
┣━━ log_prob_distrax.sum() nan
action_distrax
┣━━ log_prob_custom.sum() 58.477264
┣━━ log_prob_tfp.sum() nan
┣━━ log_prob_distrax.sum() nan

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions