Skip to content

Commit bb9440b

Browse files
kurtamohlerVincent Moens
authored andcommitted
[Test] Add tests for Tree
ghstack-source-id: 8f7aa07 Pull Request resolved: #2738
1 parent 4262ab9 commit bb9440b

File tree

1 file changed

+104
-1
lines changed

1 file changed

+104
-1
lines changed

test/test_storage_map.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313

1414
from tensordict import assert_close, TensorDict
15-
from torchrl.data import LazyTensorStorage, ListStorage, MCTSForest
15+
from torchrl.data import LazyTensorStorage, ListStorage, MCTSForest, Tree
1616
from torchrl.data.map import (
1717
BinaryToDecimal,
1818
QueryModule,
@@ -248,6 +248,109 @@ def test_map_rollout(self):
248248
assert not contains[rollout.shape[-1] :].any()
249249

250250

251+
# Tests Tree independent of MCTSForest
252+
class TestTree:
253+
def dummy_tree(self):
254+
"""Creates a tree with the following node IDs:
255+
256+
0
257+
├── 1
258+
| ├── 3
259+
| └── 4
260+
└── 2
261+
├── 5
262+
└── 6
263+
"""
264+
265+
class IDGen:
266+
def __init__(self):
267+
self.next_id = 0
268+
269+
def __call__(self):
270+
res = self.next_id
271+
self.next_id += 1
272+
return res
273+
274+
gen_id = IDGen()
275+
gen_hash = lambda: hash(torch.rand(1).item())
276+
277+
def dummy_node_stack(obervations):
278+
return TensorDict.lazy_stack(
279+
[
280+
Tree(
281+
node_data=TensorDict({"obs": torch.tensor(obs)}),
282+
hash=gen_hash(),
283+
node_id=gen_id(),
284+
)
285+
for obs in obervations
286+
]
287+
)
288+
289+
tree = dummy_node_stack([0])[0]
290+
tree.subtree = dummy_node_stack([1, 2])
291+
tree.subtree[0].subtree = dummy_node_stack([3, 4])
292+
tree.subtree[1].subtree = dummy_node_stack([6, 7])
293+
return tree
294+
295+
# Checks that when adding nodes to a tree, the `parent` property is set
296+
# correctly
297+
def test_parents(self):
298+
tree = self.dummy_tree()
299+
300+
def check_parents_recursive(tree, parent):
301+
if parent is None:
302+
if tree.parent is not None:
303+
return False
304+
elif tree.parent.node_data is not parent.node_data:
305+
return False
306+
307+
if tree.subtree is not None:
308+
for subtree in tree.subtree:
309+
if not check_parents_recursive(subtree, tree):
310+
return False
311+
312+
return True
313+
314+
assert check_parents_recursive(tree, None)
315+
316+
def test_vertices(self):
317+
tree = self.dummy_tree()
318+
N = 7
319+
assert tree.num_vertices(count_repeat=False) == N
320+
assert tree.num_vertices(count_repeat=True) == N
321+
assert len(tree.vertices(key_type="hash")) == N
322+
assert len(tree.vertices(key_type="id")) == N
323+
assert len(tree.vertices(key_type="path")) == N
324+
325+
for path, vertex in tree.vertices(key_type="path").items():
326+
vertex_check = tree
327+
for i in path:
328+
vertex_check = vertex_check.subtree[i]
329+
assert vertex.node_data is vertex_check.node_data
330+
331+
def test_in(self):
332+
for tree in self.dummy_tree().vertices().values():
333+
for path, subtree in tree.vertices(key_type="path").items():
334+
assert subtree in tree
335+
336+
if len(path) == 0:
337+
assert tree in subtree
338+
else:
339+
assert tree not in subtree
340+
341+
def test_valid_paths(self):
342+
tree = self.dummy_tree()
343+
paths = set(tree.valid_paths())
344+
paths_check = {(0, 0), (0, 1), (1, 0), (1, 1)}
345+
assert paths == paths_check
346+
347+
def test_edges(self):
348+
tree = self.dummy_tree()
349+
edges = set(tree.edges())
350+
edges_check = {(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)}
351+
assert edges == edges_check
352+
353+
251354
class TestMCTSForest:
252355
def dummy_rollouts(self) -> Tuple[TensorDict, ...]:
253356
"""

0 commit comments

Comments
 (0)