Skip to content

Commit be40b1c

Browse files
committed
Fixes
1 parent 51c0167 commit be40b1c

File tree

7 files changed

+241
-201
lines changed

7 files changed

+241
-201
lines changed

.github/unittest/linux_libs/scripts_minari/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ dependencies:
2727
- jax
2828
- mujoco
2929
- mujoco-py<2.2,>=2.1
30-
- minigrid
30+
- minigrid

.github/unittest/linux_libs/scripts_minari/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ minari[gcs,hdf5,hf,create]
1717
gymnasium>=1.2.0
1818
ale-py
1919
gymnasium-robotics
20-
mujoco
20+
mujoco

.github/unittest/linux_libs/scripts_minari/run_all.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ source "${this_dir}/install.sh"
1313
source "${this_dir}/run_test.sh"
1414
source "${this_dir}/post_process.sh"
1515

16-
echo "Minari tests completed successfully!"
16+
echo "Minari tests completed successfully!"

test/_utils_internal.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import unittest
1414
import warnings
1515
from functools import wraps
16+
from typing import Callable
1617

1718
import pytest
1819
import torch
@@ -214,7 +215,12 @@ def generate_seeds(seed, repeat):
214215

215216

216217
# Decorator to retry upon certain Exceptions.
217-
def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False):
218+
def retry(
219+
ExceptionToCheck: type[Exception],
220+
tries: int = 3,
221+
delay: int = 3,
222+
skip_after_retries: bool = False,
223+
) -> Callable[[Callable], Callable]:
218224
def deco_retry(f):
219225
@wraps(f)
220226
def f_retry(*args, **kwargs):

test/test_libs.py

Lines changed: 67 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3408,16 +3408,16 @@ def test_d4rl_iteration(self, task, split_trajs):
34083408
]
34093409

34103410

3411-
def _minari_init():
3411+
def _minari_init() -> tuple[bool, Exception | None]:
34123412
"""Initialize Minari datasets list. Returns True if already initialized."""
34133413
global _MINARI_DATASETS
34143414
if _MINARI_DATASETS and not all(
34153415
isinstance(x, str) and x.isdigit() for x in _MINARI_DATASETS
34163416
):
3417-
return True # Already initialized with real dataset names
3417+
return True, None # Already initialized with real dataset names
34183418

34193419
if not _has_minari or not _has_gymnasium:
3420-
return False
3420+
return False, ImportError("Minari or Gymnasium not found")
34213421

34223422
try:
34233423
import minari
@@ -3434,9 +3434,9 @@ def _minari_init():
34343434

34353435
assert len(keys) > 5, keys
34363436
_MINARI_DATASETS[:] = keys # Replace the placeholder values
3437-
return True
3438-
except Exception:
3439-
return False
3437+
return True, None
3438+
except Exception as err:
3439+
return False, err
34403440

34413441

34423442
def get_random_minigrid_datasets():
@@ -3607,6 +3607,7 @@ def test_load(self, dataset_idx, split):
36073607
if cleanup_needed:
36083608
minari.delete_dataset(dataset_id=dataset_id)
36093609

