Skip to content

Commit c166c4f

Browse files
authored
[BugFix] Fix minari dataloading (#3054)
1 parent 32f7d72 commit c166c4f

File tree

4 files changed

+56
-20
lines changed

4 files changed

+56
-20
lines changed

.github/unittest/linux_libs/scripts_openx/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ dependencies:
2121
- hydra-core
2222
- tqdm
2323
- h5py
24-
- datasets
24+
- datasets<4.0.0
2525
- pillow

test/test_libs.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3342,33 +3342,57 @@ def test_d4rl_iteration(self, task, split_trajs):
33423342
_MINARI_DATASETS = []
33433343

33443344

3345-
def _minari_selected_datasets():
3346-
if not _has_minari or not _has_gymnasium:
3347-
return
3345+
def _minari_init():
3346+
"""Initialize Minari datasets list. Returns True if already initialized."""
33483347
global _MINARI_DATASETS
3349-
import minari
3348+
if _MINARI_DATASETS and not all(
3349+
isinstance(x, str) and x.isdigit() for x in _MINARI_DATASETS
3350+
):
3351+
return True # Already initialized with real dataset names
33503352

3351-
torch.manual_seed(0)
3353+
if not _has_minari or not _has_gymnasium:
3354+
return False
33523355

3353-
total_keys = sorted(
3354-
minari.list_remote_datasets(latest_version=True, compatible_minari_version=True)
3355-
)
3356-
indices = torch.randperm(len(total_keys))[:20]
3357-
keys = [total_keys[idx] for idx in indices]
3356+
try:
3357+
import minari
3358+
3359+
torch.manual_seed(0)
33583360

3359-
assert len(keys) > 5, keys
3360-
_MINARI_DATASETS += keys
3361+
total_keys = sorted(
3362+
minari.list_remote_datasets(
3363+
latest_version=True, compatible_minari_version=True
3364+
)
3365+
)
3366+
indices = torch.randperm(len(total_keys))[:20]
3367+
keys = [total_keys[idx] for idx in indices]
33613368

3369+
assert len(keys) > 5, keys
3370+
_MINARI_DATASETS[:] = keys # Replace the placeholder values
3371+
return True
3372+
except Exception:
3373+
return False
33623374

3363-
_minari_selected_datasets()
3375+
3376+
# Initialize with placeholder values for parametrization
3377+
# These will be replaced with actual dataset names when the first Minari test runs
3378+
_MINARI_DATASETS = [str(i) for i in range(20)]
33643379

33653380

33663381
@pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari not found")
33673382
@pytest.mark.slow
33683383
class TestMinari:
33693384
@pytest.mark.parametrize("split", [False, True])
3370-
@pytest.mark.parametrize("selected_dataset", _MINARI_DATASETS)
3371-
def test_load(self, selected_dataset, split):
3385+
@pytest.mark.parametrize("dataset_idx", range(20))
3386+
def test_load(self, dataset_idx, split):
3387+
# Initialize Minari datasets if not already done
3388+
if not _minari_init():
3389+
pytest.skip("Failed to initialize Minari datasets")
3390+
3391+
# Get the actual dataset name from the initialized list
3392+
if dataset_idx >= len(_MINARI_DATASETS):
3393+
pytest.skip(f"Dataset index {dataset_idx} out of range")
3394+
3395+
selected_dataset = _MINARI_DATASETS[dataset_idx]
33723396
torchrl_logger.info(f"dataset {selected_dataset}")
33733397
data = MinariExperienceReplay(
33743398
selected_dataset, batch_size=32, split_trajs=split

torchrl/data/datasets/minari_data.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,5 +463,9 @@ def _patch_info(info_td):
463463
val_td_sel = val_td_sel.apply(
464464
lambda x: torch.cat([torch.zeros_like(x[:1]), x], 0), batch_size=[min_shape + 1]
465465
)
466-
val_td_sel.update(val_td.select(*unique_shapes[max_shape]))
466+
source = val_td.select(*unique_shapes[max_shape])
467+
# make sure source has no batch size
468+
source.batch_size = ()
469+
if not source.is_empty():
470+
val_td_sel.update(source, update_batch_size=True)
467471
return val_td_sel

torchrl/data/datasets/openx.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -577,9 +577,17 @@ def _init(self):
577577
)
578578
import datasets
579579

580-
dataset = datasets.load_dataset(
581-
self.repo, self.dataset_id, streaming=True, split=self.split
582-
)
580+
try:
581+
dataset = datasets.load_dataset(
582+
self.repo, self.dataset_id, streaming=True, split=self.split
583+
)
584+
except Exception as e:
585+
if "Dataset scripts are no longer supported" in str(e):
586+
raise RuntimeError(
587+
f"Failed to load dataset {self.dataset_id}. Your version of `datasets` is too new - please downgrade to <4.0.0."
588+
) from e
589+
raise e
590+
583591
if self.shuffle:
584592
dataset = dataset.shuffle()
585593
self.dataset = dataset

0 commit comments

Comments
 (0)