Skip to content

Commit 59d29b8

Browse files
author
Vincent Moens
authored
[Feature] Fix DType casting lazy init (#1589)
1 parent db1a7d4 commit 59d29b8

File tree

5 files changed

+393
-176
lines changed

5 files changed

+393
-176
lines changed

test/test_transforms.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,12 +2259,6 @@ def test_double2float(self, keys, keys_inv, device):
22592259
)
22602260
action_spec = double2float.transform_input_spec(input_spec)
22612261
assert action_spec.dtype == torch.float
2262-
2263-
elif len(keys) == 1:
2264-
observation_spec = BoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double)
2265-
observation_spec = double2float.transform_observation_spec(observation_spec)
2266-
assert observation_spec.dtype == torch.float
2267-
22682262
else:
22692263
observation_spec = CompositeSpec(
22702264
{
@@ -2274,7 +2268,7 @@ def test_double2float(self, keys, keys_inv, device):
22742268
)
22752269
observation_spec = double2float.transform_observation_spec(observation_spec)
22762270
for key in keys:
2277-
assert observation_spec[key].dtype == torch.float
2271+
assert observation_spec[key].dtype == torch.float, key
22782272

22792273
@pytest.mark.parametrize("device", get_default_devices())
22802274
@pytest.mark.parametrize(
@@ -2326,6 +2320,7 @@ def test_single_env_no_inkeys(self):
23262320
base_env.state_spec[key] = spec.to(torch.float64)
23272321
if base_env.action_spec.dtype == torch.float32:
23282322
base_env.action_spec = base_env.action_spec.to(torch.float64)
2323+
check_env_specs(base_env)
23292324
env = TransformedEnv(
23302325
base_env,
23312326
DoubleToFloat(),
@@ -2335,6 +2330,8 @@ def test_single_env_no_inkeys(self):
23352330
for spec in env.state_spec.values(True, True):
23362331
assert spec.dtype == torch.float32
23372332
assert env.action_spec.dtype != torch.float64
2333+
assert env.transform.in_keys == env.transform.out_keys
2334+
assert env.transform.in_keys_inv == env.transform.out_keys_inv
23382335
check_env_specs(env)
23392336

23402337
def test_single_trans_env_check(self, dtype_fixture): # noqa: F811

torchrl/envs/transforms/rlhf.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
from copy import deepcopy
5+
from copy import copy, deepcopy
66

77
import torch
88
from tensordict import TensorDictBase, unravel_key
@@ -93,24 +93,22 @@ def __init__(
9393
if in_keys is None:
9494
in_keys = self.DEFAULT_IN_KEYS
9595
if out_keys is None:
96-
out_keys = in_keys
97-
if not isinstance(in_keys, list):
98-
in_keys = [in_keys]
99-
if not isinstance(out_keys, list):
100-
out_keys = [out_keys]
101-
if not is_seq_of_nested_key(in_keys) or not is_seq_of_nested_key(out_keys):
96+
out_keys = copy(in_keys)
97+
super().__init__(in_keys=in_keys, out_keys=out_keys)
98+
if not is_seq_of_nested_key(self.in_keys) or not is_seq_of_nested_key(
99+
self.out_keys
100+
):
102101
raise ValueError(
103-
f"invalid in_keys / out_keys:\nin_keys={in_keys} \nout_keys={out_keys}"
102+
f"invalid in_keys / out_keys:\nin_keys={self.in_keys} \nout_keys={self.out_keys}"
104103
)
105-
if len(in_keys) != 1 or len(out_keys) != 1:
104+
if len(self.in_keys) != 1 or len(self.out_keys) != 1:
106105
raise ValueError(
107-
f"Only one in_key/out_key is allowed, got in_keys={in_keys}, out_keys={out_keys}."
106+
f"Only one in_key/out_key is allowed, got in_keys={self.in_keys}, out_keys={self.out_keys}."
108107
)
109-
super().__init__(in_keys=in_keys, out_keys=out_keys)
110108
# for convenience, convert out_keys to tuples
111-
self.out_keys = [
109+
self._out_keys = [
112110
out_key if isinstance(out_key, tuple) else (out_key,)
113-
for out_key in self.out_keys
111+
for out_key in self._out_keys
114112
]
115113

116114
# update the in_keys for dispatch etc

0 commit comments

Comments
 (0)