Skip to content

Commit 0795ac5

Browse files
author
Vincent Moens
committed
[BugFix] Fix auto-batch-size
1 parent c83c04c commit 0795ac5

File tree

5 files changed

+27
-13
lines changed

5 files changed

+27
-13
lines changed

torchrl/data/datasets/d4rl.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,15 +291,18 @@ def _get_dataset_direct(self, name, env_kwargs):
291291
k: torch.from_numpy(item)
292292
for k, item in dataset.items()
293293
if isinstance(item, np.ndarray)
294-
}
294+
},
295+
auto_batch_size=True,
295296
)
296297
dataset = dataset.unflatten_keys("/")
297298
if "metadata" in dataset.keys():
298299
metadata = dataset.get("metadata")
299300
dataset = dataset.exclude("metadata")
300301
self.metadata = metadata
301302
# find batch size
302-
dataset = make_tensordict(dataset.flatten_keys("/").to_dict())
303+
dataset = make_tensordict(
304+
dataset.flatten_keys("/").to_dict(), auto_batch_size=True
305+
)
303306
dataset = dataset.unflatten_keys("/")
304307
else:
305308
self.metadata = {}
@@ -361,7 +364,8 @@ def _get_dataset_from_env(self, name, env_kwargs):
361364
k: torch.from_numpy(item)
362365
for k, item in env.get_dataset().items()
363366
if isinstance(item, np.ndarray)
364-
}
367+
},
368+
auto_batch_size=True,
365369
)
366370
dataset = dataset.unflatten_keys("/")
367371
dataset = self._process_data_from_env(dataset, env)
@@ -373,7 +377,9 @@ def _process_data_from_env(self, dataset, env=None):
373377
dataset = dataset.exclude("metadata")
374378
self.metadata = metadata
375379
# find batch size
376-
dataset = make_tensordict(dataset.flatten_keys("/").to_dict())
380+
dataset = make_tensordict(
381+
dataset.flatten_keys("/").to_dict(), auto_batch_size=True
382+
)
377383
dataset = dataset.unflatten_keys("/")
378384
else:
379385
self.metadata = {}

torchrl/data/datasets/gen_dgrl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def _unpack_category_file(
272272
npybuffer = tar.extractfile(member=member)
273273
# npyfile = Path(download_folder) / member.name
274274
npfile = np.load(npybuffer, allow_pickle=True)
275-
td = TensorDict.from_dict(npfile.tolist())
275+
td = TensorDict.from_dict(npfile.tolist(), auto_batch_size=True)
276276
td.set("observations", td.get("observations").to(torch.uint8))
277277
td.set(("next", "observation"), td.get("observations")[1:])
278278
td.set("observations", td.get("observations")[:-1])

torchrl/data/datasets/openx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,4 +787,4 @@ def _make_tensordict_image_conv(data):
787787
data["observation"]["image"] = tensor
788788
except KeyError:
789789
pass
790-
return make_tensordict(data)
790+
return make_tensordict(data, auto_batch_size=True)

torchrl/data/datasets/vd4rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def _is_downloaded(self):
418418
def _from_npz(npz_path):
419419
npz = np.load(npz_path)
420420
npz_dict = {file: npz[file] for file in npz.files}
421-
return TensorDict.from_dict(npz_dict)
421+
return TensorDict.from_dict(npz_dict, auto_batch_size=True)
422422

423423

424424
_NAME_MATCH = KeyDependentDefaultDict(lambda x: x)

torchrl/data/rlhf/dataset.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,10 @@ def load(self):
148148
dataset = self._load_dataset()
149149
dataset = self._tokenize(dataset)
150150
prefix = (split, str(max_length))
151-
return self.dataset_to_tensordict(
151+
result = self.dataset_to_tensordict(
152152
dataset, data_dir=data_dir, prefix=prefix, valid_mask_key="valid_sample"
153-
)[prefix]
153+
)
154+
return result[prefix]
154155

155156
def _load_dataset(self):
156157
"""Loads a text dataset from ``datasets``.
@@ -213,7 +214,9 @@ def _tokenize(
213214
for key, value in dataset_dict.items()
214215
if key not in excluded_features
215216
}
216-
dataset = TensorDict.from_dict(dataset_dict)
217+
dataset = TensorDict.from_dict(
218+
dataset_dict, auto_batch_size=True, batch_dims=1
219+
)
217220
elif excluded_features:
218221
dataset = dataset.exclude(*excluded_features)
219222
# keep non empty rows (i.e. where at least one token is not eos)
@@ -294,14 +297,16 @@ def dataset_to_tensordict(
294297
if prefix is None:
295298
prefix = ()
296299
data_dict = {key: torch.as_tensor(dataset[key]) for key in features}
297-
out = TensorDict.from_dict(data_dict, batch_dims=batch_dims)
300+
out = TensorDict.from_dict(
301+
data_dict, batch_dims=batch_dims, auto_batch_size=True
302+
)
298303
else:
299304
out = dataset
300305
if valid_mask_key is not None and valid_mask_key in out.keys(
301306
include_nested=True
302307
):
303308
out = out[out.get(valid_mask_key)]
304-
out = TensorDict({prefix: out}, [])
309+
out = TensorDict({prefix: out})
305310
out.memmap_(prefix=data_dir)
306311
return out
307312

@@ -481,6 +486,9 @@ def __call__(self, sample):
481486
batch_size = [] if isinstance(input, str) else [len(input)]
482487
if self.return_tensordict:
483488
return TensorDict.from_dict(
484-
dict(tokenized_sample), batch_size=batch_size, device=self.device
489+
dict(tokenized_sample),
490+
batch_size=batch_size,
491+
device=self.device,
492+
auto_batch_size=True,
485493
)
486494
return tokenized_sample

0 commit comments

Comments
 (0)