Skip to content

Commit 39fe662

Browse files
vmoensrohitnig
andauthored
[Refactor] Buffers tensorclass compat and tutorial (#1101)
Co-authored-by: Rohit Nigam <rohitnigam@meta.com> Co-authored-by: Rohit Nigam <rohitnigam@gmail.com>
1 parent aad6684 commit 39fe662

File tree

14 files changed

+878
-86
lines changed

14 files changed

+878
-86
lines changed

test/_utils_internal.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import pytest
1717
import torch
1818
import torch.cuda
19+
20+
from tensordict import tensorclass
1921
from torchrl._utils import implement_for, seed_generator
2022

2123
from torchrl.envs import ObservationNorm
@@ -295,3 +297,15 @@ def t_out():
295297
)
296298

297299
return t_out
300+
301+
302+
def make_tc(td):
303+
"""Makes a tensorclass from a tensordict instance."""
304+
305+
class MyClass:
306+
pass
307+
308+
MyClass.__annotations__ = {}
309+
for key in td.keys():
310+
MyClass.__annotations__[key] = torch.Tensor
311+
return tensorclass(MyClass)

test/test_rb.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import numpy as np
1313
import pytest
1414
import torch
15-
from _utils_internal import get_available_devices
15+
from _utils_internal import get_available_devices, make_tc
1616
from tensordict import is_tensorclass, tensorclass
1717
from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase
1818
from torchrl.data import (
@@ -129,9 +129,14 @@ def test_add(self, rb_type, sampler, writer, storage, size):
129129
)
130130
data = self._get_datum(rb_type)
131131
rb.add(data)
132-
s = rb._storage[0]
132+
s = rb.sample(1)
133+
assert s.ndim, s
134+
s = s[0]
133135
if isinstance(s, TensorDictBase):
134-
assert (s == data.select(*s.keys())).all()
136+
s = s.select(*data.keys(True), strict=False)
137+
data = data.select(*s.keys(True), strict=False)
138+
assert (s == data).all()
139+
assert list(s.keys(True, True))
135140
else:
136141
assert (s == data).all()
137142

@@ -373,14 +378,22 @@ def test_prototype_prb(priority_key, contiguous, device):
373378

374379

375380
@pytest.mark.parametrize("stack", [False, True])
381+
@pytest.mark.parametrize("datatype", ["tc", "tb"])
376382
@pytest.mark.parametrize("reduction", ["min", "max", "median", "mean"])
377-
def test_replay_buffer_trajectories(stack, reduction):
383+
def test_replay_buffer_trajectories(stack, reduction, datatype):
378384
traj_td = TensorDict(
379385
{"obs": torch.randn(3, 4, 5), "actions": torch.randn(3, 4, 2)},
380386
batch_size=[3, 4],
381387
)
388+
if datatype == "tc":
389+
c = make_tc(traj_td)
390+
traj_td = c(**traj_td, batch_size=traj_td.batch_size)
391+
assert is_tensorclass(traj_td)
392+
elif datatype != "tb":
393+
raise NotImplementedError
394+
382395
if stack:
383-
traj_td = torch.stack([td.to_tensordict() for td in traj_td], 0)
396+
traj_td = torch.stack(list(traj_td), 0)
384397

385398
rb = TensorDictReplayBuffer(
386399
sampler=samplers.PrioritizedSampler(
@@ -394,6 +407,10 @@ def test_replay_buffer_trajectories(stack, reduction):
394407
)
395408
rb.extend(traj_td)
396409
sampled_td = rb.sample()
410+
if datatype == "tc":
411+
assert is_tensorclass(traj_td)
412+
return
413+
397414
sampled_td.set("td_error", torch.rand(sampled_td.shape))
398415
rb.update_tensordict_priority(sampled_td)
399416
sampled_td = rb.sample(include_info=True)
@@ -510,9 +527,12 @@ def test_add(self, rbtype, storage, size, prefetch):
510527
rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch)
511528
data = self._get_datum(rbtype)
512529
rb.add(data)
513-
s = rb._storage[0]
530+
s = rb.sample(1)[0]
514531
if isinstance(s, TensorDictBase):
515-
assert (s == data.select(*s.keys())).all()
532+
s = s.select(*data.keys(True), strict=False)
533+
data = data.select(*s.keys(True), strict=False)
534+
assert (s == data).all()
535+
assert list(s.keys(True, True))
516536
else:
517537
assert (s == data).all()
518538