3610+
@retry(Exception, tries=3, delay=1)
36103611
def test_minari_preproc(self, tmpdir):
36113612
dataset = MinariExperienceReplay(
36123613
"D4RL/pointmaze/large-v2",
@@ -3656,63 +3657,70 @@ def fn(data):
36563657
@pytest.mark.skipif(
36573658
not _has_minari or not _has_gymnasium, reason="Minari or Gym not available"
36583659
)
3659-
def test_local_minari_dataset_loading(self):
3660-
import minari
3661-
from minari import DataCollector
3662-
3663-
if not _minari_init():
3664-
pytest.skip("Failed to initialize Minari datasets")
3665-
3666-
dataset_id = "cartpole/test-local-v1"
3667-
3668-
# Create dataset using Gym + DataCollector
3669-
env = gymnasium.make("CartPole-v1")
3670-
env = DataCollector(env, record_infos=True)
3671-
for _ in range(50):
3672-
env.reset(seed=123)
3673-
while True:
3674-
action = env.action_space.sample()
3675-
obs, rew, terminated, truncated, info = env.step(action)
3676-
if terminated or truncated:
3677-
break
3678-
3679-
env.create_dataset(
3680-
dataset_id=dataset_id,
3681-
algorithm_name="RandomPolicy",
3682-
code_permalink="https://github.com/Farama-Foundation/Minari",
3683-
author="Farama",
3684-
author_email="contact@farama.org",
3685-
eval_env="CartPole-v1",
3686-
)
3687-
3688-
# Load from local cache
3689-
data = MinariExperienceReplay(
3690-
dataset_id=dataset_id,
3691-
split_trajs=False,
3692-
batch_size=32,
3693-
download=False,
3694-
sampler=SamplerWithoutReplacement(drop_last=True),
3695-
prefetch=2,
3696-
load_from_local_minari=True,
3697-
)
3660+
def test_local_minari_dataset_loading(self, tmpdir):
3661+
MINARI_DATASETS_PATH = os.environ.get("MINARI_DATASETS_PATH")
3662+
os.environ["MINARI_DATASETS_PATH"] = str(tmpdir)
3663+
try:
3664+
import minari
3665+
from minari import DataCollector
3666+
3667+
success, err = _minari_init()
3668+
if not success:
3669+
pytest.skip(f"Failed to initialize Minari datasets: {err}")
3670+
3671+
dataset_id = "cartpole/test-local-v1"
3672+
3673+
# Create dataset using Gym + DataCollector
3674+
env = gymnasium.make("CartPole-v1")
3675+
env = DataCollector(env, record_infos=True)
3676+
for _ in range(50):
3677+
env.reset(seed=123)
3678+
while True:
3679+
action = env.action_space.sample()
3680+
obs, rew, terminated, truncated, info = env.step(action)
3681+
if terminated or truncated:
3682+
break
3683+
3684+
env.create_dataset(
3685+
dataset_id=dataset_id,
3686+
algorithm_name="RandomPolicy",
3687+
code_permalink="https://github.com/Farama-Foundation/Minari",
3688+
author="Farama",
3689+
author_email="contact@farama.org",
3690+
eval_env="CartPole-v1",
3691+
)
36983692

3699-
t0 = time.time()
3700-
for i, sample in enumerate(data):
3701-
t1 = time.time()
3702-
torchrl_logger.info(
3703-
f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms"
3693+
# Load from local cache
3694+
data = MinariExperienceReplay(
3695+
dataset_id=dataset_id,
3696+
split_trajs=False,
3697+
batch_size=32,
3698+
download=False,
3699+
sampler=SamplerWithoutReplacement(drop_last=True),
3700+
prefetch=2,
3701+
load_from_local_minari=True,
37043702
)
3705-
assert data.metadata["action_space"].is_in(
3706-
sample["action"]
3707-
), "Invalid action sample"
3708-
assert data.metadata["observation_space"].is_in(
3709-
sample["observation"]
3710-
), "Invalid observation sample"
3703+
37113704
t0 = time.time()
3712-
if i == 10:
3713-
break
3705+
for i, sample in enumerate(data):
3706+
t1 = time.time()
3707+
torchrl_logger.info(
3708+
f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms"
3709+
)
3710+
assert data.metadata["action_space"].is_in(
3711+
sample["action"]
3712+
), "Invalid action sample"
3713+
assert data.metadata["observation_space"].is_in(
3714+
sample["observation"]
3715+
), "Invalid observation sample"
3716+
t0 = time.time()
3717+
if i == 10:
3718+
break
37143719

3715-
minari.delete_dataset(dataset_id="cartpole/test-local-v1")
3720+
minari.delete_dataset(dataset_id="cartpole/test-local-v1")
3721+
finally:
3722+
if MINARI_DATASETS_PATH:
3723+
os.environ["MINARI_DATASETS_PATH"] = MINARI_DATASETS_PATH
37163724

37173725

37183726
@pytest.mark.slow

0 commit comments

Comments
 (0)