Skip to content

[Refactor] refactor noisy linear #3082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
358 changes: 358 additions & 0 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import argparse
import math
import os

import pytest
Expand Down Expand Up @@ -891,3 +892,360 @@ def test_consistent_dropout_primer(self):
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)


@pytest.mark.parametrize("device", get_default_devices())
class TestNoisyLinear:
"""Tests for NoisyLinear layer based on NoisyNet paper specifications."""

def test_noisy_linear_initialization(self, device):
"""Test that NoisyLinear initializes with correct parameters."""
from torchrl.modules.models.exploration import NoisyLinear

in_features, out_features = 10, 5
layer = NoisyLinear(in_features, out_features, device=device)

# Check that mu and sigma parameters exist
assert hasattr(layer, "weight_mu")
assert hasattr(layer, "weight_sigma")
assert hasattr(layer, "bias_mu")
assert hasattr(layer, "bias_sigma")

# Check parameter shapes
assert layer.weight_mu.shape == (out_features, in_features)
assert layer.weight_sigma.shape == (out_features, in_features)
assert layer.bias_mu.shape == (out_features,)
assert layer.bias_sigma.shape == (out_features,)

# Check that sigma values are positive
assert (layer.weight_sigma > 0).all()
assert (layer.bias_sigma > 0).all()

# Check initialization ranges (from paper)
mu_range = 1 / math.sqrt(in_features)
assert (layer.weight_mu >= -mu_range).all()
assert (layer.weight_mu <= mu_range).all()
assert (layer.bias_mu >= -mu_range).all()
assert (layer.bias_mu <= mu_range).all()

def test_noisy_linear_training_vs_eval(self, device):
"""Test that NoisyLinear behaves differently in training vs eval mode."""
from torchrl.modules.models.exploration import NoisyLinear

torch.manual_seed(0)
layer = NoisyLinear(10, 5, device=device)
x = torch.randn(3, 10, device=device)

# Get outputs in training mode
layer.train()
y_train_1 = layer(x)
layer.reset_noise() # Reset noise
y_train_2 = layer(x)

# Get outputs in eval mode
layer.eval()
y_eval_1 = layer(x)
layer.reset_noise() # Reset noise
y_eval_2 = layer(x)

# Training outputs should be different due to noise
assert not torch.allclose(y_train_1, y_train_2, atol=1e-6)

# Eval outputs should be identical (no noise)
torch.testing.assert_close(y_eval_1, y_eval_2)

# Training and eval outputs should be different
assert not torch.allclose(y_train_1, y_eval_1, atol=1e-6)

def test_noise_consistency_within_episode(self, device):
"""Test that noise remains consistent within an episode (no reset)."""
from torchrl.modules.models.exploration import NoisyLinear

torch.manual_seed(0)
layer = NoisyLinear(10, 5, device=device)
layer.train()
x = torch.randn(3, 10, device=device)

# First forward pass
y1 = layer(x)

# Multiple forward passes without resetting noise
y2 = layer(x)
y3 = layer(x)
y4 = layer(x)

# All outputs should be identical (same noise)
assert torch.allclose(y1, y2, atol=1e-6)
assert torch.allclose(y1, y3, atol=1e-6)
assert torch.allclose(y1, y4, atol=1e-6)

def test_noise_change_after_reset(self, device):
"""Test that noise changes after reset_noise() is called."""
from torchrl.modules.models.exploration import NoisyLinear

torch.manual_seed(0)
layer = NoisyLinear(10, 5, device=device)
layer.train()
x = torch.randn(3, 10, device=device)

# First episode
y1 = layer(x)

# Reset noise (simulating new episode)
layer.reset_noise()
y2 = layer(x)

# Reset noise again
layer.reset_noise()
y3 = layer(x)

# Outputs should be different after each reset
assert not torch.allclose(y1, y2, atol=1e-6)
assert not torch.allclose(y1, y3, atol=1e-6)
assert not torch.allclose(y2, y3, atol=1e-6)

def test_factorized_gaussian_noise(self, device):
"""Test that the noise follows factorized Gaussian distribution."""
from torchrl.modules.models.exploration import NoisyLinear

