12
12
import numpy as np
13
13
import pytest
14
14
import torch
15
- from _utils_internal import get_available_devices
15
+ from _utils_internal import get_available_devices , make_tc
16
16
from tensordict import is_tensorclass , tensorclass
17
17
from tensordict .tensordict import assert_allclose_td , TensorDict , TensorDictBase
18
18
from torchrl .data import (
@@ -129,9 +129,14 @@ def test_add(self, rb_type, sampler, writer, storage, size):
129
129
)
130
130
data = self ._get_datum (rb_type )
131
131
rb .add (data )
132
- s = rb ._storage [0 ]
132
+ s = rb .sample (1 )
133
+ assert s .ndim , s
134
+ s = s [0 ]
133
135
if isinstance (s , TensorDictBase ):
134
- assert (s == data .select (* s .keys ())).all ()
136
+ s = s .select (* data .keys (True ), strict = False )
137
+ data = data .select (* s .keys (True ), strict = False )
138
+ assert (s == data ).all ()
139
+ assert list (s .keys (True , True ))
135
140
else :
136
141
assert (s == data ).all ()
137
142
@@ -373,14 +378,22 @@ def test_prototype_prb(priority_key, contiguous, device):
373
378
374
379
375
380
@pytest .mark .parametrize ("stack" , [False , True ])
381
+ @pytest .mark .parametrize ("datatype" , ["tc" , "tb" ])
376
382
@pytest .mark .parametrize ("reduction" , ["min" , "max" , "median" , "mean" ])
377
- def test_replay_buffer_trajectories (stack , reduction ):
383
+ def test_replay_buffer_trajectories (stack , reduction , datatype ):
378
384
traj_td = TensorDict (
379
385
{"obs" : torch .randn (3 , 4 , 5 ), "actions" : torch .randn (3 , 4 , 2 )},
380
386
batch_size = [3 , 4 ],
381
387
)
388
+ if datatype == "tc" :
389
+ c = make_tc (traj_td )
390
+ traj_td = c (** traj_td , batch_size = traj_td .batch_size )
391
+ assert is_tensorclass (traj_td )
392
+ elif datatype != "tb" :
393
+ raise NotImplementedError
394
+
382
395
if stack :
383
- traj_td = torch .stack ([ td . to_tensordict () for td in traj_td ] , 0 )
396
+ traj_td = torch .stack (list ( traj_td ) , 0 )
384
397
385
398
rb = TensorDictReplayBuffer (
386
399
sampler = samplers .PrioritizedSampler (
@@ -394,6 +407,10 @@ def test_replay_buffer_trajectories(stack, reduction):
394
407
)
395
408
rb .extend (traj_td )
396
409
sampled_td = rb .sample ()
410
+ if datatype == "tc" :
411
+ assert is_tensorclass (traj_td )
412
+ return
413
+
397
414
sampled_td .set ("td_error" , torch .rand (sampled_td .shape ))
398
415
rb .update_tensordict_priority (sampled_td )
399
416
sampled_td = rb .sample (include_info = True )
@@ -510,9 +527,12 @@ def test_add(self, rbtype, storage, size, prefetch):
510
527
rb = self ._get_rb (rbtype , storage = storage , size = size , prefetch = prefetch )
511
528
data = self ._get_datum (rbtype )
512
529
rb .add (data )
513
- s = rb ._storage [0 ]
530
+ s = rb .sample ( 1 ) [0 ]
514
531
if isinstance (s , TensorDictBase ):
515
- assert (s == data .select (* s .keys ())).all ()
532
+ s = s .select (* data .keys (True ), strict = False )
533
+ data = data .select (* s .keys (True ), strict = False )
534
+ assert (s == data ).all ()
535
+ assert list (s .keys (True , True ))
516
536
else :
517
537
assert (s == data ).all ()
518
538
@@ -649,6 +669,7 @@ def test_prb(priority_key, contiguous, device):
649
669
},
650
670
batch_size = [3 ],
651
671
).to (device )
672
+
652
673
rb .extend (td1 )
653
674
s = rb .sample ()
654
675
assert s .batch_size == torch .Size ([5 ])
@@ -838,17 +859,29 @@ def test_insert_transform():
838
859
839
860
@pytest .mark .parametrize ("transform" , transforms )
840
861
def test_smoke_replay_buffer_transform (transform ):
841
- rb = ReplayBuffer (transform = transform (in_keys = "observation" ), batch_size = 1 )
862
+ rb = TensorDictReplayBuffer (
863
+ transform = transform (in_keys = ["observation" ]), batch_size = 1
864
+ )
842
865
843
866
# td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": torch.randn(3)}, [])
844
- td = TensorDict ({"observation" : torch .randn (3 , 3 , 3 , 16 , 1 )}, [])
867
+ td = TensorDict ({"observation" : torch .randn (3 , 3 , 3 , 16 , 3 )}, [])
845
868
rb .add (td )
846
- rb .sample ()
847
869
848
- rb ._transform = mock .MagicMock ()
849
- rb ._transform .__len__ = lambda * args : 3
870
+ m = mock .Mock ()
871
+ m .side_effect = [td .unsqueeze (0 )]
872
+ rb ._transform .forward = m
873
+ # rb._transform.__len__ = lambda *args: 3
850
874
rb .sample ()
851
- assert rb ._transform .called
875
+ assert rb ._transform .forward .called
876
+
877
+ # was_called = [False]
878
+ # forward = rb._transform.forward
879
+ # def new_forward(*args, **kwargs):
880
+ # was_called[0] = True
881
+ # return forward(*args, **kwargs)
882
+ # rb._transform.forward = new_forward
883
+ # rb.sample()
884
+ # assert was_called[0]
852
885
853
886
854
887
transforms = [
0 commit comments