@@ -690,6 +690,39 @@ def make_env():
690
690
del env
691
691
692
692
693
+ @pytest .mark .parametrize (
694
+ "break_when_any_done,break_when_all_done" ,
695
+ [[True , False ], [False , True ], [False , False ]],
696
+ )
697
+ @pytest .mark .parametrize ("n_envs" , [1 , 4 ])
698
+ def test_collector_outplace_policy (n_envs , break_when_any_done , break_when_all_done ):
699
+ def policy_inplace (td ):
700
+ td .set ("action" , torch .ones (td .shape + (1 ,)))
701
+ return td
702
+
703
+ def policy_outplace (td ):
704
+ return td .empty ().set ("action" , torch .ones (td .shape + (1 ,)))
705
+
706
+ if n_envs == 1 :
707
+ env = CountingEnv (10 )
708
+ else :
709
+ env = SerialEnv (
710
+ n_envs ,
711
+ [functools .partial (CountingEnv , 10 + i ) for i in range (n_envs )],
712
+ )
713
+ env .reset ()
714
+ c_inplace = SyncDataCollector (
715
+ env , policy_inplace , frames_per_batch = 10 , total_frames = 100
716
+ )
717
+ d_inplace = torch .cat (list (c_inplace ), dim = 0 )
718
+ env .reset ()
719
+ c_outplace = SyncDataCollector (
720
+ env , policy_outplace , frames_per_batch = 10 , total_frames = 100
721
+ )
722
+ d_outplace = torch .cat (list (c_outplace ), dim = 0 )
723
+ assert_allclose_td (d_inplace , d_outplace )
724
+
725
+
693
726
# Deprecated reset_when_done
694
727
# @pytest.mark.parametrize("num_env", [1, 2])
695
728
# @pytest.mark.parametrize("env_name", ["vec"])
0 commit comments