Skip to content

Commit ecedcf1

Browse files
authored
[Naming] Fixing key names (#668)
* fixing key names * fixing key names
1 parent e803ffa commit ecedcf1

File tree

9 files changed

+168
-168
lines changed

9 files changed

+168
-168
lines changed

test/test_transforms.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,12 +1407,12 @@ def test_insert(self):
14071407
class TestR3M:
14081408
@pytest.mark.parametrize("tensor_pixels_key", [None, ["funny_key"]])
14091409
def test_r3m_instantiation(self, model, tensor_pixels_key, device):
1410-
keys_in = ["next_pixels"]
1411-
keys_out = ["next_vec"]
1410+
in_keys = ["next_pixels"]
1411+
out_keys = ["next_vec"]
14121412
r3m = R3MTransform(
14131413
model,
1414-
in_keys=keys_in,
1415-
keys_out=keys_out,
1414+
in_keys=in_keys,
1415+
out_keys=out_keys,
14161416
tensor_pixels_keys=tensor_pixels_key,
14171417
)
14181418
base_env = DiscreteActionConvMockEnvNumpy().to(device)
@@ -1438,12 +1438,12 @@ def test_r3m_instantiation(self, model, tensor_pixels_key, device):
14381438
],
14391439
)
14401440
def test_r3m_mult_images(self, model, device, stack_images, parallel):
1441-
keys_in = ["next_pixels", "next_pixels2"]
1442-
keys_out = ["next_vec"] if stack_images else ["next_vec", "next_vec2"]
1441+
in_keys = ["next_pixels", "next_pixels2"]
1442+
out_keys = ["next_vec"] if stack_images else ["next_vec", "next_vec2"]
14431443
r3m = R3MTransform(
14441444
model,
1445-
in_keys=keys_in,
1446-
keys_out=keys_out,
1445+
in_keys=in_keys,
1446+
out_keys=out_keys,
14471447
stack_images=stack_images,
14481448
)
14491449

@@ -1487,13 +1487,13 @@ def base_env_constructor():
14871487
transformed_env.close()
14881488

