@@ -406,7 +406,7 @@ class TestTransforms:
406
406
def test_resize (self , interpolation , keys , nchannels , batch , device ):
407
407
torch .manual_seed (0 )
408
408
dont_touch = torch .randn (* batch , nchannels , 16 , 16 , device = device )
409
- resize = Resize (w = 20 , h = 21 , interpolation = interpolation , keys_in = keys )
409
+ resize = Resize (w = 20 , h = 21 , interpolation = interpolation , in_keys = keys )
410
410
td = TensorDict (
411
411
{
412
412
key : torch .randn (* batch , nchannels , 16 , 16 , device = device )
@@ -444,7 +444,7 @@ def test_resize(self, interpolation, keys, nchannels, batch, device):
444
444
def test_centercrop (self , keys , h , nchannels , batch , device ):
445
445
torch .manual_seed (0 )
446
446
dont_touch = torch .randn (* batch , nchannels , 16 , 16 , device = device )
447
- cc = CenterCrop (w = 20 , h = h , keys_in = keys )
447
+ cc = CenterCrop (w = 20 , h = h , in_keys = keys )
448
448
if h is None :
449
449
h = 20
450
450
td = TensorDict (
@@ -485,7 +485,7 @@ def test_flatten(self, keys, size, nchannels, batch, device):
485
485
torch .manual_seed (0 )
486
486
dont_touch = torch .randn (* batch , * size , nchannels , 16 , 16 , device = device )
487
487
start_dim = - 3 - len (size )
488
- flatten = FlattenObservation (start_dim , - 3 , keys_in = keys )
488
+ flatten = FlattenObservation (start_dim , - 3 , in_keys = keys )
489
489
td = TensorDict (
490
490
{
491
491
key : torch .randn (* batch , * size , nchannels , 16 , 16 , device = device )
@@ -527,7 +527,7 @@ def test_flatten(self, keys, size, nchannels, batch, device):
527
527
def test_unsqueeze (self , keys , size , nchannels , batch , device , unsqueeze_dim ):
528
528
torch .manual_seed (0 )
529
529
dont_touch = torch .randn (* batch , * size , nchannels , 16 , 16 , device = device )
530
- unsqueeze = UnsqueezeTransform (unsqueeze_dim , keys_in = keys )
530
+ unsqueeze = UnsqueezeTransform (unsqueeze_dim , in_keys = keys )
531
531
td = TensorDict (
532
532
{
533
533
key : torch .randn (* batch , * size , nchannels , 16 , 16 , device = device )
@@ -586,7 +586,7 @@ def test_unsqueeze_inv(
586
586
torch .manual_seed (0 )
587
587
keys_total = set (keys + keys_inv )
588
588
unsqueeze = UnsqueezeTransform (
589
- unsqueeze_dim , keys_in = keys , keys_inv_in = keys_inv
589
+ unsqueeze_dim , in_keys = keys , in_keys_inv = keys_inv
590
590
)
591
591
td = TensorDict (
592
592
{
@@ -621,7 +621,7 @@ def test_unsqueeze_inv(
621
621
def test_squeeze (self , keys , keys_inv , size , nchannels , batch , device , squeeze_dim ):
622
622
torch .manual_seed (0 )
623
623
keys_total = set (keys + keys_inv )
624
- squeeze = SqueezeTransform (squeeze_dim , keys_in = keys , keys_inv_in = keys_inv )
624
+ squeeze = SqueezeTransform (squeeze_dim , in_keys = keys , in_keys_inv = keys_inv )
625
625
td = TensorDict (
626
626
{
627
627
key : torch .randn (* batch , * size , nchannels , 16 , 16 , device = device )
@@ -656,7 +656,7 @@ def test_squeeze_inv(
656
656
):
657
657
torch .manual_seed (0 )
658
658
keys_total = set (keys + keys_inv )
659
- squeeze = SqueezeTransform (squeeze_dim , keys_in = keys , keys_inv_in = keys_inv )
659
+ squeeze = SqueezeTransform (squeeze_dim , in_keys = keys , in_keys_inv = keys_inv )
660
660
td = TensorDict (
661
661
{
662
662
key : torch .randn (* batch , * size , nchannels , 16 , 16 , device = device )
@@ -687,7 +687,7 @@ def test_squeeze_inv(
687
687
def test_grayscale (self , keys , device ):
688
688
torch .manual_seed (0 )
689
689
nchannels = 3
690
- gs = GrayScale (keys_in = keys )
690
+ gs = GrayScale (in_keys = keys )
691
691
dont_touch = torch .randn (1 , nchannels , 16 , 16 , device = device )
692
692
td = TensorDict (
693
693
{key : torch .randn (1 , nchannels , 16 , 16 , device = device ) for key in keys },
@@ -720,7 +720,7 @@ def test_grayscale(self, keys, device):
720
720
def test_totensorimage (self , keys , batch , device ):
721
721
torch .manual_seed (0 )
722
722
nchannels = 3
723
- totensorimage = ToTensorImage (keys_in = keys )
723
+ totensorimage = ToTensorImage (in_keys = keys )
724
724
dont_touch = torch .randn (* batch , nchannels , 16 , 16 , device = device )
725
725
td = TensorDict (
726
726
{
@@ -764,7 +764,7 @@ def test_totensorimage(self, keys, batch, device):
764
764
@pytest .mark .parametrize ("device" , get_available_devices ())
765
765
def test_compose (self , keys , batch , device , nchannels = 1 , N = 4 ):
766
766
torch .manual_seed (0 )
767
- t1 = CatFrames (keys_in = keys , N = 4 )
767
+ t1 = CatFrames (in_keys = keys , N = 4 )
768
768
t2 = FiniteTensorDictCheck ()
769
769
compose = Compose (t1 , t2 )
770
770
dont_touch = torch .randn (* batch , nchannels , 16 , 16 , device = device )
@@ -818,8 +818,8 @@ def test_compose_inv(self, keys_inv_1, keys_inv_2, device):
818
818
torch .manual_seed (0 )
819
819
keys_to_transform = set (keys_inv_1 + keys_inv_2 )
820
820
keys_total = set (["action_1" , "action_2" , "dont_touch" ])
821
- double2float_1 = DoubleToFloat (keys_inv_in = keys_inv_1 )
822
- double2float_2 = DoubleToFloat (keys_inv_in = keys_inv_2 )
821
+ double2float_1 = DoubleToFloat (in_keys_inv = keys_inv_1 )
822
+ double2float_2 = DoubleToFloat (in_keys_inv = keys_inv_2 )
823
823
compose = Compose (double2float_1 , double2float_2 )
824
824
td = TensorDict (
825
825
{
@@ -861,7 +861,7 @@ def test_observationnorm(
861
861
loc = loc .to (device )
862
862
if isinstance (scale , Tensor ):
863
863
scale = scale .to (device )
864
- on = ObservationNorm (loc , scale , keys_in = keys , standard_normal = standard_normal )
864
+ on = ObservationNorm (loc , scale , in_keys = keys , standard_normal = standard_normal )
865
865
dont_touch = torch .randn (1 , nchannels , 16 , 16 , device = device )
866
866
td = TensorDict (
867
867
{key : torch .zeros (1 , nchannels , 16 , 16 , device = device ) for key in keys }, [1 ]
@@ -910,7 +910,7 @@ def test_catframes_transform_observation_spec(self):
910
910
key1 = "first key"
911
911
key2 = "second key"
912
912
keys = [key1 , key2 ]
913
- cat_frames = CatFrames (N = N , keys_in = keys )
913
+ cat_frames = CatFrames (N = N , in_keys = keys )
914
914
mins = [0 , 0.5 ]
915
915
maxes = [0.5 , 1 ]
916
916
observation_spec = CompositeSpec (
@@ -953,7 +953,7 @@ def test_catframes_buffer_check_latest_frame(self, device):
953
953
key2_tensor = torch .ones (1 , 1 , 3 , 3 , device = device )
954
954
key_tensors = [key1_tensor , key2_tensor ]
955
955
td = TensorDict (dict (zip (keys , key_tensors )), [1 ], device = device )
956
- cat_frames = CatFrames (N = N , keys_in = keys )
956
+ cat_frames = CatFrames (N = N , in_keys = keys )
957
957
958
958
cat_frames (td )
959
959
latest_frame = td .get (key2 )
@@ -973,7 +973,7 @@ def test_catframes_reset(self, device):
973
973
key2_tensor = torch .ones (1 , 1 , 3 , 3 , device = device )
974
974
key_tensors = [key1_tensor , key2_tensor ]
975
975
td = TensorDict (dict (zip (keys , key_tensors )), [1 ], device = device )
976
- cat_frames = CatFrames (N = N , keys_in = keys )
976
+ cat_frames = CatFrames (N = N , in_keys = keys )
977
977
978
978
cat_frames (td )
979
979
buffer_length1 = len (cat_frames .buffer )
@@ -1014,7 +1014,7 @@ def test_finitetensordictcheck(self, device):
1014
1014
def test_double2float (self , keys , keys_inv , device ):
1015
1015
torch .manual_seed (0 )
1016
1016
keys_total = set (keys + keys_inv )
1017
- double2float = DoubleToFloat (keys_in = keys , keys_inv_in = keys_inv )
1017
+ double2float = DoubleToFloat (in_keys = keys , in_keys_inv = keys_inv )
1018
1018
dont_touch = torch .randn (1 , 3 , 3 , dtype = torch .double , device = device )
1019
1019
td = TensorDict (
1020
1020
{
@@ -1066,7 +1066,7 @@ def test_double2float(self, keys, keys_inv, device):
1066
1066
],
1067
1067
)
1068
1068
def test_cattensors (self , keys , device ):
1069
- cattensors = CatTensors (keys_in = keys , out_key = "observation_out" , dim = - 2 )
1069
+ cattensors = CatTensors (in_keys = keys , out_key = "observation_out" , dim = - 2 )
1070
1070
1071
1071
dont_touch = torch .randn (1 , 3 , 3 , dtype = torch .double , device = device )
1072
1072
td = TensorDict (
@@ -1235,7 +1235,7 @@ def test_reward_scaling(self, batch, scale, loc, keys, device):
1235
1235
keys_total = set ([])
1236
1236
else :
1237
1237
keys_total = set (keys )
1238
- reward_scaling = RewardScaling (keys_in = keys , scale = scale , loc = loc )
1238
+ reward_scaling = RewardScaling (in_keys = keys , scale = scale , loc = loc )
1239
1239
td = TensorDict (
1240
1240
{
1241
1241
** {key : torch .randn (* batch , 1 , device = device ) for key in keys_total },
@@ -1276,7 +1276,7 @@ def test_append(self):
1276
1276
key = list (obs_spec .keys ())[0 ]
1277
1277
1278
1278
env = TransformedEnv (env )
1279
- env .append_transform (CatFrames (N = 4 , cat_dim = - 1 , keys_in = [key ]))
1279
+ env .append_transform (CatFrames (N = 4 , cat_dim = - 1 , in_keys = [key ]))
1280
1280
assert isinstance (env .transform , Compose )
1281
1281
assert len (env .transform ) == 1
1282
1282
obs_spec = env .observation_spec
@@ -1301,7 +1301,7 @@ def test_insert(self):
1301
1301
assert env ._observation_spec is not None
1302
1302
assert env ._reward_spec is not None
1303
1303
1304
- env .insert_transform (0 , CatFrames (N = 4 , cat_dim = - 1 , keys_in = [key ]))
1304
+ env .insert_transform (0 , CatFrames (N = 4 , cat_dim = - 1 , in_keys = [key ]))
1305
1305
1306
1306
# transformed envs do not have spec after insert -- they need to be computed
1307
1307
assert env ._input_spec is None
@@ -1348,7 +1348,7 @@ def test_insert(self):
1348
1348
assert env ._observation_spec is None
1349
1349
assert env ._reward_spec is None
1350
1350
1351
- env .insert_transform (- 5 , CatFrames (N = 4 , cat_dim = - 1 , keys_in = [key ]))
1351
+ env .insert_transform (- 5 , CatFrames (N = 4 , cat_dim = - 1 , in_keys = [key ]))
1352
1352
assert isinstance (env .transform , Compose )
1353
1353
assert len (env .transform ) == 6
1354
1354
@@ -1411,7 +1411,7 @@ def test_r3m_instantiation(self, model, tensor_pixels_key, device):
1411
1411
keys_out = ["next_vec" ]
1412
1412
r3m = R3MTransform (
1413
1413
model ,
1414
- keys_in = keys_in ,
1414
+ in_keys = keys_in ,
1415
1415
keys_out = keys_out ,
1416
1416
tensor_pixels_keys = tensor_pixels_key ,
1417
1417
)
@@ -1442,7 +1442,7 @@ def test_r3m_mult_images(self, model, device, stack_images, parallel):
1442
1442
keys_out = ["next_vec" ] if stack_images else ["next_vec" , "next_vec2" ]
1443
1443
r3m = R3MTransform (
1444
1444
model ,
1445
- keys_in = keys_in ,
1445
+ in_keys = keys_in ,
1446
1446
keys_out = keys_out ,
1447
1447
stack_images = stack_images ,
1448
1448
)
@@ -1492,7 +1492,7 @@ def test_r3m_parallel(self, model, device):
1492
1492
tensor_pixels_key = None
1493
1493
r3m = R3MTransform (
1494
1494
model ,
1495
- keys_in = keys_in ,
1495
+ in_keys = keys_in ,
1496
1496
keys_out = keys_out ,
1497
1497
tensor_pixels_keys = tensor_pixels_key ,
1498
1498
)
@@ -1566,7 +1566,7 @@ def test_r3m_spec_against_real(self, model, tensor_pixels_key, device):
1566
1566
keys_out = ["next_vec" ]
1567
1567
r3m = R3MTransform (
1568
1568
model ,
1569
- keys_in = keys_in ,
1569
+ in_keys = keys_in ,
1570
1570
keys_out = keys_out ,
1571
1571
tensor_pixels_keys = tensor_pixels_key ,
1572
1572
)
@@ -1592,7 +1592,7 @@ def test_vip_instantiation(self, model, tensor_pixels_key, device):
1592
1592
keys_out = ["next_vec" ]
1593
1593
vip = VIPTransform (
1594
1594
model ,
1595
- keys_in = keys_in ,
1595
+ in_keys = keys_in ,
1596
1596
keys_out = keys_out ,
1597
1597
tensor_pixels_keys = tensor_pixels_key ,
1598
1598
)
@@ -1617,7 +1617,7 @@ def test_vip_mult_images(self, model, device, stack_images, parallel):
1617
1617
keys_out = ["next_vec" ] if stack_images else ["next_vec" , "next_vec2" ]
1618
1618
vip = VIPTransform (
1619
1619
model ,
1620
- keys_in = keys_in ,
1620
+ in_keys = keys_in ,
1621
1621
keys_out = keys_out ,
1622
1622
stack_images = stack_images ,
1623
1623
)
@@ -1667,7 +1667,7 @@ def test_vip_parallel(self, model, device):
1667
1667
tensor_pixels_key = None
1668
1668
vip = VIPTransform (
1669
1669
model ,
1670
- keys_in = keys_in ,
1670
+ in_keys = keys_in ,
1671
1671
keys_out = keys_out ,
1672
1672
tensor_pixels_keys = tensor_pixels_key ,
1673
1673
)
@@ -1741,7 +1741,7 @@ def test_vip_spec_against_real(self, model, tensor_pixels_key, device):
1741
1741
keys_out = ["next_vec" ]
1742
1742
vip = VIPTransform (
1743
1743
model ,
1744
- keys_in = keys_in ,
1744
+ in_keys = keys_in ,
1745
1745
keys_out = keys_out ,
1746
1746
tensor_pixels_keys = tensor_pixels_key ,
1747
1747
)
@@ -1762,7 +1762,7 @@ def test_batch_locked_transformed(device):
1762
1762
env = TransformedEnv (
1763
1763
MockBatchedLockedEnv (device ),
1764
1764
Compose (
1765
- ObservationNorm (keys_in = ["next_observation" ], loc = 0.5 , scale = 1.1 ),
1765
+ ObservationNorm (in_keys = ["next_observation" ], loc = 0.5 , scale = 1.1 ),
1766
1766
RewardClipping (0 , 0.1 ),
1767
1767
),
1768
1768
)
@@ -1786,7 +1786,7 @@ def test_batch_unlocked_transformed(device):
1786
1786
env = TransformedEnv (
1787
1787
MockBatchedUnLockedEnv (device ),
1788
1788
Compose (
1789
- ObservationNorm (keys_in = ["next_observation" ], loc = 0.5 , scale = 1.1 ),
1789
+ ObservationNorm (in_keys = ["next_observation" ], loc = 0.5 , scale = 1.1 ),
1790
1790
RewardClipping (0 , 0.1 ),
1791
1791
),
1792
1792
)
@@ -1806,7 +1806,7 @@ def test_batch_unlocked_with_batch_size_transformed(device):
1806
1806
env = TransformedEnv (
1807
1807
MockBatchedUnLockedEnv (device , batch_size = torch .Size ([2 ])),
1808
1808
Compose (
1809
- ObservationNorm (keys_in = ["next_observation" ], loc = 0.5 , scale = 1.1 ),
1809
+ ObservationNorm (in_keys = ["next_observation" ], loc = 0.5 , scale = 1.1 ),
1810
1810
RewardClipping (0 , 0.1 ),
1811
1811
),
1812
1812
)
0 commit comments