torch.manual_seed(0)
layer = NoisyLinear(10, 5, device=device)
layer.train()

# Get noise samples
noise_samples = []
for _ in range(1000):
layer.reset_noise()
# Extract the actual noise used
weight_noise = layer.weight - layer.weight_mu
noise_samples.append(weight_noise.flatten())

noise_samples = torch.stack(noise_samples)

# Check that noise has approximately zero mean
assert abs(noise_samples.mean()) < 0.1

# Check that noise has reasonable variance
noise_std = noise_samples.std()
expected_std = layer.std_init / math.sqrt(10) # Based on initialization
assert 0.5 * expected_std < noise_std < 2.0 * expected_std

def test_weight_property_behavior(self, device):
"""Test that weight property returns correct values in train/eval modes."""
from torchrl.modules.models.exploration import NoisyLinear

torch.manual_seed(0)
layer = NoisyLinear(10, 5, device=device)

# Training mode
layer.train()
layer.reset_noise()
weight_train = layer.weight
bias_train = layer.bias

# Should include noise
assert not torch.allclose(weight_train, layer.weight_mu, atol=1e-6)
assert not torch.allclose(bias_train, layer.bias_mu, atol=1e-6)

# Eval mode
layer.eval()
weight_eval = layer.weight
bias_eval = layer.bias

# Should be exactly the mean weights
assert torch.allclose(weight_eval, layer.weight_mu, atol=1e-6)
assert torch.allclose(bias_eval, layer.bias_mu, atol=1e-6)

def test_noisy_linear_in_network(self, device):
"""Test NoisyLinear in a complete network setup."""
from torchrl.modules.models.exploration import NoisyLinear

torch.manual_seed(0)

# Create a simple network with NoisyLinear
network = nn.Sequential(
nn.Linear(10, 20), nn.ReLU(), NoisyLinear(20, 5, device=device)
).to(device)

x = torch.randn(3, 10, device=device)

# Training mode
network.train()
y_train_1 = network(x)
network[-1].reset_noise() # Reset noise in NoisyLinear layer
y_train_2 = network(x)

# Eval mode
network.eval()
y_eval_1 = network(x)
y_eval_2 = network(x)

# Training outputs should be different
assert not torch.allclose(y_train_1, y_train_2, atol=1e-6)

# Eval outputs should be identical
assert torch.allclose(y_eval_1, y_eval_2, atol=1e-6)

def test_noise_reset_function(self, device):
"""Test the reset_noise utility function."""
from torchrl.modules.models.exploration import NoisyLinear, reset_noise

torch.manual_seed(0)

# Create network with multiple NoisyLinear layers
network = nn.Sequential(
NoisyLinear(10, 20, device=device),
nn.ReLU(),
NoisyLinear(20, 5, device=device),
).to(device)

network.train()
x = torch.randn(3, 10, device=device)

# First forward pass
network(x)

# Reset noise using utility function
reset_noise(network)
network(x)

# Outputs should be different (but might be the same if noise is very small)
# Let's check that at least one of the layers changed
changed = False
for module in network.modules():
if hasattr(module, "weight_mu"):
# Check if the actual weights changed
if not torch.allclose(module.weight, module.weight_mu, atol=1e-6):
changed = True
break

# If no noise is present, the test should still pass
if not changed:
# Check that we're in eval mode or noise is very small
assert network.training == False or all(
hasattr(m, "weight_sigma") and m.weight_sigma.max() < 1e-3
for m in network.modules()
if hasattr(m, "weight_sigma")
)

def test_noisy_linear_gradients(self, device):
"""Test that gradients flow through NoisyLinear parameters."""
from torchrl.modules.models.exploration import NoisyLinear

torch.manual_seed(0)
layer = NoisyLinear(10, 5, device=device)
layer.train()

x = torch.randn(3, 10, device=device, requires_grad=True)
y = layer(x)
loss = y.sum()

# Backward pass
loss.backward()

# Check that gradients exist for all parameters
assert layer.weight_mu.grad is not None
assert layer.weight_sigma.grad is not None
assert layer.bias_mu.grad is not None
assert layer.bias_sigma.grad is not None

