@@ -4700,6 +4700,180 @@ def policy(td):
4700
4700
r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = []))
4701
4701
assert r .ndim == 1
4702
4702
4703
+ @pytest .mark .parametrize (
4704
+ "str2str,stack_method" ,
4705
+ [
4706
+ [True , None ],
4707
+ [False , "as_padded_tensor" ],
4708
+ # TODO: a bit experimental, fails with check_env_specs
4709
+ # [False, "as_nested_tensor"],
4710
+ [False , None ],
4711
+ ],
4712
+ )
4713
+ @pytest .mark .parametrize ("batched" , [True , False ])
4714
+ @pytest .mark .parametrize ("device" , [None , "cpu" ])
4715
+ @pytest .mark .parametrize ("batch_size" , [0 , 4 ])
4716
+ @pytest .mark .parametrize ("repeats" , [3 ])
4717
+ def test_llm_from_dataloader_repeats (
4718
+ self , str2str , batched , stack_method , device , batch_size , repeats
4719
+ ):
4720
+ if str2str :
4721
+ kwargs = {
4722
+ "dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4723
+ "data_keys" : ["observation" ],
4724
+ "example_data" : "a string!" ,
4725
+ "repeats" : repeats ,
4726
+ }
4727
+ else :
4728
+ if stack_method is None :
4729
+ stack_method = as_padded_tensor
4730
+ kwargs = {
4731
+ "dataloader" : self .DummyTensorDataLoader (
4732
+ padding = True , batch_size = batch_size
4733
+ ),
4734
+ "data_keys" : ["observation" ],
4735
+ "data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4736
+ "stack_method" : stack_method ,
4737
+ "repeats" : repeats ,
4738
+ }
4739
+ kwargs .update ({"str2str" : str2str , "device" : device })
4740
+ env = LLMEnv .from_dataloader (** kwargs )
4741
+ assert env .transform .repeats == repeats
4742
+
4743
+ max_steps = 3
4744
+ env .append_transform (StepCounter (max_steps = max_steps ))
4745
+
4746
+ def policy (td ):
4747
+ if str2str :
4748
+ if not td .shape :
4749
+ td ["action" ] = "<nothing>"
4750
+ else :
4751
+ td ["action" ] = NonTensorStack (
4752
+ * ["<nothing>" for _ in range (td .shape [0 ])]
4753
+ )
4754
+ else :
4755
+ td ["action" ] = torch .ones (td .shape + (1 ,), dtype = torch .int64 )
4756
+ return td
4757
+
4758
+ if batched :
4759
+ r = env .rollout (
4760
+ 100 ,
4761
+ policy ,
4762
+ tensordict = TensorDict (batch_size = [3 ]),
4763
+ break_when_any_done = False ,
4764
+ )
4765
+ else :
4766
+ r = env .rollout (100 , policy , break_when_any_done = False )
4767
+ # check that r at reset is always the same
4768
+ r_reset = r [..., ::max_steps ]
4769
+ if not batched :
4770
+ if str2str :
4771
+ assert r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4772
+ assert r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4773
+ assert r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4774
+ else :
4775
+ assert (
4776
+ r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4777
+ ).all ()
4778
+ assert (
4779
+ r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4780
+ ).all ()
4781
+ assert (
4782
+ r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4783
+ ).any ()
4784
+ else :
4785
+ # When batched, each block contains the 3 reset packs
4786
+ if str2str :
4787
+ assert r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4788
+ assert r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4789
+ assert r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4790
+ else :
4791
+ assert (
4792
+ r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4793
+ ).all ()
4794
+ assert (
4795
+ r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4796
+ ).all ()
4797
+ assert (
4798
+ r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4799
+ ).any ()
4800
+
4801
+ @pytest .mark .parametrize (
4802
+ "str2str,stack_method" ,
4803
+ [
4804
+ [True , None ],
4805
+ [False , "as_padded_tensor" ],
4806
+ ],
4807
+ )
4808
+ @pytest .mark .parametrize ("batched" , [True ])
4809
+ @pytest .mark .parametrize ("device" , [None ])
4810
+ @pytest .mark .parametrize ("batch_size" , [4 ])
4811
+ @pytest .mark .parametrize ("repeats" , [3 ])
4812
+ @pytest .mark .parametrize (
4813
+ "assign_reward,assign_done" , [[True , False ], [True , True ], [False , True ]]
4814
+ )
4815
+ def test_done_and_reward (
4816
+ self ,
4817
+ str2str ,
4818
+ batched ,
4819
+ stack_method ,
4820
+ device ,
4821
+ batch_size ,
4822
+ repeats ,
4823
+ assign_reward ,
4824
+ assign_done ,
4825
+ ):
4826
+ with pytest .raises (
4827
+ ValueError , match = "str2str"
4828
+ ) if str2str else contextlib .nullcontext ():
4829
+ if str2str :
4830
+ kwargs = {
4831
+ "dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4832
+ "data_keys" : ["observation" ],
4833
+ "example_data" : "a string!" ,
4834
+ "repeats" : repeats ,
4835
+ "assign_reward" : assign_reward ,
4836
+ "assign_done" : assign_done ,
4837
+ }
4838
+ else :
4839
+ if stack_method is None :
4840
+ stack_method = as_padded_tensor
4841
+ kwargs = {
4842
+ "dataloader" : self .DummyTensorDataLoader (
4843
+ padding = True , batch_size = batch_size
4844
+ ),
4845
+ "data_keys" : ["observation" ],
4846
+ "data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4847
+ "stack_method" : stack_method ,
4848
+ "repeats" : repeats ,
4849
+ "assign_reward" : assign_reward ,
4850
+ "assign_done" : assign_done ,
4851
+ }
4852
+ kwargs .update ({"str2str" : str2str , "device" : device })
4853
+ env = LLMEnv .from_dataloader (** kwargs )
4854
+ # We want to make sure that transforms that rely on the done state work appropriately
4855
+ env .append_transform (StepCounter (max_steps = 10 ))
4856
+
4857
+ def policy (td ):
4858
+ td ["action" ] = torch .ones (
4859
+ td .shape + (torch .randint (10 , (1 ,)).item (),), dtype = torch .int64
4860
+ )
4861
+ return td
4862
+
4863
+ if batched :
4864
+ r = env .rollout (
4865
+ 100 ,
4866
+ policy ,
4867
+ tensordict = TensorDict (batch_size = [3 ]),
4868
+ break_when_any_done = False ,
4869
+ )
4870
+ else :
4871
+ r = env .rollout (100 , policy , break_when_any_done = False )
4872
+ if assign_done :
4873
+ assert "terminated" in r
4874
+ assert "done" in r
4875
+ print (r )
4876
+
4703
4877
4704
4878
if __name__ == "__main__" :
4705
4879
args , unknown = argparse .ArgumentParser ().parse_known_args ()
0 commit comments