Skip to content

Commit c2a149d

Browse files
author
Vincent Moens
committed
[BugFix] Fix composite setitem
ghstack-source-id: f33b49b Pull Request resolved: #2778
1 parent b27ee6d commit c2a149d

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

torchrl/data/tensor_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4818,7 +4818,7 @@ def get(self, item, default=NO_DEFAULT):
48184818
def __setitem__(self, key, value):
48194819
dest = self
48204820
if isinstance(key, tuple) and len(key) > 1:
4821-
while key[0] not in self.keys():
4821+
while key[0] not in dest.keys():
48224822
dest[key[0]] = dest = Composite(shape=self.shape, device=self.device)
48234823
if len(key) > 2:
48244824
key = key[1:]

torchrl/envs/transforms/transforms.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8037,9 +8037,11 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
80378037
input_spec["full_action_spec"][out_key] = input_spec[
80388038
"full_action_spec"
80398039
][action_key].clone()
8040-
if not self.create_copy:
8040+
if not self.create_copy:
8041+
for action_key in self.parent.action_keys:
8042+
if action_key in self.in_keys_inv:
80418043
del input_spec["full_action_spec"][action_key]
8042-
for state_key in self.parent.full_state_spec.keys(True):
8044+
for state_key in self.parent.full_state_spec.keys(True, True):
80438045
if state_key in self.in_keys_inv:
80448046
for i, out_key in enumerate(self.out_keys_inv): # noqa: B007
80458047
if self.in_keys_inv[i] == state_key:
@@ -8050,7 +8052,9 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
80508052
input_spec["full_state_spec"][out_key] = input_spec["full_state_spec"][
80518053
state_key
80528054
].clone()
8053-
if not self.create_copy:
8055+
if not self.create_copy:
8056+
for state_key in self.parent.full_state_spec.keys(True, True):
8057+
if state_key in self.in_keys_inv:
80548058
del input_spec["full_state_spec"][state_key]
80558059
return input_spec
80568060

0 commit comments

Comments
 (0)