File tree Expand file tree Collapse file tree 2 files changed +10
-3
lines changed Expand file tree Collapse file tree 2 files changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -3940,10 +3940,17 @@ def test_device_cast(self):
3940
3940
assert comp ["nontensor" ].device == torch .device ("cpu" )
3941
3941
3942
3942
def test_encode (self ):
3943
+ nt = NonTensor (shape = (1 ,))
3944
+ r = nt .encode ("a string" )
3945
+ assert isinstance (r , NonTensorData )
3946
+ assert r .shape == nt .shape
3947
+
3943
3948
comp = Composite (device = "cpu" )
3944
- comp ["nontensor" ] = NonTensor ( shape = ())
3949
+ comp ["nontensor" ] = nt
3945
3950
r = comp .encode ({"nontensor" : "a string" })
3946
- assert isinstance (r ["nontensor" ], str )
3951
+ assert isinstance (r , TensorDict )
3952
+ assert isinstance (r .get ("nontensor" ), NonTensorData )
3953
+ assert r .get ("nontensor" ).shape == (1 ,)
3947
3954
3948
3955
3949
3956
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "not cuda device" )
Original file line number Diff line number Diff line change @@ -2918,7 +2918,7 @@ def _encode_eager(
2918
2918
* ,
2919
2919
ignore_device : bool = False ,
2920
2920
) -> torch .Tensor | TensorDictBase :
2921
- return val
2921
+ return NonTensorData ( val , device = self . device , batch_size = self . shape )
2922
2922
2923
2923
2924
2924
class _UnboundedMeta (abc .ABCMeta ):
You can’t perform that action at this time.
0 commit comments