Skip to content

Commit d3dca73

Browse files
author
Vincent Moens
committed
[Feature] Capture wrong spec transforms (1/N)
ghstack-source-id: f2d938b Pull Request resolved: #2805
1 parent e0d3eee commit d3dca73

File tree

3 files changed

+67
-12
lines changed

3 files changed

+67
-12
lines changed

test/test_env.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@
5555
from torchrl.envs.gym_like import default_info_dict_reader
5656
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
5757
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper
58-
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
59-
from torchrl.envs.transforms.transforms import (
58+
from torchrl.envs.transforms import (
6059
AutoResetEnv,
6160
AutoResetTransform,
61+
Compose,
62+
StepCounter,
6263
Tokenizer,
6364
Transform,
65+
TransformedEnv,
6466
UnsqueezeTransform,
6567
)
6668
from torchrl.envs.utils import (
@@ -3770,6 +3772,28 @@ def test_str2str_rb_slicesampler(self):
37703772
else:
37713773
raise RuntimeError("Failed to sample both trajs")
37723774

3775+
def test_env_with_str_append(self):
3776+
class StrAppender(Transform):
3777+
def transform_observation_spec(self, observation_spec):
3778+
return observation_spec.set("str", NonTensor(example_data="a string"))
3779+
3780+
def _step(self, td, next_td):
3781+
s = td["str"]
3782+
3783+
s += "-" + str(int(s.split("-")[-1]) + 1)
3784+
next_td["str"] = s
3785+
return next_td
3786+
3787+
def _reset(self, td, reset_td):
3788+
return reset_td.set("str", "0")
3789+
3790+
env = TransformedEnv(CountingEnv(), StrAppender())
3791+
r = env.rollout(10)
3792+
r_unbind = r.unbind(0)
3793+
for ep_prev, ep_next in zip(r_unbind[:-1], r_unbind[1:]):
3794+
assert ep_prev["next", "str"].startswith(ep_prev["str"])
3795+
assert ep_next["str"] == ep_prev["next", "str"]
3796+
37733797
def test_env_with_tensorclass(self):
37743798
env = EnvWithTensorClass()
37753799
env.check_env_specs()

torchrl/envs/transforms/rlhf.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,6 @@ def _step(
187187
forward = _call
188188

189189
def transform_output_spec(self, output_spec: Composite) -> Composite:
190-
output_spec = super().transform_output_spec(output_spec)
191-
# todo: here we'll need to use the reward_key once it's implemented
192-
# parent = self.parent
193190
in_key = unravel_key(self.in_keys[0])
194191
out_key = unravel_key(self.out_keys[0])
195192

@@ -205,18 +202,13 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
205202
)
206203
elif in_key == "reward":
207204
parent = self.parent
208-
reward_spec = Unbounded(
209-
device=output_spec.device,
210-
shape=output_spec["full_reward_spec"][parent.reward_key].shape,
211-
)
205+
reward_spec = output_spec["full_reward_spec"][parent.reward_key].clone()
212206
# then we need to populate the output keys
213207
observation_spec = output_spec["full_observation_spec"]
214208
observation_spec[out_key] = reward_spec
215209
else:
216210
observation_spec = output_spec["full_observation_spec"]
217-
reward_spec = Unbounded(
218-
device=output_spec.device, shape=observation_spec[in_key].shape
219-
)
211+
reward_spec = observation_spec[in_key].clone()
220212
# then we need to populate the output keys
221213
observation_spec[out_key] = reward_spec
222214
return output_spec

torchrl/envs/transforms/transforms.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,21 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
560560
output_spec["full_done_spec"] = self.transform_done_spec(
561561
output_spec["full_done_spec"]
562562
)
563+
output_spec_keys = [
564+
unravel_key(k[1:]) for k in output_spec.keys(True) if isinstance(k, tuple)
565+
]
566+
out_keys = {unravel_key(k) for k in self.out_keys}
567+
in_keys = {unravel_key(k) for k in self.in_keys}
568+
for key in out_keys - in_keys:
569+
if unravel_key(key) not in output_spec_keys:
570+
warnings.warn(
571+
f"The key '{key}' is unaccounted for by the transform (expected keys {output_spec_keys}). "
572+
f"Every new entry in the tensordict resulting from a call to a transform must be "
573+
f"registered in the specs for torchrl rollouts to be consistently built. "
574+
f"Make sure transform_output_spec/transform_observation_spec/... is coded correctly. "
575+
"This warning will trigger a KeyError in v0.9, make sure to adapt your code accordingly.",
576+
category=FutureWarning,
577+
)
563578
return output_spec
564579

565580
def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
@@ -1468,33 +1483,57 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
14681483
# the action spec from the env, map it using t0 then t1 (going from in to out).
14691484
for t in self.transforms:
14701485
input_spec = t.transform_input_spec(input_spec)
1486+
if not isinstance(input_spec, Composite):
1487+
raise TypeError(
1488+
f"Expected Compose but got {type(input_spec)} with transform {t}"
1489+
)
14711490
return input_spec
14721491

14731492
def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec:
14741493
# To understand why we don't invert, look up at transform_input_spec
14751494
for t in self.transforms:
14761495
action_spec = t.transform_action_spec(action_spec)
1496+
if not isinstance(action_spec, TensorSpec):
1497+
raise TypeError(
1498+
f"Expected TensorSpec but got {type(action_spec)} with transform {t}"
1499+
)
14771500
return action_spec
14781501

14791502
def transform_state_spec(self, state_spec: TensorSpec) -> TensorSpec:
14801503
# To understand why we don't invert, look up at transform_input_spec
14811504
for t in self.transforms:
14821505
state_spec = t.transform_state_spec(state_spec)
1506+
if not isinstance(state_spec, Composite):
1507+
raise TypeError(
1508+
f"Expected Compose but got {type(state_spec)} with transform {t}"
1509+
)
14831510
return state_spec
14841511

14851512
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
14861513
for t in self.transforms:
14871514
observation_spec = t.transform_observation_spec(observation_spec)
1515+
if not isinstance(observation_spec, TensorSpec):
1516+
raise TypeError(
1517+
f"Expected TensorSpec but got {type(observation_spec)} with transform {t}"
1518+
)
14881519
return observation_spec
14891520

14901521
def transform_output_spec(self, output_spec: TensorSpec) -> TensorSpec:
14911522
for t in self.transforms:
14921523
output_spec = t.transform_output_spec(output_spec)
1524+
if not isinstance(output_spec, Composite):
1525+
raise TypeError(
1526+
f"Expected Compose but got {type(output_spec)} with transform {t}"
1527+
)
14931528
return output_spec
14941529

14951530
def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
14961531
for t in self.transforms:
14971532
reward_spec = t.transform_reward_spec(reward_spec)
1533+
if not isinstance(reward_spec, TensorSpec):
1534+
raise TypeError(
1535+
f"Expected TensorSpec but got {type(reward_spec)} with transform {t}"
1536+
)
14981537
return reward_spec
14991538

15001539
def __getitem__(self, item: Union[int, slice, List]) -> Union:

0 commit comments

Comments
 (0)