Skip to content

Commit 6196a95

Browse files
authored
[Naming] Rename keys_in to in_keys in transforms.py and related modules (#656)
* rename keys_in to in_keys * update format by pre-commit command
1 parent d9b6ed9 commit 6196a95

File tree

8 files changed

+219
-219
lines changed

8 files changed

+219
-219
lines changed

examples/dreamer/dreamer_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,13 @@ def make_env_transforms(
9393
if cfg.grayscale:
9494
env.append_transform(GrayScale())
9595
env.append_transform(FlattenObservation())
96-
env.append_transform(CatFrames(N=cfg.catframes, keys_in=["next_pixels"]))
96+
env.append_transform(CatFrames(N=cfg.catframes, in_keys=["next_pixels"]))
9797
if stats is None:
9898
obs_stats = {"loc": 0.0, "scale": 1.0}
9999
else:
100100
obs_stats = stats
101101
obs_stats["standard_normal"] = True
102-
env.append_transform(ObservationNorm(**obs_stats, keys_in=["next_pixels"]))
102+
env.append_transform(ObservationNorm(**obs_stats, in_keys=["next_pixels"]))
103103
if norm_rewards:
104104
reward_scaling = 1.0
105105
reward_loc = 0.0
@@ -118,7 +118,7 @@ def make_env_transforms(
118118
]
119119
float_to_double_list += ["action"] # DMControl requires double-precision
120120
env.append_transform(
121-
DoubleToFloat(keys_in=double_to_float_list, keys_inv_in=float_to_double_list)
121+
DoubleToFloat(in_keys=double_to_float_list, in_keys_inv=float_to_double_list)
122122
)
123123

124124
default_dict = {

test/test_env.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def create_env_fn():
254254
GymEnv(env_name, frame_skip=frame_skip, device=device),
255255
Compose(
256256
ObservationNorm(
257-
keys_in=["next_observation"], loc=0.5, scale=1.1
257+
in_keys=["next_observation"], loc=0.5, scale=1.1
258258
),
259259
RewardClipping(0, 0.1),
260260
),
@@ -275,7 +275,7 @@ def t_out():
275275
Compose(*[ToTensorImage(), RewardClipping(0, 0.1)])
276276
if not transformed_in
277277
else Compose(
278-
*[ObservationNorm(keys_in=["next_pixels"], loc=0, scale=1)]
278+
*[ObservationNorm(in_keys=["next_pixels"], loc=0, scale=1)]
279279
)
280280
)
281281

@@ -297,14 +297,14 @@ def t_out():
297297
return (
298298
Compose(
299299
ObservationNorm(
300-
keys_in=["next_observation"], loc=0.5, scale=1.1
300+
in_keys=["next_observation"], loc=0.5, scale=1.1
301301
),
302302
RewardClipping(0, 0.1),
303303
)
304304
if not transformed_in
305305
else Compose(
306306
ObservationNorm(
307-
keys_in=["next_observation"], loc=1.0, scale=1.0
307+
in_keys=["next_observation"], loc=1.0, scale=1.0
308308
)
309309
)
310310
)
@@ -458,8 +458,8 @@ def env1_maker():
458458
CatTensors(env1_obs_keys, "next_observation_stand", del_keys=False),
459459
CatTensors(env1_obs_keys, "next_observation"),
460460
DoubleToFloat(
461-
keys_in=["next_observation_stand", "next_observation"],
462-
keys_inv_in=["action"],
461+
in_keys=["next_observation_stand", "next_observation"],
462+
in_keys_inv=["action"],
463463
),
464464
),
465465
)
@@ -471,8 +471,8 @@ def env2_maker():
471471
CatTensors(env2_obs_keys, "next_observation_walk", del_keys=False),
472472
CatTensors(env2_obs_keys, "next_observation"),
473473
DoubleToFloat(
474-
keys_in=["next_observation_walk", "next_observation"],
475-
keys_inv_in=["action"],
474+
in_keys=["next_observation_walk", "next_observation"],
475+
in_keys_inv=["action"],
476476
),
477477
),
478478
)