14891489
def test_r3m_parallel(self, model, device):
1490-
keys_in = ["next_pixels"]
1491-
keys_out = ["next_vec"]
1490+
in_keys = ["next_pixels"]
1491+
out_keys = ["next_vec"]
14921492
tensor_pixels_key = None
14931493
r3m = R3MTransform(
14941494
model,
1495-
in_keys=keys_in,
1496-
keys_out=keys_out,
1495+
in_keys=in_keys,
1496+
out_keys=out_keys,
14971497
tensor_pixels_keys=tensor_pixels_key,
14981498
)
14991499
base_env = ParallelEnv(4, lambda: DiscreteActionConvMockEnvNumpy().to(device))
@@ -1562,12 +1562,12 @@ def test_r3mnet_transform_observation_spec(
15621562

15631563
@pytest.mark.parametrize("tensor_pixels_key", [None, ["funny_key"]])
15641564
def test_r3m_spec_against_real(self, model, tensor_pixels_key, device):
1565-
keys_in = ["next_pixels"]
1566-
keys_out = ["next_vec"]
1565+
in_keys = ["next_pixels"]
1566+
out_keys = ["next_vec"]
15671567
r3m = R3MTransform(
15681568
model,
1569-
in_keys=keys_in,
1570-
keys_out=keys_out,
1569+
in_keys=in_keys,
1570+
out_keys=out_keys,
15711571
tensor_pixels_keys=tensor_pixels_key,
15721572
)
15731573
base_env = DiscreteActionConvMockEnvNumpy().to(device)
@@ -1588,12 +1588,12 @@ def test_r3m_spec_against_real(self, model, tensor_pixels_key, device):
15881588
class TestVIP:
15891589
@pytest.mark.parametrize("tensor_pixels_key", [None, ["funny_key"]])
15901590
def test_vip_instantiation(self, model, tensor_pixels_key, device):
1591-
keys_in = ["next_pixels"]
1592-
keys_out = ["next_vec"]
1591+
in_keys = ["next_pixels"]
1592+
out_keys = ["next_vec"]
15931593
vip = VIPTransform(
15941594
model,
1595-
in_keys=keys_in,
1596-
keys_out=keys_out,
1595+
in_keys=in_keys,
1596+
out_keys=out_keys,
15971597
tensor_pixels_keys=tensor_pixels_key,
15981598
)
15991599
base_env = DiscreteActionConvMockEnvNumpy().to(device)
@@ -1613,12 +1613,12 @@ def test_vip_instantiation(self, model, tensor_pixels_key, device):
16131613
@pytest.mark.parametrize("stack_images", [True, False])
16141614
@pytest.mark.parametrize("parallel", [True, False])
16151615
def test_vip_mult_images(self, model, device, stack_images, parallel):
1616-
keys_in = ["next_pixels", "next_pixels2"]
1617-
keys_out = ["next_vec"] if stack_images else ["next_vec", "next_vec2"]
1616+
in_keys = ["next_pixels", "next_pixels2"]
1617+
out_keys = ["next_vec"] if stack_images else ["next_vec", "next_vec2"]
16181618
vip = VIPTransform(
16191619
model,
1620-
in_keys=keys_in,
1621-
keys_out=keys_out,
1620+
in_keys=in_keys,
1621+
out_keys=out_keys,
16221622
stack_images=stack_images,
16231623
)
16241624

@@ -1662,13 +1662,13 @@ def base_env_constructor():
16621662
transformed_env.close()
16631663

16641664
def test_vip_parallel(self, model, device):
1665-
keys_in = ["next_pixels"]
1666-
keys_out = ["next_vec"]
1665+
in_keys = ["next_pixels"]
1666+
out_keys = ["next_vec"]
16671667
tensor_pixels_key = None
16681668
vip = VIPTransform(
16691669
model,
1670-
in_keys=keys_in,
1671-
keys_out=keys_out,
1670+
in_keys=in_keys,
1671+
out_keys=out_keys,
16721672
tensor_pixels_keys=tensor_pixels_key,
16731673
)
16741674
base_env = ParallelEnv(4, lambda: DiscreteActionConvMockEnvNumpy().to(device))
@@ -1688,13 +1688,13 @@ def test_vip_parallel(self, model, device):
16881688
del transformed_env
16891689

16901690
def test_vip_parallel_reward(self, model, device):
1691-
keys_in = ["next_pixels"]
1692-
keys_out = ["next_vec"]
1691+
in_keys = ["next_pixels"]
1692+
out_keys = ["next_vec"]
16931693
tensor_pixels_key = None
16941694
vip = VIPRewardTransform(
16951695
model,
1696-
keys_in=keys_in,
1697-
keys_out=keys_out,
1696+
in_keys=in_keys,
1697+
out_keys=out_keys,
16981698
tensor_pixels_keys=tensor_pixels_key,
16991699
)
17001700
base_env = ParallelEnv(4, lambda: DiscreteActionConvMockEnvNumpy().to(device))
@@ -1802,12 +1802,12 @@ def test_vipnet_transform_observation_spec(
18021802

18031803
@pytest.mark.parametrize("tensor_pixels_key", [None, ["funny_key"]])
18041804
def test_vip_spec_against_real(self, model, tensor_pixels_key, device):
1805-
keys_in = ["next_pixels"]
1806-
keys_out = ["next_vec"]
1805+
in_keys = ["next_pixels"]
1806+
out_keys = ["next_vec"]
18071807
vip = VIPTransform(
18081808
model,
1809-
in_keys=keys_in,
1810-
keys_out=keys_out,
1809+
in_keys=in_keys,
1810+
out_keys=out_keys,
18111811
tensor_pixels_keys=tensor_pixels_key,
18121812
)
18131813
base_env = DiscreteActionConvMockEnvNumpy().to(device)

torchrl/envs/transforms/r3m.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True):
6161
f"model {model_name} is currently not supported by R3M"
6262
)
6363
convnet.fc = Identity()
64-
super().__init__(in_keys=in_keys, keys_out=out_keys)
64+
super().__init__(in_keys=in_keys, out_keys=out_keys)
6565
self.convnet = convnet
6666
self.del_keys = del_keys
6767

@@ -92,11 +92,11 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
9292

9393
observation_spec = CompositeSpec(**observation_spec)
9494
if self.del_keys:
95-
for key_in in keys:
96-
del observation_spec[key_in]
95+
for in_key in keys:
96+
del observation_spec[in_key]
9797