# Check that gradients are not zero
assert not torch.allclose(
layer.weight_mu.grad, torch.zeros_like(layer.weight_mu.grad)
)
assert not torch.allclose(
layer.weight_sigma.grad, torch.zeros_like(layer.weight_sigma.grad)
)

def test_noisy_linear_parameter_learning(self, device):
"""Test that sigma parameters actually learn during training."""
from torchrl.modules.models.exploration import NoisyLinear

torch.manual_seed(0)
layer = NoisyLinear(10, 5, device=device)
layer.train()

# Store initial sigma values
initial_weight_sigma = layer.weight_sigma.clone()
initial_bias_sigma = layer.bias_sigma.clone()

# Simple training loop
optimizer = torch.optim.Adam(layer.parameters(), lr=0.01)
x = torch.randn(100, 10, device=device)
target = torch.randn(100, 5, device=device)

for _ in range(10):
optimizer.zero_grad()
layer.reset_noise() # Reset noise each iteration
y = layer(x)
loss = torch.nn.functional.mse_loss(y, target)
loss.backward()
optimizer.step()

# Check that sigma values have changed
assert not torch.allclose(layer.weight_sigma, initial_weight_sigma, atol=1e-6)
assert not torch.allclose(layer.bias_sigma, initial_bias_sigma, atol=1e-6)

def test_noisy_linear_std_init_effect(self, device):
"""Test that different std_init values affect noise magnitude."""
from torchrl.modules.models.exploration import NoisyLinear

torch.manual_seed(0)

# Create layers with different std_init values
layer_small = NoisyLinear(10, 5, std_init=0.01, device=device)
layer_large = NoisyLinear(10, 5, std_init=1.0, device=device)

layer_small.train()
layer_large.train()

x = torch.randn(3, 10, device=device)

# Get outputs with different noise levels
layer_small.reset_noise()
layer_large.reset_noise()

# Get multiple samples to measure noise variance
noise_samples_small = []
noise_samples_large = []

for _ in range(10):
layer_small.reset_noise()
layer_large.reset_noise()
y_small = layer_small(x)
y_large = layer_large(x)
noise_samples_small.append(y_small)
noise_samples_large.append(y_large)

noise_samples_small = torch.stack(noise_samples_small)
noise_samples_large = torch.stack(noise_samples_large)

# Calculate noise variance
noise_var_small = noise_samples_small.var(dim=0).mean()
noise_var_large = noise_samples_large.var(dim=0).mean()

# Large std_init should produce larger noise variance
assert noise_var_large > noise_var_small

def test_noisy_linear_serialization(self, device):
"""Test that NoisyLinear can be saved and loaded correctly."""
import os
import tempfile

from torchrl.modules.models.exploration import NoisyLinear

torch.manual_seed(0)
layer = NoisyLinear(10, 5, device=device)

# Save and load
with tempfile.NamedTemporaryFile(delete=False) as f:
torch.save(layer.state_dict(), f.name)
layer_loaded = NoisyLinear(10, 5, device=device)
layer_loaded.load_state_dict(torch.load(f.name))
os.unlink(f.name)

# Check that parameters are the same
assert torch.allclose(layer.weight_mu, layer_loaded.weight_mu, atol=1e-6)
assert torch.allclose(layer.weight_sigma, layer_loaded.weight_sigma, atol=1e-6)
assert torch.allclose(layer.bias_mu, layer_loaded.bias_mu, atol=1e-6)
assert torch.allclose(layer.bias_sigma, layer_loaded.bias_sigma, atol=1e-6)
17 changes: 17 additions & 0 deletions torchrl/modules/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,20 @@ def _reset_parameters_recursive(module, warn_if_no_op: bool = True) -> bool:
"_reset_parameters_recursive was called without the parameters argument and did not find any parameters to reset"
)
return any_reset

def primers_from_module(module: nn.Module, target_cls: T) -> list[TensorDictPrimer]:
"""Get primers from a module.

Iterates over the module's children and returns the primers of the children that are instances of the target class.
These primers will write some data to be used by the models at reset time.
The tensors are set within the model during the policy call by <TODO>


Args:
module (nn.Module): the module to get primers from.

Returns:
list[TensorDictPrimer]: the primers from the module.
"""
#
...
Loading