@@ -3408,16 +3408,16 @@ def test_d4rl_iteration(self, task, split_trajs):
3408
3408
]
3409
3409
3410
3410
3411
- def _minari_init ():
3411
+ def _minari_init () -> tuple [ bool , Exception | None ] :
3412
3412
"""Initialize Minari datasets list. Returns True if already initialized."""
3413
3413
global _MINARI_DATASETS
3414
3414
if _MINARI_DATASETS and not all (
3415
3415
isinstance (x , str ) and x .isdigit () for x in _MINARI_DATASETS
3416
3416
):
3417
- return True # Already initialized with real dataset names
3417
+ return True , None # Already initialized with real dataset names
3418
3418
3419
3419
if not _has_minari or not _has_gymnasium :
3420
- return False
3420
+ return False , ImportError ( "Minari or Gymnasium not found" )
3421
3421
3422
3422
try :
3423
3423
import minari
@@ -3434,9 +3434,9 @@ def _minari_init():
3434
3434
3435
3435
assert len (keys ) > 5 , keys
3436
3436
_MINARI_DATASETS [:] = keys # Replace the placeholder values
3437
- return True
3438
- except Exception :
3439
- return False
3437
+ return True , None
3438
+ except Exception as err :
3439
+ return False , err
3440
3440
3441
3441
3442
3442
def get_random_minigrid_datasets ():
@@ -3607,6 +3607,7 @@ def test_load(self, dataset_idx, split):
3607
3607
if cleanup_needed :
3608
3608
minari .delete_dataset (dataset_id = dataset_id )
3609
3609
3610
+ @retry (Exception , tries = 3 , delay = 1 )
3610
3611
def test_minari_preproc (self , tmpdir ):
3611
3612
dataset = MinariExperienceReplay (
3612
3613
"D4RL/pointmaze/large-v2" ,
@@ -3656,63 +3657,70 @@ def fn(data):
3656
3657
@pytest .mark .skipif (
3657
3658
not _has_minari or not _has_gymnasium , reason = "Minari or Gym not available"
3658
3659
)
3659
- def test_local_minari_dataset_loading (self ):
3660
- import minari
3661
- from minari import DataCollector
3662
-
3663
- if not _minari_init ():
3664
- pytest .skip ("Failed to initialize Minari datasets" )
3665
-
3666
- dataset_id = "cartpole/test-local-v1"
3667
-
3668
- # Create dataset using Gym + DataCollector
3669
- env = gymnasium .make ("CartPole-v1" )
3670
- env = DataCollector (env , record_infos = True )
3671
- for _ in range (50 ):
3672
- env .reset (seed = 123 )
3673
- while True :
3674
- action = env .action_space .sample ()
3675
- obs , rew , terminated , truncated , info = env .step (action )
3676
- if terminated or truncated :
3677
- break
3678
-
3679
- env .create_dataset (
3680
- dataset_id = dataset_id ,
3681
- algorithm_name = "RandomPolicy" ,
3682
- code_permalink = "https://github.com/Farama-Foundation/Minari" ,
3683
- author = "Farama" ,
3684
- author_email = "contact@farama.org" ,
3685
- eval_env = "CartPole-v1" ,
3686
- )
3687
-
3688
- # Load from local cache
3689
- data = MinariExperienceReplay (
3690
- dataset_id = dataset_id ,
3691
- split_trajs = False ,
3692
- batch_size = 32 ,
3693
- download = False ,
3694
- sampler = SamplerWithoutReplacement (drop_last = True ),
3695
- prefetch = 2 ,
3696
- load_from_local_minari = True ,
3697
- )
3660
+ def test_local_minari_dataset_loading (self , tmpdir ):
3661
+ MINARI_DATASETS_PATH = os .environ .get ("MINARI_DATASETS_PATH" )
3662
+ os .environ ["MINARI_DATASETS_PATH" ] = str (tmpdir )
3663
+ try :
3664
+ import minari
3665
+ from minari import DataCollector
3666
+
3667
+ success , err = _minari_init ()
3668
+ if not success :
3669
+ pytest .skip (f"Failed to initialize Minari datasets: { err } " )
3670
+
3671
+ dataset_id = "cartpole/test-local-v1"
3672
+
3673
+ # Create dataset using Gym + DataCollector
3674
+ env = gymnasium .make ("CartPole-v1" )
3675
+ env = DataCollector (env , record_infos = True )
3676
+ for _ in range (50 ):
3677
+ env .reset (seed = 123 )
3678
+ while True :
3679
+ action = env .action_space .sample ()
3680
+ obs , rew , terminated , truncated , info = env .step (action )
3681
+ if terminated or truncated :
3682
+ break
3683
+
3684
+ env .create_dataset (
3685
+ dataset_id = dataset_id ,
3686
+ algorithm_name = "RandomPolicy" ,
3687
+ code_permalink = "https://github.com/Farama-Foundation/Minari" ,
3688
+ author = "Farama" ,
3689
+ author_email = "contact@farama.org" ,
3690
+ eval_env = "CartPole-v1" ,
3691
+ )
3698
3692
3699
- t0 = time .time ()
3700
- for i , sample in enumerate (data ):
3701
- t1 = time .time ()
3702
- torchrl_logger .info (
3703
- f"[Local Minari] Sampling time { 1000 * (t1 - t0 ):4.4f} ms"
3693
+ # Load from local cache
3694
+ data = MinariExperienceReplay (
3695
+ dataset_id = dataset_id ,
3696
+ split_trajs = False ,
3697
+ batch_size = 32 ,
3698
+ download = False ,
3699
+ sampler = SamplerWithoutReplacement (drop_last = True ),
3700
+ prefetch = 2 ,
3701
+ load_from_local_minari = True ,
3704
3702
)
3705
- assert data .metadata ["action_space" ].is_in (
3706
- sample ["action" ]
3707
- ), "Invalid action sample"
3708
- assert data .metadata ["observation_space" ].is_in (
3709
- sample ["observation" ]
3710
- ), "Invalid observation sample"
3703
+
3711
3704
t0 = time .time ()
3712
- if i == 10 :
3713
- break
3705
+ for i , sample in enumerate (data ):
3706
+ t1 = time .time ()
3707
+ torchrl_logger .info (
3708
+ f"[Local Minari] Sampling time { 1000 * (t1 - t0 ):4.4f} ms"
3709
+ )
3710
+ assert data .metadata ["action_space" ].is_in (
3711
+ sample ["action" ]
3712
+ ), "Invalid action sample"
3713
+ assert data .metadata ["observation_space" ].is_in (
3714
+ sample ["observation" ]
3715
+ ), "Invalid observation sample"
3716
+ t0 = time .time ()
3717
+ if i == 10 :
3718
+ break
3714
3719
3715
- minari .delete_dataset (dataset_id = "cartpole/test-local-v1" )
3720
+ minari .delete_dataset (dataset_id = "cartpole/test-local-v1" )
3721
+ finally :
3722
+ if MINARI_DATASETS_PATH :
3723
+ os .environ ["MINARI_DATASETS_PATH" ] = MINARI_DATASETS_PATH
3716
3724
3717
3725
3718
3726
@pytest .mark .slow
0 commit comments