Skip to content

Commit 40c04ef

Browse files
authored
[Feature] Default collate_fn (#688)
* init * amend
1 parent f5d98af commit 40c04ef

File tree

5 files changed

+68
-54
lines changed

5 files changed

+68
-54
lines changed

test/test_rb.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@
3030
from torchrl.data.replay_buffers.writers import RoundRobinWriter
3131

3232

33-
collate_fn_dict = {
34-
ListStorage: lambda x: torch.stack(x, 0),
35-
LazyTensorStorage: lambda x: x,
36-
LazyMemmapStorage: lambda x: x,
37-
None: lambda x: torch.stack(x, 0),
38-
}
33+
# collate_fn_dict = {
34+
# ListStorage: lambda x: torch.stack(x, 0),
35+
# LazyTensorStorage: lambda x: x,
36+
# LazyMemmapStorage: lambda x: x,
37+
# None: lambda x: torch.stack(x, 0),
38+
# }
3939

4040

4141
@pytest.mark.parametrize(
@@ -54,7 +54,6 @@
5454
@pytest.mark.parametrize("size", [3, 100])
5555
class TestPrototypeBuffers:
5656
def _get_rb(self, rb_type, size, sampler, writer, storage):
57-
collate_fn = collate_fn_dict[storage]
5857

5958
if storage is not None:
6059
storage = storage(size)
@@ -65,9 +64,7 @@ def _get_rb(self, rb_type, size, sampler, writer, storage):
6564

6665
sampler = sampler(**sampler_args)
6766
writer = writer()
68-
rb = rb_type(
69-
collate_fn=collate_fn, storage=storage, sampler=sampler, writer=writer
70-
)
67+
rb = rb_type(storage=storage, sampler=sampler, writer=writer)
7168
return rb
7269

7370
def _get_datum(self, rb_type):
@@ -192,7 +189,6 @@ def test_prototype_prb(priority_key, contiguous, device):
192189
np.random.seed(0)
193190
rb = rb_prototype.TensorDictReplayBuffer(
194191
sampler=samplers.PrioritizedSampler(5, alpha=0.7, beta=0.9),
195-
collate_fn=None if contiguous else lambda x: torch.stack(x, 0),
196192
priority_key=priority_key,
197193
)
198194
td1 = TensorDict(
@@ -271,7 +267,6 @@ def test_rb_prototype_trajectories(stack):
271267
alpha=0.7,
272268
beta=0.9,
273269
),
274-
collate_fn=lambda x: torch.stack(x, 0),
275270
priority_key="td_error",
276271
)
277272
rb.extend(traj_td)
@@ -315,7 +310,6 @@ class TestBuffers:
315310
_default_params_td_prb = {"alpha": 0.8, "beta": 0.9}
316311

317312
def _get_rb(self, rbtype, size, storage, prefetch):
318-
collate_fn = collate_fn_dict[storage]
319313
if storage is not None:
320314
storage = storage(size)
321315
if rbtype is ReplayBuffer:
@@ -328,13 +322,7 @@ def _get_rb(self, rbtype, size, storage, prefetch):
328322
params = self._default_params_td_prb
329323
else:
330324
raise NotImplementedError(rbtype)
331-
rb = rbtype(
332-
size=size,
333-
storage=storage,
334-
prefetch=prefetch,
335-
collate_fn=collate_fn,
336-
**params
337-
)
325+
rb = rbtype(size=size, storage=storage, prefetch=prefetch, **params)
338326
return rb
339327

340328
def _get_datum(self, rbtype):
@@ -460,7 +448,6 @@ def test_prb(priority_key, contiguous, device):
460448
5,
461449
alpha=0.7,
462450
beta=0.9,
463-
collate_fn=None if contiguous else lambda x: torch.stack(x, 0),
464451
priority_key=priority_key,
465452
)
466453
td1 = TensorDict(
@@ -537,7 +524,6 @@ def test_rb_trajectories(stack):
537524
5,
538525
alpha=0.7,
539526
beta=0.9,
540-
collate_fn=lambda x: torch.stack(x, 0),
541527
priority_key="td_error",
542528
)
543529
rb.extend(traj_td)
@@ -565,10 +551,14 @@ def test_shared_storage_prioritized_sampler():
565551
sampler1 = PrioritizedSampler(max_capacity=n, alpha=0.7, beta=1.1)
566552

567553
rb0 = rb_prototype.ReplayBuffer(
568-
storage=storage, writer=writer, sampler=sampler0, collate_fn=lambda x: x
554+
storage=storage,
555+
writer=writer,
556+
sampler=sampler0,
569557
)
570558
rb1 = rb_prototype.ReplayBuffer(
571-
storage=storage, writer=writer, sampler=sampler1, collate_fn=lambda x: x
559+
storage=storage,
560+
writer=writer,
561+
sampler=sampler1,
572562
)
573563

574564
data = TensorDict({"a": torch.arange(50)}, [50])
@@ -593,9 +583,11 @@ def test_legacy_rb_does_not_attach():
593583
storage = LazyMemmapStorage(n)
594584
writer = RoundRobinWriter()
595585
sampler = RandomSampler()
596-
rb = ReplayBuffer(storage=storage, size=n, prefetch=0, collate_fn=lambda x: x)
586+
rb = ReplayBuffer(storage=storage, size=n, prefetch=0)
597587
prb = rb_prototype.ReplayBuffer(
598-
storage=storage, writer=writer, sampler=sampler, collate_fn=lambda x: x
588+
storage=storage,
589+
writer=writer,
590+
sampler=sampler,
599591
)
600592

601593
assert len(storage._attached_entities) == 1