@@ -649,6 +669,7 @@ def test_prb(priority_key, contiguous, device):
649669
},
650670
batch_size=[3],
651671
).to(device)
672+
652673
rb.extend(td1)
653674
s = rb.sample()
654675
assert s.batch_size == torch.Size([5])
@@ -838,17 +859,29 @@ def test_insert_transform():
838859

839860
@pytest.mark.parametrize("transform", transforms)
840861
def test_smoke_replay_buffer_transform(transform):
841-
rb = ReplayBuffer(transform=transform(in_keys="observation"), batch_size=1)
862+
rb = TensorDictReplayBuffer(
863+
transform=transform(in_keys=["observation"]), batch_size=1
864+
)
842865

843866
# td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": torch.randn(3)}, [])
844-
td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, [])
867+
td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 3)}, [])
845868
rb.add(td)
846-
rb.sample()
847869

848-
rb._transform = mock.MagicMock()
849-
rb._transform.__len__ = lambda *args: 3
870+
m = mock.Mock()
871+
m.side_effect = [td.unsqueeze(0)]
872+
rb._transform.forward = m
873+
# rb._transform.__len__ = lambda *args: 3
850874
rb.sample()
851-
assert rb._transform.called
875+
assert rb._transform.forward.called
876+
877+
# was_called = [False]
878+
# forward = rb._transform.forward
879+
# def new_forward(*args, **kwargs):
880+
# was_called[0] = True
881+
# return forward(*args, **kwargs)
882+
# rb._transform.forward = new_forward
883+
# rb.sample()
884+
# assert was_called[0]
852885

853886

854887
transforms = [

test/test_rb_distributed.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
5+
import argparse
66
import os
77
import sys
88
import time
@@ -53,7 +53,9 @@ def sample_from_buffer_remotely_returns_correct_tensordict_test(rank, name, worl
5353
_, inserted = _add_random_tensor_dict_to_buffer(buffer)
5454
sampled = _sample_from_buffer(buffer, 1)
5555
assert type(sampled) is type(inserted) is TensorDict
56-
assert (sampled["a"] == inserted["a"]).all()
56+
a_sample = sampled["a"]
57+
a_insert = inserted["a"]
58+
assert (a_sample == a_insert).all()
5759

5860

5961
@pytest.mark.skipif(
@@ -131,3 +133,8 @@ def _sample_from_buffer(buffer, batch_size):
131133
return rpc.rpc_sync(
132134
buffer.owner(), ReplayBufferNode.sample, args=(buffer, batch_size)
133135
)
136+
137+
138+
if __name__ == "__main__":
139+
args, unknown = argparse.ArgumentParser().parse_known_args()
140+
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def test_rb_trainer_state_dict(self, prioritized, storage_type):
289289
trainer._process_batch_hook(td)
290290
td_out = trainer._process_optim_batch_hook(td)
291291
if prioritized:
292-
td_out.set(replay_buffer.priority_key, torch.rand(N))
292+
td_out.unlock_().set(replay_buffer.priority_key, torch.rand(N))
293293
trainer._post_loss_hook(td_out)
294294

295295
trainer2 = mocking_trainer()
@@ -424,7 +424,7 @@ def make_storage():
424424
# sample from rb
425425
td_out = trainer._process_optim_batch_hook(td)
426426
if prioritized:
427-
td_out.set(replay_buffer.priority_key, torch.rand(N))
427+
td_out.unlock_().set(replay_buffer.priority_key, torch.rand(N))
428428
trainer._post_loss_hook(td_out)
429429
trainer.save_trainer(True)
430430

torchrl/_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,13 +454,14 @@ def context_decorator(ctx, func):
454454
be a multi-shot context manager that can be directly invoked multiple times)
455455
or a callable that produces a context manager.
456456
"""
457-
assert not (callable(ctx) and hasattr(ctx, "__enter__")), (
458-
f"Passed in {ctx} is both callable and also a valid context manager "
459-
"(has __enter__), making it ambiguous which interface to use. If you "
460-
"intended to pass a context manager factory, rewrite your call as "
461-
"context_decorator(lambda: ctx()); if you intended to pass a context "
462-
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
463-
)
457+
if callable(ctx) and hasattr(ctx, "__enter__"):
458+
raise RuntimeError(
459+
f"Passed in {ctx} is both callable and also a valid context manager "
460+
"(has __enter__), making it ambiguous which interface to use. If you "
461+
"intended to pass a context manager factory, rewrite your call as "
462+
"context_decorator(lambda: ctx()); if you intended to pass a context "
463+
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
464+
)
464465

465466
if not callable(ctx):
466467

0 commit comments

Comments
 (0)