Skip to content

Commit 32f7d72

Browse files
authored
[BugFix] Brax memory leak fix (#3052)
1 parent 3ebe93d commit 32f7d72

File tree

6 files changed

+257
-38
lines changed

6 files changed

+257
-38
lines changed

.github/unittest/linux_libs/scripts_brax/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ dependencies:
2121
- hydra-core
2222
- jax[cuda12]
2323
- brax
24+
- psutil

.github/unittest/linux_libs/scripts_brax/run_test.sh

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ conda activate ./env
88

99
export PYTORCH_TEST_WITH_SLOW='1'
1010
export LAZY_LEGACY_OP=False
11+
12+
# Configure JAX for proper GPU initialization
13+
export XLA_PYTHON_CLIENT_PREALLOCATE=false
14+
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
15+
export TF_FORCE_GPU_ALLOW_GROWTH=true
16+
export CUDA_VISIBLE_DEVICES=0
17+
1118
python -m torch.utils.collect_env
1219
# Avoid error: "fatal: unsafe repository"
1320
git config --global --add safe.directory '*'
@@ -28,7 +35,33 @@ export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON
2835
# this workflow only tests the libs
2936
python -c "import brax"
3037
python -c "import brax.envs"
31-
python -c "import jax"
38+
39+
# Initialize JAX with proper GPU configuration
40+
python -c "
41+
import jax
42+
import jax.numpy as jnp
43+
import os
44+
45+
# Configure JAX for GPU
46+
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
47+
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
48+
49+
# Test JAX GPU availability
50+
try:
51+
devices = jax.devices()
52+
print(f'JAX devices: {devices}')
53+
if len(devices) > 1:
54+
print('JAX GPU is available')
55+
else:
56+
print('JAX CPU only')
57+
except Exception as e:
58+
print(f'JAX initialization error: {e}')
59+
# Fallback to CPU
60+
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
61+
jax.config.update('jax_platform_name', 'cpu')
62+
print('Falling back to JAX CPU')
63+
"
64+
3265
python3 -c 'import torch;t = torch.ones([2,2], device="cuda:0");print(t);print("tensor device:" + str(t.device))'
3366

3467
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestBrax --error-for-skips

.github/workflows/test-linux-libs.yml

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,39 +21,39 @@ permissions:
2121

2222
jobs:
2323

24-
unittests-atari-dqn:
25-
strategy:
26-
matrix:
27-
python_version: ["3.10"]
28-
cuda_arch_version: ["12.8"]
29-
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
30-
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
31-
with:
32-
repository: pytorch/rl
33-
runner: "linux.g5.4xlarge.nvidia.gpu"
34-
docker-image: "nvidia/cuda:12.4.0-devel-ubuntu22.04"
35-
timeout: 120
36-
script: |
37-
if [[ "${{ github.ref }}" =~ release/* ]]; then
38-
export RELEASE=1
39-
export TORCH_VERSION=stable
40-
else
41-
export RELEASE=0
42-
export TORCH_VERSION=nightly
43-
fi
44-
45-
set -euo pipefail
46-
export PYTHON_VERSION="3.10"
47-
export CU_VERSION="cu128"
48-
export TAR_OPTIONS="--no-same-owner"
49-
export UPLOAD_CHANNEL="nightly"
50-
export TF_CPP_MIN_LOG_LEVEL=0
51-
export TD_GET_DEFAULTS_TO_NONE=1
52-
53-
bash .github/unittest/linux_libs/scripts_ataridqn/setup_env.sh
54-
bash .github/unittest/linux_libs/scripts_ataridqn/install.sh
55-
bash .github/unittest/linux_libs/scripts_ataridqn/run_test.sh
56-
bash .github/unittest/linux_libs/scripts_ataridqn/post_process.sh
24+
# unittests-atari-dqn:
25+
# strategy:
26+
# matrix:
27+
# python_version: ["3.10"]
28+
# cuda_arch_version: ["12.8"]
29+
# if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }}
30+
# uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
31+
# with:
32+
# repository: pytorch/rl
33+
# runner: "linux.g5.4xlarge.nvidia.gpu"
34+
# docker-image: "nvidia/cuda:12.4.0-devel-ubuntu22.04"
35+
# timeout: 120
36+
# script: |
37+
# if [[ "${{ github.ref }}" =~ release/* ]]; then
38+
# export RELEASE=1
39+
# export TORCH_VERSION=stable
40+
# else
41+
# export RELEASE=0
42+
# export TORCH_VERSION=nightly
43+
# fi
44+
45+
# set -euo pipefail
46+
# export PYTHON_VERSION="3.10"
47+
# export CU_VERSION="cu128"
48+
# export TAR_OPTIONS="--no-same-owner"
49+
# export UPLOAD_CHANNEL="nightly"
50+
# export TF_CPP_MIN_LOG_LEVEL=0
51+
# export TD_GET_DEFAULTS_TO_NONE=1
52+
53+
# bash .github/unittest/linux_libs/scripts_ataridqn/setup_env.sh
54+
# bash .github/unittest/linux_libs/scripts_ataridqn/install.sh
55+
# bash .github/unittest/linux_libs/scripts_ataridqn/run_test.sh
56+
# bash .github/unittest/linux_libs/scripts_ataridqn/post_process.sh
5757

5858
unittests-brax:
5959
strategy:

test/test_libs.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import functools
99
import gc
1010
import importlib.util
11+
import os
1112
import urllib.error
1213

14+
1315
_has_isaac = importlib.util.find_spec("isaacgym") is not None
1416

1517
if _has_isaac:
@@ -19,7 +21,6 @@
1921
from torchrl.envs.libs.isaacgym import IsaacGymEnv
2022
import argparse
2123
import importlib
22-
import os
2324

2425
import time
2526
import urllib
@@ -2414,6 +2415,28 @@ def test_env_device(self, env_name, frame_skip, transformed_out, device):
24142415
@pytest.mark.parametrize("device", get_available_devices())
24152416
@pytest.mark.parametrize("envname", ["fast"])
24162417
class TestBrax:
2418+
@pytest.fixture(autouse=True)
2419+
def _setup_jax(self):
2420+
"""Configure JAX for proper GPU initialization."""
2421+
import os
2422+
2423+
import jax
2424+
2425+
# Set JAX environment variables for better GPU handling
2426+
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
2427+
os.environ.setdefault("XLA_PYTHON_CLIENT_ALLOCATOR", "platform")
2428+
os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")
2429+
2430+
# Try to initialize JAX with GPU, fallback to CPU if it fails
2431+
try:
2432+
jax.devices()
2433+
except Exception:
2434+
# Fallback to CPU
2435+
os.environ["JAX_PLATFORM_NAME"] = "cpu"
2436+
jax.config.update("jax_platform_name", "cpu")
2437+
2438+
yield
2439+
24172440
@pytest.mark.parametrize("requires_grad", [False, True])
24182441
def test_brax_constructor(self, envname, requires_grad, device):
24192442
env0 = BraxEnv(envname, requires_grad=requires_grad, device=device)
@@ -2545,6 +2568,75 @@ def make_brax():
25452568
tensordict = env.rollout(3)
25462569
assert tensordict.shape == torch.Size([n, *batch_size, 3])
25472570

2571+
def test_brax_memory_leak(self, envname, device):
2572+
"""Test memory usage with different cache clearing strategies."""
2573+
import psutil
2574+
2575+
process = psutil.Process(os.getpid())
2576+
env = BraxEnv(
2577+
envname,
2578+
batch_size=[10],
2579+
requires_grad=True,
2580+
device=device,
2581+
)
2582+
env.clear_cache()
2583+
gc.collect()
2584+
env.set_seed(0)
2585+
next_td = env.reset()
2586+
num_steps = 200
2587+
policy = TensorDictModule(
2588+
torch.nn.Linear(
2589+
env.observation_spec[env.observation_keys[0]].shape[-1],
2590+
env.action_spec.shape[-1],
2591+
device=device,
2592+
),
2593+
in_keys=env.observation_keys[:1],
2594+
out_keys=["action"],
2595+
)
2596+
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
2597+
for i in range(num_steps):
2598+
policy(next_td)
2599+
out_td, next_td = env.step_and_maybe_reset(next_td)
2600+
if i % 50 == 0:
2601+
loss = out_td["next", "observation"].sum()
2602+
loss.backward()
2603+
next_td = next_td.detach().clone()
2604+
# gc.collect()
2605+
final_memory = process.memory_info().rss / 1024 / 1024 # MB
2606+
memory_increase = final_memory - initial_memory
2607+
assert (
2608+
memory_increase < 100
2609+
), f"Memory leak with automatic clearing: {memory_increase:.2f} MB"
2610+
2611+
def test_brax_cache_clearing(self, envname, device):
2612+
env = BraxEnv(envname, batch_size=[1], requires_grad=True, device=device)
2613+
env.clear_cache()
2614+
for _ in range(5):
2615+
env.clear_cache()
2616+
2617+
@pytest.mark.parametrize("freq", [10, None, False])
2618+
def test_brax_automatic_cache_clearing_parameter(self, envname, device, freq):
2619+
env = BraxEnv(
2620+
envname,
2621+
batch_size=[1],
2622+
requires_grad=True,
2623+
device=device,
2624+
cache_clear_frequency=freq,
2625+
)
2626+
if freq is False:
2627+
assert env._cache_clear_frequency is False
2628+
elif freq is None:
2629+
assert env._cache_clear_frequency == 20 # Default value
2630+
else:
2631+
assert env._cache_clear_frequency == freq
2632+
env.set_seed(0)
2633+
next_td = env.reset()
2634+
for i in range(10):
2635+
action = env.action_spec.rand()
2636+
next_td["action"] = action
2637+
out_td, next_td = env.step_and_maybe_reset(next_td)
2638+
assert env._step_count == i + 1
2639+
25482640

25492641
@pytest.mark.skipif(not _has_vmas, reason="vmas not installed")
25502642
class TestVmas:

torchrl/data/datasets/atari_dqn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,12 @@ def __init__(
411411
mp_start_method: str = "fork",
412412
**kwargs,
413413
):
414+
import warnings
415+
416+
warnings.warn(
417+
"This dataset is no longer available. We are working on a fix, or possibly a deprecation.",
418+
DeprecationWarning,
419+
)
414420
if dataset_id not in self.available_datasets:
415421
raise ValueError(
416422
"The dataseet_id is not part of the available datasets. The dataset should be named <game_name>/<run> "

0 commit comments

Comments
 (0)