test/test_transforms.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ class TestTransforms:
406406
def test_resize(self, interpolation, keys, nchannels, batch, device):
407407
torch.manual_seed(0)
408408
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
409-
resize = Resize(w=20, h=21, interpolation=interpolation, keys_in=keys)
409+
resize = Resize(w=20, h=21, interpolation=interpolation, in_keys=keys)
410410
td = TensorDict(
411411
{
412412
key: torch.randn(*batch, nchannels, 16, 16, device=device)
@@ -444,7 +444,7 @@ def test_resize(self, interpolation, keys, nchannels, batch, device):
444444
def test_centercrop(self, keys, h, nchannels, batch, device):
445445
torch.manual_seed(0)
446446
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
447-
cc = CenterCrop(w=20, h=h, keys_in=keys)
447+
cc = CenterCrop(w=20, h=h, in_keys=keys)
448448
if h is None:
449449
h = 20
450450
td = TensorDict(
@@ -485,7 +485,7 @@ def test_flatten(self, keys, size, nchannels, batch, device):
485485
torch.manual_seed(0)
486486
dont_touch = torch.randn(*batch, *size, nchannels, 16, 16, device=device)
487487
start_dim = -3 - len(size)
488-
flatten = FlattenObservation(start_dim, -3, keys_in=keys)
488+
flatten = FlattenObservation(start_dim, -3, in_keys=keys)
489489
td = TensorDict(
490490
{
491491
key: torch.randn(*batch, *size, nchannels, 16, 16, device=device)
@@ -527,7 +527,7 @@ def test_flatten(self, keys, size, nchannels, batch, device):
527527
def test_unsqueeze(self, keys, size, nchannels, batch, device, unsqueeze_dim):
528528
torch.manual_seed(0)
529529
dont_touch = torch.randn(*batch, *size, nchannels, 16, 16, device=device)
530-
unsqueeze = UnsqueezeTransform(unsqueeze_dim, keys_in=keys)
530+
unsqueeze = UnsqueezeTransform(unsqueeze_dim, in_keys=keys)
531531
td = TensorDict(
532532
{
533533
key: torch.randn(*batch, *size, nchannels, 16, 16, device=device)
@@ -586,7 +586,7 @@ def test_unsqueeze_inv(
586586
torch.manual_seed(0)
587587
keys_total = set(keys + keys_inv)
588588
unsqueeze = UnsqueezeTransform(
589-
unsqueeze_dim, keys_in=keys, keys_inv_in=keys_inv
589+
unsqueeze_dim, in_keys=keys, in_keys_inv=keys_inv
590590
)
591591
td = TensorDict(
592592
{
@@ -621,7 +621,7 @@ def test_unsqueeze_inv(
621621
def test_squeeze(self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim):
622622
torch.manual_seed(0)
623623
keys_total = set(keys + keys_inv)
624-
squeeze = SqueezeTransform(squeeze_dim, keys_in=keys, keys_inv_in=keys_inv)
624+
squeeze = SqueezeTransform(squeeze_dim, in_keys=keys, in_keys_inv=keys_inv)
625625
td = TensorDict(
626626
{
627627
key: torch.randn(*batch, *size, nchannels, 16, 16, device=device)
@@ -656,7 +656,7 @@ def test_squeeze_inv(
656656
):
657657
torch.manual_seed(0)
658658
keys_total = set(keys + keys_inv)
659-
squeeze = SqueezeTransform(squeeze_dim, keys_in=keys, keys_inv_in=keys_inv)
659+
squeeze = SqueezeTransform(squeeze_dim, in_keys=keys, in_keys_inv=keys_inv)
660660
td = TensorDict(
661661
{
662662
key: torch.randn(*batch, *size, nchannels, 16, 16, device=device)
@@ -687,7 +687,7 @@ def test_squeeze_inv(
687687
def test_grayscale(self, keys, device):
688688
torch.manual_seed(0)
689689
nchannels = 3
690-
gs = GrayScale(keys_in=keys)
690+
gs = GrayScale(in_keys=keys)
691691
dont_touch = torch.randn(1, nchannels, 16, 16, device=device)
692692
td = TensorDict(
693693
{key: torch.randn(1, nchannels, 16, 16, device=device) for key in keys},
@@ -720,7 +720,7 @@ def test_grayscale(self, keys, device):
720720
def test_totensorimage(self, keys, batch, device):
721721
torch.manual_seed(0)
722722
nchannels = 3
723-
totensorimage = ToTensorImage(keys_in=keys)
723+
totensorimage = ToTensorImage(in_keys=keys)
724724
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
725725
td = TensorDict(
726726
{
@@ -764,7 +764,7 @@ def test_totensorimage(self, keys, batch, device):
764764
@pytest.mark.parametrize("device", get_available_devices())
765765
def test_compose(self, keys, batch, device, nchannels=1, N=4):
766766
torch.manual_seed(0)
767-
t1 = CatFrames(keys_in=keys, N=4)
767+
t1 = CatFrames(in_keys=keys, N=4)
768768
t2 = FiniteTensorDictCheck()
769769
compose = Compose(t1, t2)
770770
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
@@ -818,8 +818,8 @@ def test_compose_inv(self, keys_inv_1, keys_inv_2, device):
818818
torch.manual_seed(0)
819819
keys_to_transform = set(keys_inv_1 + keys_inv_2)
820820
keys_total = set(["action_1", "action_2", "dont_touch"])
821-
double2float_1 = DoubleToFloat(keys_inv_in=keys_inv_1)
822-
double2float_2 = DoubleToFloat(keys_inv_in=keys_inv_2)
821+
double2float_1 = DoubleToFloat(in_keys_inv=keys_inv_1)
822+
double2float_2 = DoubleToFloat(in_keys_inv=keys_inv_2)
823823
compose = Compose(double2float_1, double2float_2)
824824
td = TensorDict(
825825
{
@@ -861,7 +861,7 @@ def test_observationnorm(
861861
loc = loc.to(device)
862862
if isinstance(scale, Tensor):
863863
scale = scale.to(device)
864-
on = ObservationNorm(loc, scale, keys_in=keys, standard_normal=standard_normal)
864+
on = ObservationNorm(loc, scale, in_keys=keys, standard_normal=standard_normal)
865865
dont_touch = torch.randn(1, nchannels, 16, 16, device=device)
866866
td = TensorDict(
867867
{key: torch.zeros(1, nchannels, 16, 16, device=device) for key in keys}, [1]
@@ -910,7 +910,7 @@ def test_catframes_transform_observation_spec(self):
910910
key1 = "first key"
911911
key2 = "second key"
912912
keys = [key1, key2]
913-
cat_frames = CatFrames(N=N, keys_in=keys)
913+
cat_frames = CatFrames(N=N, in_keys=keys)
914914
mins = [0, 0.5]
915915
maxes = [0.5, 1]
916916
observation_spec = CompositeSpec(
@@ -953,7 +953,7 @@ def test_catframes_buffer_check_latest_frame(self, device):
953953
key2_tensor = torch.ones(1, 1, 3, 3, device=device)
954954
key_tensors = [key1_tensor, key2_tensor]
955955
td = TensorDict(dict(zip(keys, key_tensors)), [1], device=device)
956-
cat_frames = CatFrames(N=N, keys_in=keys)
956+
cat_frames = CatFrames(N=N, in_keys=keys)
957957

958958
cat_frames(td)
959959
latest_frame = td.get(key2)
@@ -973,7 +973,7 @@ def test_catframes_reset(self, device):
973973
key2_tensor = torch.ones(1, 1, 3, 3, device=device)
974974
key_tensors = [key1_tensor, key2_tensor]
975975
td = TensorDict(dict(zip(keys, key_tensors)), [1], device=device)
976-
cat_frames = CatFrames(N=N, keys_in=keys)
976+
cat_frames = CatFrames(N=N, in_keys=keys)
977977

978978
cat_frames(td)
979979
buffer_length1 = len(cat_frames.buffer)
@@ -1014,7 +1014,7 @@ def test_finitetensordictcheck(self, device):
10141014
def test_double2float(self, keys, keys_inv, device):
10151015
torch.manual_seed(0)
10161016
keys_total = set(keys + keys_inv)
1017-
double2float = DoubleToFloat(keys_in=keys, keys_inv_in=keys_inv)
1017+
double2float = DoubleToFloat(in_keys=keys, in_keys_inv=keys_inv)
10181018
dont_touch = torch.randn(1, 3, 3, dtype=torch.double, device=device)
10191019
td = TensorDict(
10201020
{
@@ -1066,7 +1066,7 @@ def test_double2float(self, keys, keys_inv, device):
10661066
],
10671067
)
10681068
def test_cattensors(self, keys, device):
1069-
cattensors = CatTensors(keys_in=keys, out_key="observation_out", dim=-2)
1069+
cattensors = CatTensors(in_keys=keys, out_key="observation_out", dim=-2)
10701070

10711071
dont_touch = torch.randn(1, 3, 3, dtype=torch.double, device=device)
10721072
td = TensorDict(
@@ -1235,7 +1235,7 @@ def test_reward_scaling(self, batch, scale, loc, keys, device):
12351235
keys_total = set([])
12361236
else:
12371237
keys_total = set(keys)
1238-
reward_scaling = RewardScaling(keys_in=keys, scale=scale, loc=loc)
1238+
reward_scaling = RewardScaling(in_keys=keys, scale=scale, loc=loc)
12391239
td = TensorDict(
12401240
{
12411241
**{key: torch.randn(*batch, 1, device=device) for key in keys_total},
@@ -1276,7 +1276,7 @@ def test_append(self):
12761276
key = list(obs_spec.keys())[0]
12771277

12781278
env = TransformedEnv(env)
1279-
env.append_transform(CatFrames(N=4, cat_dim=-1, keys_in=[key]))
1279+
env.append_transform(CatFrames(N=4, cat_dim=-1, in_keys=[key]))
12801280
assert isinstance(env.transform, Compose)
12811281
assert len(env.transform) == 1
12821282
obs_spec = env.observation_spec
@@ -1301,7 +1301,7 @@ def test_insert(self):
13011301
assert env._observation_spec is not None
13021302
assert env._reward_spec is not None
13031303

1304-
env.insert_transform(0, CatFrames(N=4, cat_dim=-1, keys_in=[key]))
1304+
env.insert_transform(0, CatFrames(N=4, cat_dim=-1, in_keys=[key]))
13051305

13061306
# transformed envs do not have spec after insert -- they need to be computed
13071307
assert env._input_spec is None
@@ -1348,7 +1348,7 @@ def test_insert(self):
13481348
assert env._observation_spec is None
13491349
assert env._reward_spec is None
13501350

1351-
env.insert_transform(-5, CatFrames(N=4, cat_dim=-1, keys_in=[key]))
1351+
env.insert_transform(-5, CatFrames(N=4, cat_dim=-1, in_keys=[key]))
13521352
assert isinstance(env.transform, Compose)
13531353
assert len(env.transform) == 6
13541354

@@ -1411,7 +1411,7 @@ def test_r3m_instantiation(self, model, tensor_pixels_key, device):
14111411
keys_out = ["next_vec"]
14121412
r3m = R3MTransform(
14131413
model,
1414-
keys_in=keys_in,
1414+
in_keys=keys_in,
14151415
keys_out=keys_out,
14161416
tensor_pixels_keys=tensor_pixels_key,
14171417
)
@@ -1442,7 +1442,7 @@ def test_r3m_mult_images(self, model, device, stack_images, parallel):
14421442
keys_out = ["next_vec"] if stack_images else ["next_vec", "next_vec2"]
14431443
r3m = R3MTransform(
14441444
model,
1445-
keys_in=keys_in,
1445+
in_keys=keys_in,
14461446
keys_out=keys_out,
14471447
stack_images=stack_images,
14481448
)
@@ -1492,7 +1492,7 @@ def test_r3m_parallel(self, model, device):
14921492
tensor_pixels_key = None
14931493
r3m = R3MTransform(
14941494
model,
1495-
keys_in=keys_in,
1495+
in_keys=keys_in,
14961496
keys_out=keys_out,
14971497
tensor_pixels_keys=tensor_pixels_key,
14981498
)
@@ -1566,7 +1566,7 @@ def test_r3m_spec_against_real(self, model, tensor_pixels_key, device):
15661566
keys_out = ["next_vec"]
15671567
r3m = R3MTransform(
15681568
model,
1569-
keys_in=keys_in,
1569+
in_keys=keys_in,
15701570
keys_out=keys_out,
15711571
tensor_pixels_keys=tensor_pixels_key,
15721572
)
@@ -1592,7 +1592,7 @@ def test_vip_instantiation(self, model, tensor_pixels_key, device):
15921592
keys_out = ["next_vec"]
15931593
vip = VIPTransform(
15941594
model,
1595-
keys_in=keys_in,
1595+
in_keys=keys_in,
15961596
keys_out=keys_out,
15971597
tensor_pixels_keys=tensor_pixels_key,
15981598
)
@@ -1617,7 +1617,7 @@ def test_vip_mult_images(self, model, device, stack_images, parallel):
16171617
keys_out = ["next_vec"] if stack_images else ["next_vec", "next_vec2"]
16181618
vip = VIPTransform(
16191619
model,
1620-
keys_in=keys_in,
1620+
in_keys=keys_in,
16211621
keys_out=keys_out,
16221622
stack_images=stack_images,
16231623
)
@@ -1667,7 +1667,7 @@ def test_vip_parallel(self, model, device):
16671667
tensor_pixels_key = None
16681668
vip = VIPTransform(
16691669
model,
1670-
keys_in=keys_in,
1670+
in_keys=keys_in,
16711671
keys_out=keys_out,
16721672
tensor_pixels_keys=tensor_pixels_key,
16731673
)
@@ -1741,7 +1741,7 @@ def test_vip_spec_against_real(self, model, tensor_pixels_key, device):
17411741
keys_out = ["next_vec"]
17421742
vip = VIPTransform(
17431743
model,
1744-
keys_in=keys_in,
1744+
in_keys=keys_in,
17451745
keys_out=keys_out,
17461746
tensor_pixels_keys=tensor_pixels_key,
17471747
)
@@ -1762,7 +1762,7 @@ def test_batch_locked_transformed(device):
17621762
env = TransformedEnv(
17631763
MockBatchedLockedEnv(device),
17641764
Compose(
1765-
ObservationNorm(keys_in=["next_observation"], loc=0.5, scale=1.1),
1765+
ObservationNorm(in_keys=["next_observation"], loc=0.5, scale=1.1),
17661766
RewardClipping(0, 0.1),
17671767
),
17681768
)
@@ -1786,7 +1786,7 @@ def test_batch_unlocked_transformed(device):
17861786
env = TransformedEnv(
17871787
MockBatchedUnLockedEnv(device),
17881788
Compose(
1789-
ObservationNorm(keys_in=["next_observation"], loc=0.5, scale=1.1),
1789+
ObservationNorm(in_keys=["next_observation"], loc=0.5, scale=1.1),
17901790
RewardClipping(0, 0.1),
17911791
),
17921792
)
@@ -1806,7 +1806,7 @@ def test_batch_unlocked_with_batch_size_transformed(device):
18061806
env = TransformedEnv(
18071807
MockBatchedUnLockedEnv(device, batch_size=torch.Size([2])),
18081808
Compose(
1809-
ObservationNorm(keys_in=["next_observation"], loc=0.5, scale=1.1),
1809+
ObservationNorm(in_keys=["next_observation"], loc=0.5, scale=1.1),
18101810
RewardClipping(0, 0.1),
18111811
),
18121812
)

0 commit comments

Comments
 (0)