@@ -4581,11 +4581,13 @@ def __next__(self):
4581
4581
@pytest .mark .parametrize ("batch_size" , [0 , 4 ])
4582
4582
@pytest .mark .parametrize ("device" , [None , "cpu" ])
4583
4583
def test_llm_env (self , str2str , batched , stack_method , device , batch_size ):
4584
- env = LLMEnv (str2str = str2str , device = device )
4584
+ env = LLMEnv (
4585
+ str2str = str2str , device = device , has_attention = False , no_stack = False
4586
+ )
4585
4587
if str2str :
4586
4588
primer = DataLoadingPrimer (
4587
4589
dataloader = self .DummyDataLoader (batch_size = batch_size ),
4588
- data_keys = ["observation" ],
4590
+ data_keys = [LLMEnv . _DEFAULT_STR_KEY ],
4589
4591
example_data = "a string!" ,
4590
4592
)
4591
4593
else :
@@ -4595,7 +4597,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4595
4597
dataloader = self .DummyTensorDataLoader (
4596
4598
batch_size = batch_size , padding = True
4597
4599
),
4598
- data_keys = ["observation" ],
4600
+ data_keys = [LLMEnv . _DEFAULT_TOKEN_KEY ],
4599
4601
data_specs = [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4600
4602
stack_method = stack_method ,
4601
4603
)
@@ -4605,7 +4607,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4605
4607
if batched :
4606
4608
td = env .reset (TensorDict (batch_size = [3 ]))
4607
4609
env .check_env_specs (break_when_any_done = "both" , tensordict = td )
4608
- r = env .rollout (10 , tensordict = TensorDict (batch_size = [3 ]))
4610
+ env .rollout (10 , tensordict = TensorDict (batch_size = [3 ]))
4609
4611
else :
4610
4612
env .check_env_specs (break_when_any_done = "both" )
4611
4613
@@ -4628,7 +4630,7 @@ def test_llm_from_dataloader(
4628
4630
if str2str :
4629
4631
kwargs = {
4630
4632
"dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4631
- "data_keys" : ["observation" ],
4633
+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
4632
4634
"example_data" : "a string!" ,
4633
4635
}
4634
4636
else :
@@ -4638,11 +4640,18 @@ def test_llm_from_dataloader(
4638
4640
"dataloader" : self .DummyTensorDataLoader (
4639
4641
padding = True , batch_size = batch_size
4640
4642
),
4641
- "data_keys" : ["observation" ],
4643
+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
4642
4644
"data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4643
4645
"stack_method" : stack_method ,
4644
4646
}
4645
- kwargs .update ({"str2str" : str2str , "device" : device })
4647
+ kwargs .update (
4648
+ {
4649
+ "str2str" : str2str ,
4650
+ "device" : device ,
4651
+ "has_attention" : False ,
4652
+ "no_stack" : False ,
4653
+ }
4654
+ )
4646
4655
env = LLMEnv .from_dataloader (** kwargs )
4647
4656
assert not env .batch_locked
4648
4657
if batched :
@@ -4655,46 +4664,64 @@ def test_llm_from_dataloader(
4655
4664
def policy (td ):
4656
4665
if str2str :
4657
4666
if not td .shape :
4658
- td ["action" ] = "<nothing>"
4667
+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = "<nothing>"
4659
4668
else :
4660
- td ["action" ] = NonTensorStack (
4669
+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = NonTensorStack (
4661
4670
* ["<nothing>" for _ in range (td .shape [0 ])]
4662
4671
)
4663
4672
else :
4664
- td ["action" ] = torch .ones (td .shape + (1 ,), dtype = torch .int64 )
4673
+ td [LLMEnv ._DEFAULT_ACTION_KEY ] = torch .ones (
4674
+ td .shape + (1 ,), dtype = torch .int64
4675
+ )
4665
4676
return td
4666
4677
4667
4678
if batched :
4668
4679
# Tell the env that we want 3 sub-envs
4669
4680
r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = [3 ]))
4670
4681
assert r .ndim == 2
4671
4682
if str2str :
4672
- assert isinstance (r [0 , 0 ]["observation" ], str )
4673
- assert isinstance (r [0 , 1 ]["observation" ], str )
4683
+ assert isinstance (r [0 , 0 ][LLMEnv . _DEFAULT_STR_KEY ], str )
4684
+ assert isinstance (r [0 , 1 ][LLMEnv . _DEFAULT_STR_KEY ], str )
4674
4685
assert (
4675
- r [0 , 0 ]["observation" ]
4676
- == r [0 , 1 ]["observation" ][: - len (r [0 , 0 ]["action" ])]
4686
+ r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4687
+ == r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4688
+ : - len (r [0 , 0 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4689
+ ]
4677
4690
)
4678
4691
assert (
4679
- r [0 , 1 ]["observation" ]
4680
- == r [0 , 2 ]["observation" ][: - len (r [0 , 1 ]["action" ])]
4692
+ r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4693
+ == r [0 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4694
+ : - len (r [0 , 1 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4695
+ ]
4681
4696
)
4682
4697
assert (
4683
- r [- 1 , 0 ]["observation" ]
4684
- == r [- 1 , 1 ]["observation" ][: - len (r [- 1 , 0 ]["action" ])]
4698
+ r [- 1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4699
+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4700
+ : - len (r [- 1 , 0 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4701
+ ]
4685
4702
)
4686
4703
assert (
4687
- r [- 1 , 1 ]["observation" ]
4688
- == r [- 1 , 2 ]["observation" ][: - len (r [- 1 , 1 ]["action" ])]
4704
+ r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4705
+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4706
+ : - len (r [- 1 , 1 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4707
+ ]
4689
4708
)
4690
4709
else :
4691
- assert (r [0 , 0 ]["observation" ] == r [0 , 1 ]["observation" ][:- 1 ]).all ()
4692
- assert (r [0 , 1 ]["observation" ] == r [0 , 2 ]["observation" ][:- 1 ]).all ()
4693
4710
assert (
4694
- r [- 1 , 0 ]["observation" ] == r [- 1 , 1 ]["observation" ][:- 1 ]
4711
+ r [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4712
+ == r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4713
+ ).all ()
4714
+ assert (
4715
+ r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4716
+ == r [0 , 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4695
4717
).all ()
4696
4718
assert (
4697
- r [- 1 , 1 ]["observation" ] == r [- 1 , 2 ]["observation" ][:- 1 ]
4719
+ r [- 1 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4720
+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4721
+ ).all ()
4722
+ assert (
4723
+ r [- 1 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4724
+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4698
4725
).all ()
4699
4726
else :
4700
4727
r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = []))
@@ -4720,7 +4747,7 @@ def test_llm_from_dataloader_repeats(
4720
4747
if str2str :
4721
4748
kwargs = {
4722
4749
"dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4723
- "data_keys" : ["observation" ],
4750
+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
4724
4751
"example_data" : "a string!" ,
4725
4752
"repeats" : repeats ,
4726
4753
}
@@ -4731,12 +4758,19 @@ def test_llm_from_dataloader_repeats(
4731
4758
"dataloader" : self .DummyTensorDataLoader (
4732
4759
padding = True , batch_size = batch_size
4733
4760
),
4734
- "data_keys" : ["observation" ],
4761
+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
4735
4762
"data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4736
4763
"stack_method" : stack_method ,
4737
4764
"repeats" : repeats ,
4738
4765
}
4739
- kwargs .update ({"str2str" : str2str , "device" : device })
4766
+ kwargs .update (
4767
+ {
4768
+ "str2str" : str2str ,
4769
+ "device" : device ,
4770
+ "has_attention" : False ,
4771
+ "no_stack" : False ,
4772
+ }
4773
+ )
4740
4774
env = LLMEnv .from_dataloader (** kwargs )
4741
4775
assert env .transform .repeats == repeats
4742
4776
@@ -4746,13 +4780,15 @@ def test_llm_from_dataloader_repeats(
4746
4780
def policy (td ):
4747
4781
if str2str :
4748
4782
if not td .shape :
4749
- td ["action" ] = "<nothing>"
4783
+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = "<nothing>"
4750
4784
else :
4751
- td ["action" ] = NonTensorStack (
4785
+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = NonTensorStack (
4752
4786
* ["<nothing>" for _ in range (td .shape [0 ])]
4753
4787
)
4754
4788
else :
4755
- td ["action" ] = torch .ones (td .shape + (1 ,), dtype = torch .int64 )
4789
+ td [LLMEnv ._DEFAULT_ACTION_KEY ] = torch .ones (
4790
+ td .shape + (1 ,), dtype = torch .int64
4791
+ )
4756
4792
return td
4757
4793
4758
4794
if batched :
@@ -4768,34 +4804,58 @@ def policy(td):
4768
4804
r_reset = r [..., ::max_steps ]
4769
4805
if not batched :
4770
4806
if str2str :
4771
- assert r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4772
- assert r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4773
- assert r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4807
+ assert (
4808
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4809
+ == r_reset [..., 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4810
+ )
4811
+ assert (
4812
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4813
+ == r_reset [..., 2 ][LLMEnv ._DEFAULT_STR_KEY ]
4814
+ )
4815
+ assert (
4816
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4817
+ != r_reset [..., 3 ][LLMEnv ._DEFAULT_STR_KEY ]
4818
+ )
4774
4819
else :
4775
4820
assert (
4776
- r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4821
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4822
+ == r_reset [..., 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4777
4823
).all ()
4778
4824
assert (
4779
- r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4825
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4826
+ == r_reset [..., 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4780
4827
).all ()
4781
4828
assert (
4782
- r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4829
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4830
+ != r_reset [..., 3 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4783
4831
).any ()
4784
4832
else :
4785
4833
# When batched, each block contains the 3 reset packs
4786
4834
if str2str :
4787
- assert r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4788
- assert r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4789
- assert r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4835
+ assert (
4836
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4837
+ == r_reset [1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4838
+ )
4839
+ assert (
4840
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4841
+ == r_reset [2 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4842
+ )
4843
+ assert (
4844
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4845
+ != r_reset [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4846
+ )
4790
4847
else :
4791
4848
assert (
4792
- r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4849
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4850
+ == r_reset [1 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4793
4851
).all ()
4794
4852
assert (
4795
- r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4853
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4854
+ == r_reset [2 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4796
4855
).all ()
4797
4856
assert (
4798
- r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4857
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4858
+ != r_reset [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4799
4859
).any ()
4800
4860
4801
4861
@pytest .mark .parametrize (
@@ -4829,7 +4889,7 @@ def test_done_and_reward(
4829
4889
if str2str :
4830
4890
kwargs = {
4831
4891
"dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4832
- "data_keys" : ["observation" ],
4892
+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
4833
4893
"example_data" : "a string!" ,
4834
4894
"repeats" : repeats ,
4835
4895
"assign_reward" : assign_reward ,
@@ -4842,20 +4902,27 @@ def test_done_and_reward(
4842
4902
"dataloader" : self .DummyTensorDataLoader (
4843
4903
padding = True , batch_size = batch_size
4844
4904
),
4845
- "data_keys" : ["observation" ],
4905
+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
4846
4906
"data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4847
4907
"stack_method" : stack_method ,
4848
4908
"repeats" : repeats ,
4849
4909
"assign_reward" : assign_reward ,
4850
4910
"assign_done" : assign_done ,
4851
4911
}
4852
- kwargs .update ({"str2str" : str2str , "device" : device })
4912
+ kwargs .update (
4913
+ {
4914
+ "str2str" : str2str ,
4915
+ "device" : device ,
4916
+ "has_attention" : False ,
4917
+ "no_stack" : False ,
4918
+ }
4919
+ )
4853
4920
env = LLMEnv .from_dataloader (** kwargs )
4854
4921
# We want to make sure that transforms that rely on the done state work appropriately
4855
4922
env .append_transform (StepCounter (max_steps = 10 ))
4856
4923
4857
4924
def policy (td ):
4858
- td ["action" ] = torch .ones (
4925
+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = torch .ones (
4859
4926
td .shape + (torch .randint (10 , (1 ,)).item (),), dtype = torch .int64
4860
4927
)
4861
4928
return td
0 commit comments