test/test_trainer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,20 +249,22 @@ def test_rb_trainer_state_dict(self, prioritized, storage_type):
249249
S = 100
250250
if storage_type == "list":
251251
storage = ListStorage(S)
252-
collate_fn = lambda x: torch.stack(x, 0)
253252
elif storage_type == "memmap":
254253
storage = LazyMemmapStorage(S)
255-
collate_fn = lambda x: x
256254
else:
257255
raise NotImplementedError
258256

259257
if prioritized:
260258
replay_buffer = TensorDictPrioritizedReplayBuffer(
261-
S, 1.1, 0.9, storage=storage, collate_fn=collate_fn
259+
S,
260+
1.1,
261+
0.9,
262+
storage=storage,
262263
)
263264
else:
264265
replay_buffer = TensorDictReplayBuffer(
265-
S, storage=storage, collate_fn=collate_fn
266+
S,
267+
storage=storage,
266268
)
267269

268270
N = 9

torchrl/data/replay_buffers/rb_prototype.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import torch
77
from tensordict.tensordict import TensorDictBase, LazyStackedTensorDict
88

9-
from .replay_buffers import pin_memory_output, stack_tensors, stack_td
9+
from .replay_buffers import pin_memory_output
1010
from .samplers import Sampler, RandomSampler
11-
from .storages import Storage, ListStorage
11+
from .storages import Storage, ListStorage, _get_default_collate
1212
from .utils import INT_CLASSES, _to_numpy, accept_remote_rref_udf_invocation
1313
from .writers import Writer, RoundRobinWriter
1414

@@ -47,7 +47,11 @@ def __init__(
4747
self._writer = writer if writer is not None else RoundRobinWriter()
4848
self._writer.register_storage(self._storage)
4949

50-
self._collate_fn = collate_fn or stack_tensors
50+
self._collate_fn = (
51+
collate_fn
52+
if collate_fn is not None
53+
else _get_default_collate(self._storage)
54+
)
5155
self._pin_memory = pin_memory
5256

5357
self._prefetch = bool(prefetch)
@@ -169,12 +173,6 @@ class TensorDictReplayBuffer(ReplayBuffer):
169173
"""
170174

171175
def __init__(self, priority_key: str = "td_error", **kw) -> None:
172-
if not kw.get("collate_fn"):
173-
174-
def collate_fn(x):
175-
return stack_td(x, 0, contiguous=True)
176-
177-
kw["collate_fn"] = collate_fn
178176
super().__init__(**kw)
179177
self.priority_key = priority_key
180178

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch
1414
from tensordict.tensordict import (
1515
TensorDictBase,
16-
_stack as stack_td,
1716
LazyStackedTensorDict,
1817
)
1918
from torch import Tensor
@@ -24,7 +23,11 @@
2423
SumSegmentTreeFp32,
2524
SumSegmentTreeFp64,
2625
)
27-
from torchrl.data.replay_buffers.storages import Storage, ListStorage
26+
from torchrl.data.replay_buffers.storages import (
27+
Storage,
28+
ListStorage,
29+
_get_default_collate,
30+
)
2831
from torchrl.data.replay_buffers.utils import INT_CLASSES
2932
from torchrl.data.replay_buffers.utils import (
3033
_to_numpy,
@@ -118,9 +121,11 @@ def __init__(
118121
self._storage = storage
119122
self._capacity = size
120123
self._cursor = 0
121-
if collate_fn is None:
122-
collate_fn = stack_tensors
123-
self._collate_fn = collate_fn
124+
self._collate_fn = (
125+
collate_fn
126+
if collate_fn is not None
127+
else _get_default_collate(self._storage)
128+
)
124129
self._pin_memory = pin_memory
125130

126131
self._prefetch = prefetch is not None and prefetch > 0
@@ -558,11 +563,6 @@ def __init__(
558563
prefetch: Optional[int] = None,
559564
storage: Optional[Storage] = None,
560565
):
561-
if collate_fn is None:
562-
563-
def collate_fn(x):
564-
return stack_td(x, 0, contiguous=True)
565-
566566
super().__init__(size, collate_fn, pin_memory, prefetch, storage=storage)
567567

568568

@@ -606,11 +606,6 @@ def __init__(
606606
prefetch: Optional[int] = None,
607607
storage: Optional[Storage] = None,
608608
) -> None:
609-
if collate_fn is None:
610-
611-
def collate_fn(x):
612-
return stack_td(x, 0, contiguous=True)
613-
614609
super(TensorDictPrioritizedReplayBuffer, self).__init__(
615610
size=size,
616611
alpha=alpha,

torchrl/data/replay_buffers/storages.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,3 +414,30 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor:
414414
)
415415
elif _CKPT_BACKEND == "torch":
416416
return mem_map_tensor._tensor
417+
418+
419+
def _collate_list_tensordict(x):
420+
out = torch.stack(x, 0)
421+
if isinstance(out, TensorDictBase):
422+
return out.to_tensordict()
423+
return out
424+
425+
426+
def _collate_list_tensors(*x):
427+
return tuple(torch.stack(_x, 0) for _x in zip(*x))
428+
429+
430+
def _collate_contiguous(x):
431+
if isinstance(x, TensorDictBase):
432+
return x.to_tensordict()
433+
return x.clone()
434+
435+
436+
def _get_default_collate(storage, _is_tensordict=True):
437+
if isinstance(storage, ListStorage):
438+
if _is_tensordict:
439+
return _collate_list_tensordict
440+
else:
441+
return _collate_list_tensors
442+
elif isinstance(storage, (LazyTensorStorage, LazyMemmapStorage)):
443+
return _collate_contiguous

0 commit comments

Comments
 (0)