|
12 | 12 | import torch
|
13 | 13 |
|
14 | 14 | from tensordict import assert_close, TensorDict
|
15 |
| -from torchrl.data import LazyTensorStorage, ListStorage, MCTSForest |
| 15 | +from torchrl.data import LazyTensorStorage, ListStorage, MCTSForest, Tree |
16 | 16 | from torchrl.data.map import (
|
17 | 17 | BinaryToDecimal,
|
18 | 18 | QueryModule,
|
@@ -248,6 +248,109 @@ def test_map_rollout(self):
|
248 | 248 | assert not contains[rollout.shape[-1] :].any()
|
249 | 249 |
|
250 | 250 |
|
| 251 | +# Tests Tree independent of MCTSForest |
| 252 | +class TestTree: |
| 253 | + def dummy_tree(self): |
| 254 | + """Creates a tree with the following node IDs: |
| 255 | +
|
| 256 | + 0 |
| 257 | + ├── 1 |
| 258 | + | ├── 3 |
| 259 | + | └── 4 |
| 260 | + └── 2 |
| 261 | + ├── 5 |
| 262 | + └── 6 |
| 263 | + """ |
| 264 | + |
| 265 | + class IDGen: |
| 266 | + def __init__(self): |
| 267 | + self.next_id = 0 |
| 268 | + |
| 269 | + def __call__(self): |
| 270 | + res = self.next_id |
| 271 | + self.next_id += 1 |
| 272 | + return res |
| 273 | + |
| 274 | + gen_id = IDGen() |
| 275 | + gen_hash = lambda: hash(torch.rand(1).item()) |
| 276 | + |
| 277 | + def dummy_node_stack(obervations): |
| 278 | + return TensorDict.lazy_stack( |
| 279 | + [ |
| 280 | + Tree( |
| 281 | + node_data=TensorDict({"obs": torch.tensor(obs)}), |
| 282 | + hash=gen_hash(), |
| 283 | + node_id=gen_id(), |
| 284 | + ) |
| 285 | + for obs in obervations |
| 286 | + ] |
| 287 | + ) |
| 288 | + |
| 289 | + tree = dummy_node_stack([0])[0] |
| 290 | + tree.subtree = dummy_node_stack([1, 2]) |
| 291 | + tree.subtree[0].subtree = dummy_node_stack([3, 4]) |
| 292 | + tree.subtree[1].subtree = dummy_node_stack([6, 7]) |
| 293 | + return tree |
| 294 | + |
| 295 | + # Checks that when adding nodes to a tree, the `parent` property is set |
| 296 | + # correctly |
| 297 | + def test_parents(self): |
| 298 | + tree = self.dummy_tree() |
| 299 | + |
| 300 | + def check_parents_recursive(tree, parent): |
| 301 | + if parent is None: |
| 302 | + if tree.parent is not None: |
| 303 | + return False |
| 304 | + elif tree.parent.node_data is not parent.node_data: |
| 305 | + return False |
| 306 | + |
| 307 | + if tree.subtree is not None: |
| 308 | + for subtree in tree.subtree: |
| 309 | + if not check_parents_recursive(subtree, tree): |
| 310 | + return False |
| 311 | + |
| 312 | + return True |
| 313 | + |
| 314 | + assert check_parents_recursive(tree, None) |
| 315 | + |
| 316 | + def test_vertices(self): |
| 317 | + tree = self.dummy_tree() |
| 318 | + N = 7 |
| 319 | + assert tree.num_vertices(count_repeat=False) == N |
| 320 | + assert tree.num_vertices(count_repeat=True) == N |
| 321 | + assert len(tree.vertices(key_type="hash")) == N |
| 322 | + assert len(tree.vertices(key_type="id")) == N |
| 323 | + assert len(tree.vertices(key_type="path")) == N |
| 324 | + |
| 325 | + for path, vertex in tree.vertices(key_type="path").items(): |
| 326 | + vertex_check = tree |
| 327 | + for i in path: |
| 328 | + vertex_check = vertex_check.subtree[i] |
| 329 | + assert vertex.node_data is vertex_check.node_data |
| 330 | + |
| 331 | + def test_in(self): |
| 332 | + for tree in self.dummy_tree().vertices().values(): |
| 333 | + for path, subtree in tree.vertices(key_type="path").items(): |
| 334 | + assert subtree in tree |
| 335 | + |
| 336 | + if len(path) == 0: |
| 337 | + assert tree in subtree |
| 338 | + else: |
| 339 | + assert tree not in subtree |
| 340 | + |
| 341 | + def test_valid_paths(self): |
| 342 | + tree = self.dummy_tree() |
| 343 | + paths = set(tree.valid_paths()) |
| 344 | + paths_check = {(0, 0), (0, 1), (1, 0), (1, 1)} |
| 345 | + assert paths == paths_check |
| 346 | + |
| 347 | + def test_edges(self): |
| 348 | + tree = self.dummy_tree() |
| 349 | + edges = set(tree.edges()) |
| 350 | + edges_check = {(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)} |
| 351 | + assert edges == edges_check |
| 352 | + |
| 353 | + |
251 | 354 | class TestMCTSForest:
|
252 | 355 | def dummy_rollouts(self) -> Tuple[TensorDict, ...]:
|
253 | 356 | """
|
|
0 commit comments