@@ -1407,12 +1407,12 @@ def test_insert(self):
1407
1407
class TestR3M :
1408
1408
@pytest .mark .parametrize ("tensor_pixels_key" , [None , ["funny_key" ]])
1409
1409
def test_r3m_instantiation (self , model , tensor_pixels_key , device ):
1410
- keys_in = ["next_pixels" ]
1411
- keys_out = ["next_vec" ]
1410
+ in_keys = ["next_pixels" ]
1411
+ out_keys = ["next_vec" ]
1412
1412
r3m = R3MTransform (
1413
1413
model ,
1414
- in_keys = keys_in ,
1415
- keys_out = keys_out ,
1414
+ in_keys = in_keys ,
1415
+ out_keys = out_keys ,
1416
1416
tensor_pixels_keys = tensor_pixels_key ,
1417
1417
)
1418
1418
base_env = DiscreteActionConvMockEnvNumpy ().to (device )
@@ -1438,12 +1438,12 @@ def test_r3m_instantiation(self, model, tensor_pixels_key, device):
1438
1438
],
1439
1439
)
1440
1440
def test_r3m_mult_images (self , model , device , stack_images , parallel ):
1441
- keys_in = ["next_pixels" , "next_pixels2" ]
1442
- keys_out = ["next_vec" ] if stack_images else ["next_vec" , "next_vec2" ]
1441
+ in_keys = ["next_pixels" , "next_pixels2" ]
1442
+ out_keys = ["next_vec" ] if stack_images else ["next_vec" , "next_vec2" ]
1443
1443
r3m = R3MTransform (
1444
1444
model ,
1445
- in_keys = keys_in ,
1446
- keys_out = keys_out ,
1445
+ in_keys = in_keys ,
1446
+ out_keys = out_keys ,
1447
1447
stack_images = stack_images ,
1448
1448
)
1449
1449
@@ -1487,13 +1487,13 @@ def base_env_constructor():
1487
1487
transformed_env .close ()
1488
1488
1489
1489
def test_r3m_parallel (self , model , device ):
1490
- keys_in = ["next_pixels" ]
1491
- keys_out = ["next_vec" ]
1490
+ in_keys = ["next_pixels" ]
1491
+ out_keys = ["next_vec" ]
1492
1492
tensor_pixels_key = None
1493
1493
r3m = R3MTransform (
1494
1494
model ,
1495
- in_keys = keys_in ,
1496
- keys_out = keys_out ,
1495
+ in_keys = in_keys ,
1496
+ out_keys = out_keys ,
1497
1497
tensor_pixels_keys = tensor_pixels_key ,
1498
1498
)
1499
1499
base_env = ParallelEnv (4 , lambda : DiscreteActionConvMockEnvNumpy ().to (device ))
@@ -1562,12 +1562,12 @@ def test_r3mnet_transform_observation_spec(
1562
1562
1563
1563
@pytest .mark .parametrize ("tensor_pixels_key" , [None , ["funny_key" ]])
1564
1564
def test_r3m_spec_against_real (self , model , tensor_pixels_key , device ):
1565
- keys_in = ["next_pixels" ]
1566
- keys_out = ["next_vec" ]
1565
+ in_keys = ["next_pixels" ]
1566
+ out_keys = ["next_vec" ]
1567
1567
r3m = R3MTransform (
1568
1568
model ,
1569
- in_keys = keys_in ,
1570
- keys_out = keys_out ,
1569
+ in_keys = in_keys ,
1570
+ out_keys = out_keys ,
1571
1571
tensor_pixels_keys = tensor_pixels_key ,
1572
1572
)
1573
1573
base_env = DiscreteActionConvMockEnvNumpy ().to (device )
@@ -1588,12 +1588,12 @@ def test_r3m_spec_against_real(self, model, tensor_pixels_key, device):
1588
1588
class TestVIP :
1589
1589
@pytest .mark .parametrize ("tensor_pixels_key" , [None , ["funny_key" ]])
1590
1590
def test_vip_instantiation (self , model , tensor_pixels_key , device ):
1591
- keys_in = ["next_pixels" ]
1592
- keys_out = ["next_vec" ]
1591
+ in_keys = ["next_pixels" ]
1592
+ out_keys = ["next_vec" ]
1593
1593
vip = VIPTransform (
1594
1594
model ,
1595
- in_keys = keys_in ,
1596
- keys_out = keys_out ,
1595
+ in_keys = in_keys ,
1596
+ out_keys = out_keys ,
1597
1597
tensor_pixels_keys = tensor_pixels_key ,
1598
1598
)
1599
1599
base_env = DiscreteActionConvMockEnvNumpy ().to (device )
@@ -1613,12 +1613,12 @@ def test_vip_instantiation(self, model, tensor_pixels_key, device):
1613
1613
@pytest .mark .parametrize ("stack_images" , [True , False ])
1614
1614
@pytest .mark .parametrize ("parallel" , [True , False ])
1615
1615
def test_vip_mult_images (self , model , device , stack_images , parallel ):
1616
- keys_in = ["next_pixels" , "next_pixels2" ]
1617
- keys_out = ["next_vec" ] if stack_images else ["next_vec" , "next_vec2" ]
1616
+ in_keys = ["next_pixels" , "next_pixels2" ]
1617
+ out_keys = ["next_vec" ] if stack_images else ["next_vec" , "next_vec2" ]
1618
1618
vip = VIPTransform (
1619
1619
model ,
1620
- in_keys = keys_in ,
1621
- keys_out = keys_out ,
1620
+ in_keys = in_keys ,
1621
+ out_keys = out_keys ,
1622
1622
stack_images = stack_images ,
1623
1623
)
1624
1624
@@ -1662,13 +1662,13 @@ def base_env_constructor():
1662
1662
transformed_env .close ()
1663
1663
1664
1664
def test_vip_parallel (self , model , device ):
1665
- keys_in = ["next_pixels" ]
1666
- keys_out = ["next_vec" ]
1665
+ in_keys = ["next_pixels" ]
1666
+ out_keys = ["next_vec" ]
1667
1667
tensor_pixels_key = None
1668
1668
vip = VIPTransform (
1669
1669
model ,
1670
- in_keys = keys_in ,
1671
- keys_out = keys_out ,
1670
+ in_keys = in_keys ,
1671
+ out_keys = out_keys ,
1672
1672
tensor_pixels_keys = tensor_pixels_key ,
1673
1673
)
1674
1674
base_env = ParallelEnv (4 , lambda : DiscreteActionConvMockEnvNumpy ().to (device ))
@@ -1688,13 +1688,13 @@ def test_vip_parallel(self, model, device):
1688
1688
del transformed_env
1689
1689
1690
1690
def test_vip_parallel_reward (self , model , device ):
1691
- keys_in = ["next_pixels" ]
1692
- keys_out = ["next_vec" ]
1691
+ in_keys = ["next_pixels" ]
1692
+ out_keys = ["next_vec" ]
1693
1693
tensor_pixels_key = None
1694
1694
vip = VIPRewardTransform (
1695
1695
model ,
1696
- keys_in = keys_in ,
1697
- keys_out = keys_out ,
1696
+ in_keys = in_keys ,
1697
+ out_keys = out_keys ,
1698
1698
tensor_pixels_keys = tensor_pixels_key ,
1699
1699
)
1700
1700
base_env = ParallelEnv (4 , lambda : DiscreteActionConvMockEnvNumpy ().to (device ))
@@ -1802,12 +1802,12 @@ def test_vipnet_transform_observation_spec(
1802
1802
1803
1803
@pytest .mark .parametrize ("tensor_pixels_key" , [None , ["funny_key" ]])
1804
1804
def test_vip_spec_against_real (self , model , tensor_pixels_key , device ):
1805
- keys_in = ["next_pixels" ]
1806
- keys_out = ["next_vec" ]
1805
+ in_keys = ["next_pixels" ]
1806
+ out_keys = ["next_vec" ]
1807
1807
vip = VIPTransform (
1808
1808
model ,
1809
- in_keys = keys_in ,
1810
- keys_out = keys_out ,
1809
+ in_keys = in_keys ,
1810
+ out_keys = out_keys ,
1811
1811
tensor_pixels_keys = tensor_pixels_key ,
1812
1812
)
1813
1813
base_env = DiscreteActionConvMockEnvNumpy ().to (device )
0 commit comments