Skip to content

Commit ac4b987

Browse files
authored
[BugFix] Update to strict select (#675)
* init * strict=False * amend * amend
1 parent d40b473 commit ac4b987

File tree

4 files changed

+57
-31
lines changed

4 files changed

+57
-31
lines changed

test/test_rb.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,18 @@ def test_extend(self, rb_type, sampler, writer, storage, size):
125125
found_similar = False
126126
for b in rb._storage:
127127
if isinstance(b, TensorDictBase):
128-
b = b.exclude("index").select(*set(d.keys()).intersection(b.keys()))
129-
d = d.select(*set(d.keys()).intersection(b.keys()))
128+
keys = set(d.keys()).intersection(b.keys())
129+
b = b.exclude("index").select(*keys, strict=False)
130+
keys = set(d.keys()).intersection(b.keys())
131+
d = d.select(*keys, strict=False)
130132

131133
value = b == d
132134
if isinstance(value, (torch.Tensor, TensorDictBase)):
133135
value = value.all()
134136
if value:
135-
found_similar = True
136137
break
137-
assert found_similar
138+
else:
139+
raise RuntimeError("did not find match")
138140

139141
def test_sample(self, rb_type, sampler, writer, storage, size):
140142
torch.manual_seed(0)
@@ -152,18 +154,18 @@ def test_sample(self, rb_type, sampler, writer, storage, size):
152154
for b in data:
153155
print(b, d)
154156
if isinstance(b, TensorDictBase):
155-
b = b.exclude("index").select(*set(d.keys()).intersection(b.keys()))
156-
d = d.select(*set(d.keys()).intersection(b.keys()))
157+
keys = set(d.keys()).intersection(b.keys())
158+
b = b.exclude("index").select(*keys, strict=False)
159+
keys = set(d.keys()).intersection(b.keys())
160+
d = d.select(*keys, strict=False)
157161

158162
value = b == d
159163
if isinstance(value, (torch.Tensor, TensorDictBase)):
160164
value = value.all()
161165
if value:
162-
found_similar = True
163166
break
164-
if not found_similar:
165-
d
166-
assert found_similar, (d, data)
167+
else:
168+
raise RuntimeError("did not find match")
167169

168170
def test_index(self, rb_type, sampler, writer, storage, size):
169171
torch.manual_seed(0)
@@ -394,16 +396,18 @@ def test_extend(self, rbtype, storage, size, prefetch):
394396
found_similar = False
395397
for b in rb._storage:
396398
if isinstance(b, TensorDictBase):
397-
b = b.exclude("index").select(*set(d.keys()).intersection(b.keys()))
398-
d = d.select(*set(d.keys()).intersection(b.keys()))
399+
keys = set(d.keys()).intersection(b.keys())
400+
b = b.exclude("index").select(*keys, strict=False)
401+
keys = set(d.keys()).intersection(b.keys())
402+
d = d.select(*keys, strict=False)
399403

400404
value = b == d
401405
if isinstance(value, (torch.Tensor, TensorDictBase)):
402406
value = value.all()
403407
if value:
404-
found_similar = True
405408
break
406-
assert found_similar
409+
else:
410+
raise RuntimeError("did not find match")
407411

408412
def test_sample(self, rbtype, storage, size, prefetch):
409413
torch.manual_seed(0)
@@ -418,18 +422,18 @@ def test_sample(self, rbtype, storage, size, prefetch):
418422
found_similar = False
419423
for b in data:
420424
if isinstance(b, TensorDictBase):
421-
b = b.exclude("index").select(*set(d.keys()).intersection(b.keys()))
422-
d = d.select(*set(d.keys()).intersection(b.keys()))
425+
keys = set(d.keys()).intersection(b.keys())
426+
b = b.exclude("index").select(*keys, strict=False)
427+
keys = set(d.keys()).intersection(b.keys())
428+
d = d.select(*keys, strict=False)
423429

424430
value = b == d
425431
if isinstance(value, (torch.Tensor, TensorDictBase)):
426432
value = value.all()
427433
if value:
428-
found_similar = True
429434
break
430-
if not found_similar:
431-
d
432-
assert found_similar, (d, data)
435+
else:
436+
raise RuntimeError("did not find matching value")
433437

434438
def test_index(self, rbtype, storage, size, prefetch):
435439
torch.manual_seed(0)

torchrl/collectors/collectors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,8 @@ def iterator(self) -> Iterator[TensorDictBase]:
536536
def _cast_to_policy(self, td: TensorDictBase) -> TensorDictBase:
537537
policy_device = self.device
538538
if hasattr(self.policy, "in_keys"):
539-
td = td.select(*self.policy.in_keys)
539+
# some keys may be absent -- TensorDictModule is resilient to missing keys
540+
td = td.select(*self.policy.in_keys, strict=False)
540541
if self._td_policy is None:
541542
self._td_policy = td.to(policy_device)
542543
else:

torchrl/envs/vec_env.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -421,13 +421,14 @@ def _create_td(self) -> None:
421421
)
422422
if self._single_task:
423423
shared_tensordict_parent = shared_tensordict_parent.select(
424-
*self.selected_keys
424+
*self.selected_keys,
425+
strict=False,
425426
)
426427
self.shared_tensordict_parent = shared_tensordict_parent.to(self.device)
427428
else:
428429
shared_tensordict_parent = torch.stack(
429430
[
430-
tensordict.select(*selected_keys).to(self.device)
431+
tensordict.select(*selected_keys, strict=False).to(self.device)
431432
for tensordict, selected_keys in zip(
432433
shared_tensordict_parent, self.selected_keys
433434
)
@@ -573,7 +574,10 @@ def _step(
573574
) -> TensorDict:
574575
self._assert_tensordict_shape(tensordict)
575576

576-
tensordict_in = tensordict.select(*self.env_input_keys)
577+
tensordict_in = tensordict.select(
578+
*self.env_input_keys,
579+
strict=False,
580+
)
577581
tensordict_out = []
578582
for i in range(self.num_workers):
579583
_tensordict_out = self._envs[i].step(tensordict_in[i])
@@ -611,7 +615,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
611615
keys = keys.union(_td.keys())
612616
self.shared_tensordicts[i].update_(_td)
613617

614-
return self.shared_tensordict_parent.select(*keys).clone()
618+
return self.shared_tensordict_parent.select(
619+
*keys,
620+
strict=False,
621+
).clone()
615622

616623
def __getattr__(self, attr: str) -> Any:
617624
if attr in self.__dir__():
@@ -740,7 +747,12 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
740747
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
741748
self._assert_tensordict_shape(tensordict)
742749

743-
self.shared_tensordict_parent.update_(tensordict.select(*self.env_input_keys))
750+
self.shared_tensordict_parent.update_(
751+
tensordict.select(
752+
*self.env_input_keys,
753+
strict=False,
754+
)
755+
)
744756
for i in range(self.num_workers):
745757
self.parent_channels[i].send(("step", None))
746758

@@ -756,7 +768,10 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
756768
keys = keys.union(data)
757769
# We must pass a clone of the tensordict, as the values of this tensordict
758770
# will be modified in-place at further steps
759-
return self.shared_tensordict_parent.select(*keys).clone()
771+
return self.shared_tensordict_parent.select(
772+
*keys,
773+
strict=False,
774+
).clone()
760775

761776
@_check_start
762777
def _shutdown_workers(self) -> None:
@@ -829,7 +844,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
829844
# there might be some delay between writing the shared tensordict
830845
# and reading the updated value on the main process
831846
sleep(0.01)
832-
return self.shared_tensordict_parent.select(*keys).clone()
847+
return self.shared_tensordict_parent.select(
848+
*keys,
849+
strict=False,
850+
).clone()
833851

834852
def __reduce__(self):
835853
if not self.is_closed:
@@ -979,7 +997,10 @@ def _run_worker_pipe_shared_mem(
979997
if not initialized:
980998
raise RuntimeError("called 'init' before step")
981999
i += 1
982-
_td = tensordict.select(*env_input_keys)
1000+
_td = tensordict.select(
1001+
*env_input_keys,
1002+
strict=False,
1003+
)
9831004
if env.is_done and not allow_step_when_done:
9841005
raise RuntimeError(
9851006
f"calling step when env is done, just reset = {just_reset}"
@@ -989,7 +1010,7 @@ def _run_worker_pipe_shared_mem(
9891010
step_keys = set(_td.keys()) - set(env_input_keys)
9901011
if pin_memory:
9911012
_td.pin_memory()
992-
tensordict.update_(_td.select(*step_keys))
1013+
tensordict.update_(_td.select(*step_keys, strict=False))
9931014
if _td.get("done"):
9941015
msg = "done"
9951016
else:

torchrl/modules/models/model_based.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def forward(self, tensordict):
200200
tensordict_out.append(_tensordict)
201201
if t < time_steps - 1:
202202
_tensordict = step_mdp(
203-
_tensordict.select(*self.out_keys), keep_other=False
203+
_tensordict.select(*self.out_keys, strict=False), keep_other=False
204204
)
205205
_tensordict = update_values[..., t + 1].update(_tensordict)
206206

0 commit comments

Comments
 (0)