Skip to content

Commit dfde953

Browse files
author
Vincent Moens
committed
[BugFix] Fix composite setitem
ghstack-source-id: f33b49b Pull Request resolved: #2778 (cherry picked from commit c2a149d)
1 parent 57c7d2d commit dfde953

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

torchrl/data/tensor_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4818,7 +4818,7 @@ def get(self, item, default=NO_DEFAULT):
48184818
def __setitem__(self, key, value):
48194819
dest = self
48204820
if isinstance(key, tuple) and len(key) > 1:
4821-
while key[0] not in self.keys():
4821+
while key[0] not in dest.keys():
48224822
dest[key[0]] = dest = Composite(shape=self.shape, device=self.device)
48234823
if len(key) > 2:
48244824
key = key[1:]

torchrl/envs/transforms/transforms.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8015,9 +8015,11 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
80158015
input_spec["full_action_spec"][out_key] = input_spec[
80168016
"full_action_spec"
80178017
][action_key].clone()
8018-
if not self.create_copy:
8018+
if not self.create_copy:
8019+
for action_key in self.parent.action_keys:
8020+
if action_key in self.in_keys_inv:
80198021
del input_spec["full_action_spec"][action_key]
8020-
for state_key in self.parent.full_state_spec.keys(True):
8022+
for state_key in self.parent.full_state_spec.keys(True, True):
80218023
if state_key in self.in_keys_inv:
80228024
for i, out_key in enumerate(self.out_keys_inv): # noqa: B007
80238025
if self.in_keys_inv[i] == state_key:
@@ -8028,7 +8030,9 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
80288030
input_spec["full_state_spec"][out_key] = input_spec["full_state_spec"][
80298031
state_key
80308032
].clone()
8031-
if not self.create_copy:
8033+
if not self.create_copy:
8034+
for state_key in self.parent.full_state_spec.keys(True, True):
8035+
if state_key in self.in_keys_inv:
80328036
del input_spec["full_state_spec"][state_key]
80338037
return input_spec
80348038

0 commit comments

Comments
 (0)