98-
for key_out in self.keys_out:
99-
observation_spec[key_out] = NdUnboundedContinuousTensorSpec(
98+
for out_key in self.out_keys:
99+
observation_spec[out_key] = NdUnboundedContinuousTensorSpec(
100100
shape=torch.Size([self.outdim]), device=device
101101
)
102102

@@ -161,7 +161,7 @@ class R3MTransform(Compose):
161161
model_name (str): one of resnet50, resnet34 or resnet18
162162
in_keys (list of str, optional): list of input keys. If left empty, the
163163
"next_pixels" key is assumed.
164-
keys_out (list of str, optional): list of output keys. If left empty,
164+
out_keys (list of str, optional): list of output keys. If left empty,
165165
"next_r3m_vec" is assumed.
166166
size (int, optional): Size of the image to feed to resnet.
167167
Defaults to 244.
@@ -187,105 +187,105 @@ def __init__(
187187
self,
188188
model_name: str,
189189
in_keys: List[str] = None,
190-
keys_out: List[str] = None,
190+
out_keys: List[str] = None,
191191
size: int = 244,
192192
stack_images: bool = True,
193193
download: bool = False,
194194
download_path: Optional[str] = None,
195195
tensor_pixels_keys: List[str] = None,
196196
):
197197
super().__init__()
198-
self.keys_in = in_keys
198+
self.in_keys = in_keys
199199
self.download = download
200200
self.download_path = download_path
201201
self.model_name = model_name
202-
self.keys_out = keys_out
202+
self.out_keys = out_keys
203203
self.size = size
204204
self.stack_images = stack_images
205205
self.tensor_pixels_keys = tensor_pixels_keys
206206

207207
def _init(self):
208-
keys_in = self.keys_in
208+
in_keys = self.in_keys
209209
model_name = self.model_name
210-
keys_out = self.keys_out
210+
out_keys = self.out_keys
211211
size = self.size
212212
stack_images = self.stack_images
213213
tensor_pixels_keys = self.tensor_pixels_keys
214214

215215
# ToTensor
216216
transforms = []
217217
if tensor_pixels_keys:
218-
for i in range(len(keys_in)):
218+
for i in range(len(in_keys)):
219219
transforms.append(
220220
CatTensors(
221-
in_keys=[keys_in[i]],
221+
in_keys=[in_keys[i]],
222222
out_key=tensor_pixels_keys[i],
223223
del_keys=False,
224224
)
225225
)
226226

227227
totensor = ToTensorImage(
228228
unsqueeze=False,
229-
in_keys=keys_in,
229+
in_keys=in_keys,
230230
)
231231
transforms.append(totensor)
232232

233233
# Normalize
234234
mean = [0.485, 0.456, 0.406]
235235
std = [0.229, 0.224, 0.225]
236236
normalize = ObservationNorm(
237-
in_keys=keys_in,
237+
in_keys=in_keys,
238238
loc=torch.tensor(mean).view(3, 1, 1),
239239
scale=torch.tensor(std).view(3, 1, 1),
240240
standard_normal=True,
241241
)
242242
transforms.append(normalize)
243243

244244
# Resize: note that resize is a no-op if the tensor has the desired size already
245-
resize = Resize(size, size, in_keys=keys_in)
245+
resize = Resize(size, size, in_keys=in_keys)
246246
transforms.append(resize)
247247

248248
# R3M
249-
if keys_out is None:
249+
if out_keys is None:
250250
if stack_images:
251-
keys_out = ["next_r3m_vec"]
251+
out_keys = ["next_r3m_vec"]
252252
else:
253-
keys_out = [f"next_r3m_vec_{i}" for i in range(len(keys_in))]
254-
elif stack_images and len(keys_out) != 1:
253+
out_keys = [f"next_r3m_vec_{i}" for i in range(len(in_keys))]
254+
elif stack_images and len(out_keys) != 1:
255255
raise ValueError(
256-
f"key_out must be of length 1 if stack_images is True. Got keys_out={keys_out}"
256+
f"out_key must be of length 1 if stack_images is True. Got out_keys={out_keys}"
257257
)
258-
elif not stack_images and len(keys_out) != len(keys_in):
258+
elif not stack_images and len(out_keys) != len(in_keys):
259259
raise ValueError(
260-
"key_out must be of length equal to in_keys if stack_images is False."
260+
"out_key must be of length equal to in_keys if stack_images is False."
261261
)
262262

263-
if stack_images and len(keys_in) > 1:
263+
if stack_images and len(in_keys) > 1:
264264
if self.is_3d:
265265
unsqueeze = UnsqueezeTransform(
266-
in_keys=keys_in,
267-
keys_out=keys_in,
266+
in_keys=in_keys,
267+
out_keys=in_keys,
268268
unsqueeze_dim=-4,
269269
)
270270
transforms.append(unsqueeze)
271271

272272
cattensors = CatTensors(
273-
keys_in,
274-
keys_out[0],
273+
in_keys,
274+
out_keys[0],
275275
dim=-4,
276276
)
277277
network = _R3MNet(
278-
in_keys=keys_out,
279-
out_keys=keys_out,
278+
in_keys=out_keys,
279+
out_keys=out_keys,
280280
model_name=model_name,
281281
del_keys=False,
282282
)
283-
flatten = FlattenObservation(-2, -1, keys_out)
283+
flatten = FlattenObservation(-2, -1, out_keys)
284284
transforms = [*transforms, cattensors, network, flatten]
285285
else:
286286
network = _R3MNet(
287-
in_keys=keys_in,
288-
out_keys=keys_out,
287+
in_keys=in_keys,
288+
out_keys=out_keys,
289289
model_name=model_name,
290290
del_keys=True,
291291
)

0 commit comments

Comments
 (0)