@@ -4616,11 +4616,13 @@ def __next__(self):
4616
4616
@pytest .mark .parametrize ("batch_size" , [0 , 4 ])
4617
4617
@pytest .mark .parametrize ("device" , [None , "cpu" ])
4618
4618
def test_llm_env (self , str2str , batched , stack_method , device , batch_size ):
4619
- env = LLMEnv (str2str = str2str , device = device )
4619
+ env = LLMEnv (
4620
+ str2str = str2str , device = device , has_attention = False , no_stack = False
4621
+ )
4620
4622
if str2str :
4621
4623
primer = DataLoadingPrimer (
4622
4624
dataloader = self .DummyDataLoader (batch_size = batch_size ),
4623
- data_keys = ["observation" ],
4625
+ data_keys = [LLMEnv . _DEFAULT_STR_KEY ],
4624
4626
example_data = "a string!" ,
4625
4627
)
4626
4628
else :
@@ -4630,7 +4632,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4630
4632
dataloader = self .DummyTensorDataLoader (
4631
4633
batch_size = batch_size , padding = True
4632
4634
),
4633
- data_keys = ["observation" ],
4635
+ data_keys = [LLMEnv . _DEFAULT_TOKEN_KEY ],
4634
4636
data_specs = [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4635
4637
stack_method = stack_method ,
4636
4638
)
@@ -4640,7 +4642,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4640
4642
if batched :
4641
4643
td = env .reset (TensorDict (batch_size = [3 ]))
4642
4644
env .check_env_specs (break_when_any_done = "both" , tensordict = td )
4643
- r = env .rollout (10 , tensordict = TensorDict (batch_size = [3 ]))
4645
+ env .rollout (10 , tensordict = TensorDict (batch_size = [3 ]))
4644
4646
else :
4645
4647
env .check_env_specs (break_when_any_done = "both" )
4646
4648
@@ -4663,7 +4665,7 @@ def test_llm_from_dataloader(
4663
4665
if str2str :
4664
4666
kwargs = {
4665
4667
"dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4666
- "data_keys" : ["observation" ],
4668
+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
4667
4669
"example_data" : "a string!" ,
4668
4670
}
4669
4671
else :
@@ -4673,11 +4675,18 @@ def test_llm_from_dataloader(
4673
4675
"dataloader" : self .DummyTensorDataLoader (
4674
4676
padding = True , batch_size = batch_size
4675
4677
),
4676
- "data_keys" : ["observation" ],
4678
+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
4677
4679
"data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4678
4680
"stack_method" : stack_method ,
4679
4681
}
4680
- kwargs .update ({"str2str" : str2str , "device" : device })
4682
+ kwargs .update (
4683
+ {
4684
+ "str2str" : str2str ,
4685
+ "device" : device ,
4686
+ "has_attention" : False ,
4687
+ "no_stack" : False ,
4688
+ }
4689
+ )
4681
4690
env = LLMEnv .from_dataloader (** kwargs )
4682
4691
assert not env .batch_locked
4683
4692
if batched :
@@ -4690,51 +4699,283 @@ def test_llm_from_dataloader(
4690
4699
def policy (td ):
4691
4700
if str2str :
4692
4701
if not td .shape :
4693
- td ["action" ] = "<nothing>"
4702
+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = "<nothing>"
4694
4703
else :
4695
- td ["action" ] = NonTensorStack (
4704
+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = NonTensorStack (
4696
4705
* ["<nothing>" for _ in range (td .shape [0 ])]
4697
4706
)
4698
4707
else :
4699
- td ["action" ] = torch .ones (td .shape + (1 ,), dtype = torch .int64 )
4708
+ td [LLMEnv ._DEFAULT_ACTION_KEY ] = torch .ones (
4709
+ td .shape + (1 ,), dtype = torch .int64
4710
+ )
4700
4711
return td
4701
4712
4702
4713
if batched :
4703
4714
# Tell the env that we want 3 sub-envs
4704
4715
r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = [3 ]))
4705
4716
assert r .ndim == 2
4706
4717
if str2str :
4707
- assert isinstance (r [0 , 0 ]["observation" ], str )
4708
- assert isinstance (r [0 , 1 ]["observation" ], str )
4718
+ assert isinstance (r [0 , 0 ][LLMEnv . _DEFAULT_STR_KEY ], str )
4719
+ assert isinstance (r [0 , 1 ][LLMEnv . _DEFAULT_STR_KEY ], str )
4709
4720
assert (
4710
- r [0 , 0 ]["observation" ]
4711
- == r [0 , 1 ]["observation" ][: - len (r [0 , 0 ]["action" ])]
4721
+ r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4722
+ == r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4723
+ : - len (r [0 , 0 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4724
+ ]
4712
4725
)
4713
4726
assert (
4714
- r [0 , 1 ]["observation" ]
4715
- == r [0 , 2 ]["observation" ][: - len (r [0 , 1 ]["action" ])]
4727
+ r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4728
+ == r [0 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4729
+ : - len (r [0 , 1 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4730
+ ]
4716
4731
)
4717
4732
assert (
4718
- r [- 1 , 0 ]["observation" ]
4719
- == r [- 1 , 1 ]["observation" ][: - len (r [- 1 , 0 ]["action" ])]
4733
+ r [- 1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4734
+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4735
+ : - len (r [- 1 , 0 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4736
+ ]
4720
4737
)
4721
4738
assert (
4722
- r [- 1 , 1 ]["observation" ]
4723
- == r [- 1 , 2 ]["observation" ][: - len (r [- 1 , 1 ]["action" ])]
4739
+ r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4740
+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4741
+ : - len (r [- 1 , 1 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4742
+ ]
4724
4743
)
4725
4744
else :
4726
- assert (r [0 , 0 ]["observation" ] == r [0 , 1 ]["observation" ][:- 1 ]).all ()
4727
- assert (r [0 , 1 ]["observation" ] == r [0 , 2 ]["observation" ][:- 1 ]).all ()
4728
4745
assert (
4729
- r [- 1 , 0 ]["observation" ] == r [- 1 , 1 ]["observation" ][:- 1 ]
4746
+ r [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4747
+ == r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4748
+ ).all ()
4749
+ assert (
4750
+ r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4751
+ == r [0 , 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4730
4752
).all ()
4731
4753
assert (
4732
- r [- 1 , 1 ]["observation" ] == r [- 1 , 2 ]["observation" ][:- 1 ]
4754
+ r [- 1 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4755
+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4756
+ ).all ()
4757
+ assert (
4758
+ r [- 1 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4759
+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4733
4760
).all ()
4734
4761
else :
4735
4762
r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = []))
4736
4763
assert r .ndim == 1
4737
4764
4765
+ @pytest .mark .parametrize (
4766
+ "str2str,stack_method" ,
4767
+ [
4768
+ [True , None ],
4769
+ [False , "as_padded_tensor" ],
4770
+ # TODO: a bit experimental, fails with check_env_specs
4771
+ # [False, "as_nested_tensor"],
4772
+ [False , None ],
4773
+ ],
4774
+ )
4775
+ @pytest .mark .parametrize ("batched" , [True , False ])
4776
+ @pytest .mark .parametrize ("device" , [None , "cpu" ])
4777
+ @pytest .mark .parametrize ("batch_size" , [0 , 4 ])
4778
+ @pytest .mark .parametrize ("repeats" , [3 ])
4779
+ def test_llm_from_dataloader_repeats (
4780
+ self , str2str , batched , stack_method , device , batch_size , repeats
4781
+ ):
4782
+ if str2str :
4783
+ kwargs = {
4784
+ "dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4785
+ "data_keys" : [LLMEnv ._DEFAULT_STR_KEY ],
4786
+ "example_data" : "a string!" ,
4787
+ "repeats" : repeats ,
4788
+ }
4789
+ else :
4790
+ if stack_method is None :
4791
+ stack_method = as_padded_tensor
4792
+ kwargs = {
4793
+ "dataloader" : self .DummyTensorDataLoader (
4794
+ padding = True , batch_size = batch_size
4795
+ ),
4796
+ "data_keys" : [LLMEnv ._DEFAULT_TOKEN_KEY ],
4797
+ "data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4798
+ "stack_method" : stack_method ,
4799
+ "repeats" : repeats ,
4800
+ }
4801
+ kwargs .update (
4802
+ {
4803
+ "str2str" : str2str ,
4804
+ "device" : device ,
4805
+ "has_attention" : False ,
4806
+ "no_stack" : False ,
4807
+ }
4808
+ )
4809
+ env = LLMEnv .from_dataloader (** kwargs )
4810
+ assert env .transform .repeats == repeats
4811
+
4812
+ max_steps = 3
4813
+ env .append_transform (StepCounter (max_steps = max_steps ))
4814
+
4815
+ def policy (td ):
4816
+ if str2str :
4817
+ if not td .shape :
4818
+ td [LLMEnv ._DEFAULT_ACTION_KEY ] = "<nothing>"
4819
+ else :
4820
+ td [LLMEnv ._DEFAULT_ACTION_KEY ] = NonTensorStack (
4821
+ * ["<nothing>" for _ in range (td .shape [0 ])]
4822
+ )
4823
+ else :
4824
+ td [LLMEnv ._DEFAULT_ACTION_KEY ] = torch .ones (
4825
+ td .shape + (1 ,), dtype = torch .int64
4826
+ )
4827
+ return td
4828
+
4829
+ if batched :
4830
+ r = env .rollout (
4831
+ 100 ,
4832
+ policy ,
4833
+ tensordict = TensorDict (batch_size = [3 ]),
4834
+ break_when_any_done = False ,
4835
+ )
4836
+ else :
4837
+ r = env .rollout (100 , policy , break_when_any_done = False )
4838
+ # check that r at reset is always the same
4839
+ r_reset = r [..., ::max_steps ]
4840
+ if not batched :
4841
+ if str2str :
4842
+ assert (
4843
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4844
+ == r_reset [..., 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4845
+ )
4846
+ assert (
4847
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4848
+ == r_reset [..., 2 ][LLMEnv ._DEFAULT_STR_KEY ]
4849
+ )
4850
+ assert (
4851
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4852
+ != r_reset [..., 3 ][LLMEnv ._DEFAULT_STR_KEY ]
4853
+ )
4854
+ else :
4855
+ assert (
4856
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4857
+ == r_reset [..., 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4858
+ ).all ()
4859
+ assert (
4860
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4861
+ == r_reset [..., 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4862
+ ).all ()
4863
+ assert (
4864
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4865
+ != r_reset [..., 3 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4866
+ ).any ()
4867
+ else :
4868
+ # When batched, each block contains the 3 reset packs
4869
+ if str2str :
4870
+ assert (
4871
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4872
+ == r_reset [1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4873
+ )
4874
+ assert (
4875
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4876
+ == r_reset [2 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4877
+ )
4878
+ assert (
4879
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4880
+ != r_reset [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4881
+ )
4882
+ else :
4883
+ assert (
4884
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4885
+ == r_reset [1 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4886
+ ).all ()
4887
+ assert (
4888
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4889
+ == r_reset [2 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4890
+ ).all ()
4891
+ assert (
4892
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4893
+ != r_reset [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4894
+ ).any ()
4895
+
4896
+ @pytest .mark .parametrize (
4897
+ "str2str,stack_method" ,
4898
+ [
4899
+ [True , None ],
4900
+ [False , "as_padded_tensor" ],
4901
+ ],
4902
+ )
4903
+ @pytest .mark .parametrize ("batched" , [True ])
4904
+ @pytest .mark .parametrize ("device" , [None ])
4905
+ @pytest .mark .parametrize ("batch_size" , [4 ])
4906
+ @pytest .mark .parametrize ("repeats" , [3 ])
4907
+ @pytest .mark .parametrize (
4908
+ "assign_reward,assign_done" , [[True , False ], [True , True ], [False , True ]]
4909
+ )
4910
+ def test_done_and_reward (
4911
+ self ,
4912
+ str2str ,
4913
+ batched ,
4914
+ stack_method ,
4915
+ device ,
4916
+ batch_size ,
4917
+ repeats ,
4918
+ assign_reward ,
4919
+ assign_done ,
4920
+ ):
4921
+ with pytest .raises (
4922
+ ValueError , match = "str2str"
4923
+ ) if str2str else contextlib .nullcontext ():
4924
+ if str2str :
4925
+ kwargs = {
4926
+ "dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4927
+ "data_keys" : [LLMEnv ._DEFAULT_STR_KEY ],
4928
+ "example_data" : "a string!" ,
4929
+ "repeats" : repeats ,
4930
+ "assign_reward" : assign_reward ,
4931
+ "assign_done" : assign_done ,
4932
+ }
4933
+ else :
4934
+ if stack_method is None :
4935
+ stack_method = as_padded_tensor
4936
+ kwargs = {
4937
+ "dataloader" : self .DummyTensorDataLoader (
4938
+ padding = True , batch_size = batch_size
4939
+ ),
4940
+ "data_keys" : [LLMEnv ._DEFAULT_TOKEN_KEY ],
4941
+ "data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4942
+ "stack_method" : stack_method ,
4943
+ "repeats" : repeats ,
4944
+ "assign_reward" : assign_reward ,
4945
+ "assign_done" : assign_done ,
4946
+ }
4947
+ kwargs .update (
4948
+ {
4949
+ "str2str" : str2str ,
4950
+ "device" : device ,
4951
+ "has_attention" : False ,
4952
+ "no_stack" : False ,
4953
+ }
4954
+ )
4955
+ env = LLMEnv .from_dataloader (** kwargs )
4956
+ # We want to make sure that transforms that rely on the done state work appropriately
4957
+ env .append_transform (StepCounter (max_steps = 10 ))
4958
+
4959
+ def policy (td ):
4960
+ td [LLMEnv ._DEFAULT_ACTION_KEY ] = torch .ones (
4961
+ td .shape + (torch .randint (10 , (1 ,)).item (),), dtype = torch .int64
4962
+ )
4963
+ return td
4964
+
4965
+ if batched :
4966
+ r = env .rollout (
4967
+ 100 ,
4968
+ policy ,
4969
+ tensordict = TensorDict (batch_size = [3 ]),
4970
+ break_when_any_done = False ,
4971
+ )
4972
+ else :
4973
+ r = env .rollout (100 , policy , break_when_any_done = False )
4974
+ if assign_done :
4975
+ assert "terminated" in r
4976
+ assert "done" in r
4977
+ print (r )
4978
+
4738
4979
4739
4980
if __name__ == "__main__" :
4740
4981
args , unknown = argparse .ArgumentParser ().parse_known_args ()
0 commit comments