Skip to content

Commit 57dc25a

Browse files
author
Vincent Moens
committed
[Refactor] Refactor trees
ghstack-source-id: 368ba4c Pull Request resolved: #2634
1 parent 19dfefc commit 57dc25a

File tree

6 files changed

+678
-56
lines changed

6 files changed

+678
-56
lines changed

test/test_storage_map.py

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def _state0(self) -> TensorDict:
301301
def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict:
302302
done = torch.zeros_like(action, dtype=torch.bool).unsqueeze(-1)
303303
reward = action.clone()
304+
action = action + torch.arange(action.shape[-1]) / action.shape[-1]
304305

305306
return TensorDict(
306307
{
@@ -326,7 +327,7 @@ def _make_forest(self) -> MCTSForest:
326327
forest.extend(r4)
327328
return forest
328329

329-
def _make_forest_intersect(self) -> MCTSForest:
330+
def _make_forest_rebranching(self) -> MCTSForest:
330331
"""
331332
├── 0
332333
│ ├── 16
@@ -449,7 +450,7 @@ def test_forest_check_ids(self):
449450

450451
def test_forest_intersect(self):
451452
state0 = self._state0()
452-
forest = self._make_forest_intersect()
453+
forest = self._make_forest_rebranching()
453454
tree = forest.get_tree(state0)
454455
subtree = forest.get_tree(TensorDict(observation=19))
455456

@@ -467,13 +468,110 @@ def test_forest_intersect(self):
467468

468469
def test_forest_intersect_vertices(self):
469470
state0 = self._state0()
470-
forest = self._make_forest_intersect()
471+
forest = self._make_forest_rebranching()
471472
tree = forest.get_tree(state0)
472473
assert len(tree.vertices(key_type="path")) > len(tree.vertices(key_type="hash"))
473474
assert len(tree.vertices(key_type="id")) == len(tree.vertices(key_type="hash"))
474475
with pytest.raises(ValueError, match="key_type must be"):
475476
tree.vertices(key_type="another key type")
476477

478+
@pytest.mark.skipif(not _has_gym, reason="requires gym")
479+
def test_simple_tree(self):
480+
from torchrl.envs import GymEnv
481+
482+
env = GymEnv("Pendulum-v1")
483+
r = env.rollout(10)
484+
state0 = r[0]
485+
forest = MCTSForest()
486+
forest.extend(r)
487+
# forest = self._make_forest_intersect()
488+
tree = forest.get_tree(state0, compact=False)
489+
assert tree.max_length() == 9
490+
for p in tree.valid_paths():
491+
assert len(p) == 9
492+
493+
@pytest.mark.parametrize(
494+
"tree_type,compact",
495+
[
496+
["simple", False],
497+
["forest", False],
498+
# parent of rebranching trees are still buggy
499+
# ["rebranching", False],
500+
# ["rebranching", True],
501+
],
502+
)
503+
def test_forest_parent(self, tree_type, compact):
504+
if tree_type == "simple":
505+
if not _has_gym:
506+
pytest.skip("requires gym")
507+
from torchrl.envs import GymEnv
508+
509+
env = GymEnv("Pendulum-v1")
510+
r = env.rollout(10)
511+
state0 = r[0]
512+
forest = MCTSForest()
513+
forest.extend(r)
514+
tree = forest.get_tree(state0, compact=compact)
515+
elif tree_type == "forest":
516+
state0 = self._state0()
517+
forest = self._make_forest()
518+
tree = forest.get_tree(state0, compact=compact)
519+
else:
520+
state0 = self._state0()
521+
forest = self._make_forest_rebranching()
522+
tree = forest.get_tree(state0, compact=compact)
523+
# Check access
524+
tree.subtree.parent
525+
tree.subtree.subtree.parent
526+
tree.subtree.subtree.subtree.parent
527+
528+
# check present of weakref
529+
assert tree.subtree[0]._parent is not None
530+
assert tree.subtree[0].subtree[0]._parent is not None
531+
532+
# Check content
533+
assert_close(tree.subtree.parent, tree)
534+
for p in tree.valid_paths():
535+
root = tree
536+
for it in p:
537+
node = root.subtree[it]
538+
assert_close(node.parent, root)
539+
root = node
540+
541+
def test_forest_action_attr(self):
542+
state0 = self._state0()
543+
forest = self._make_forest()
544+
tree = forest.get_tree(state0)
545+
assert tree.branching_action is None
546+
assert (tree.subtree.branching_action != tree.subtree.prev_action).any()
547+
assert (
548+
tree.subtree[0].subtree.branching_action
549+
!= tree.subtree[0].subtree.prev_action
550+
).any()
551+
assert tree.prev_action is None
552+
553+
@pytest.mark.parametrize("intersect", [False, True])
554+
def test_forest_check_obs_match(self, intersect):
555+
state0 = self._state0()
556+
if intersect:
557+
forest = self._make_forest_rebranching()
558+
else:
559+
forest = self._make_forest()
560+
tree = forest.get_tree(state0)
561+
for path in tree.valid_paths():
562+
prev_tree = tree
563+
for p in path:
564+
subtree = prev_tree.subtree[p]
565+
assert (
566+
subtree.node_data["observation"]
567+
== subtree.rollout[..., -1]["next", "observation"]
568+
).all()
569+
assert (
570+
subtree.node_observation
571+
== subtree.rollout[..., -1]["next", "observation"]
572+
).all()
573+
prev_tree = subtree
574+
477575

478576
if __name__ == "__main__":
479577
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/map/hash.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def forward(self, features: torch.Tensor) -> torch.Tensor:
7575
class SipHash(Module):
7676
"""A Module to Compute SipHash values for given tensors.
7777
78-
A hash function module based on SipHash implementation in python.
78+
A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]``
79+
and the output shape will be ``[batch_size]``.
7980
8081
Args:
8182
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers

torchrl/data/map/tdstorage.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ def __init__(
138138
self.collate_fn = collate_fn
139139
self.write_fn = write_fn
140140

141+
@property
142+
def max_size(self):
143+
return self.storage.max_size
144+
141145
@property
142146
def out_keys(self) -> List[NestedKey]:
143147
out_keys = self.__dict__.get("_out_keys_and_lazy")
@@ -177,7 +181,7 @@ def from_tensordict_pair(
177181
collate_fn: Callable[[Any], Any] | None = None,
178182
write_fn: Callable[[Any, Any], Any] | None = None,
179183
consolidated: bool | None = None,
180-
):
184+
) -> TensorDictMap:
181185
"""Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb.
182186
183187
Args:
@@ -238,7 +242,13 @@ def from_tensordict_pair(
238242
n_feat = 0
239243
hash_module = []
240244
for in_key in in_keys:
241-
n_feat = source[in_key].shape[-1]
245+
entry = source[in_key]
246+
if entry.ndim == source.ndim:
247+
# this is a good example of why td/tc are useful - carrying metadata
248+
# allows us to know if there's a feature dim or not
249+
n_feat = 0
250+
else:
251+
n_feat = entry.shape[-1]
242252
if n_feat > RandomProjectionHash._N_COMPONENTS_DEFAULT:
243253
_hash_module = RandomProjectionHash()
244254
else:
@@ -308,7 +318,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
308318
if not self._has_lazy_out_keys():
309319
# TODO: make this work with pytrees and avoid calling select if keys match
310320
value = value.select(*self.out_keys, strict=False)
321+
item, value = self._maybe_add_batch(item, value)
322+
index = self._to_index(item, extend=True)
323+
if index.unique().numel() < index.numel():
324+
# If multiple values point to the same place in the storage, we cannot process them by batch
325+
# There could be a better way to deal with this, using unique ids.
326+
vals = []
327+
for it, val in zip(item.split(1), value.split(1)):
328+
self[it] = val
329+
vals.append(val)
330+
# __setitem__ may affect the content of the input data
331+
value.update(TensorDictBase.lazy_stack(vals))
332+
return
311333
if self.write_fn is not None:
334+
# We use this block in the following context: the value written in the storage is already present,
335+
# but it needs to be updated.
336+
# We first check if the value is already there using `contains`. If so, we pass the new value and the
337+
# previous one to write_fn. The values that are not present are passed alone.
312338
if len(self):
313339
modifiable = self.contains(item)
314340
if modifiable.any():
@@ -322,8 +348,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
322348
value = self.write_fn(value)
323349
else:
324350
value = self.write_fn(value)
325-
item, value = self._maybe_add_batch(item, value)
326-
index = self._to_index(item, extend=True)
327351
self.storage.set(index, value)
328352

329353
def __len__(self):

0 commit comments

Comments
 (0)