@@ -301,6 +301,7 @@ def _state0(self) -> TensorDict:
301
301
def _make_td (state : torch .Tensor , action : torch .Tensor ) -> TensorDict :
302
302
done = torch .zeros_like (action , dtype = torch .bool ).unsqueeze (- 1 )
303
303
reward = action .clone ()
304
+ action = action + torch .arange (action .shape [- 1 ]) / action .shape [- 1 ]
304
305
305
306
return TensorDict (
306
307
{
@@ -326,7 +327,7 @@ def _make_forest(self) -> MCTSForest:
326
327
forest .extend (r4 )
327
328
return forest
328
329
329
- def _make_forest_intersect (self ) -> MCTSForest :
330
+ def _make_forest_rebranching (self ) -> MCTSForest :
330
331
"""
331
332
├── 0
332
333
│ ├── 16
@@ -449,7 +450,7 @@ def test_forest_check_ids(self):
449
450
450
451
def test_forest_intersect (self ):
451
452
state0 = self ._state0 ()
452
- forest = self ._make_forest_intersect ()
453
+ forest = self ._make_forest_rebranching ()
453
454
tree = forest .get_tree (state0 )
454
455
subtree = forest .get_tree (TensorDict (observation = 19 ))
455
456
@@ -467,13 +468,110 @@ def test_forest_intersect(self):
467
468
468
469
def test_forest_intersect_vertices (self ):
469
470
state0 = self ._state0 ()
470
- forest = self ._make_forest_intersect ()
471
+ forest = self ._make_forest_rebranching ()
471
472
tree = forest .get_tree (state0 )
472
473
assert len (tree .vertices (key_type = "path" )) > len (tree .vertices (key_type = "hash" ))
473
474
assert len (tree .vertices (key_type = "id" )) == len (tree .vertices (key_type = "hash" ))
474
475
with pytest .raises (ValueError , match = "key_type must be" ):
475
476
tree .vertices (key_type = "another key type" )
476
477
478
+ @pytest .mark .skipif (not _has_gym , reason = "requires gym" )
479
+ def test_simple_tree (self ):
480
+ from torchrl .envs import GymEnv
481
+
482
+ env = GymEnv ("Pendulum-v1" )
483
+ r = env .rollout (10 )
484
+ state0 = r [0 ]
485
+ forest = MCTSForest ()
486
+ forest .extend (r )
487
+ # forest = self._make_forest_intersect()
488
+ tree = forest .get_tree (state0 , compact = False )
489
+ assert tree .max_length () == 9
490
+ for p in tree .valid_paths ():
491
+ assert len (p ) == 9
492
+
493
+ @pytest .mark .parametrize (
494
+ "tree_type,compact" ,
495
+ [
496
+ ["simple" , False ],
497
+ ["forest" , False ],
498
+ # parent of rebranching trees are still buggy
499
+ # ["rebranching", False],
500
+ # ["rebranching", True],
501
+ ],
502
+ )
503
+ def test_forest_parent (self , tree_type , compact ):
504
+ if tree_type == "simple" :
505
+ if not _has_gym :
506
+ pytest .skip ("requires gym" )
507
+ from torchrl .envs import GymEnv
508
+
509
+ env = GymEnv ("Pendulum-v1" )
510
+ r = env .rollout (10 )
511
+ state0 = r [0 ]
512
+ forest = MCTSForest ()
513
+ forest .extend (r )
514
+ tree = forest .get_tree (state0 , compact = compact )
515
+ elif tree_type == "forest" :
516
+ state0 = self ._state0 ()
517
+ forest = self ._make_forest ()
518
+ tree = forest .get_tree (state0 , compact = compact )
519
+ else :
520
+ state0 = self ._state0 ()
521
+ forest = self ._make_forest_rebranching ()
522
+ tree = forest .get_tree (state0 , compact = compact )
523
+ # Check access
524
+ tree .subtree .parent
525
+ tree .subtree .subtree .parent
526
+ tree .subtree .subtree .subtree .parent
527
+
528
+ # check present of weakref
529
+ assert tree .subtree [0 ]._parent is not None
530
+ assert tree .subtree [0 ].subtree [0 ]._parent is not None
531
+
532
+ # Check content
533
+ assert_close (tree .subtree .parent , tree )
534
+ for p in tree .valid_paths ():
535
+ root = tree
536
+ for it in p :
537
+ node = root .subtree [it ]
538
+ assert_close (node .parent , root )
539
+ root = node
540
+
541
+ def test_forest_action_attr (self ):
542
+ state0 = self ._state0 ()
543
+ forest = self ._make_forest ()
544
+ tree = forest .get_tree (state0 )
545
+ assert tree .branching_action is None
546
+ assert (tree .subtree .branching_action != tree .subtree .prev_action ).any ()
547
+ assert (
548
+ tree .subtree [0 ].subtree .branching_action
549
+ != tree .subtree [0 ].subtree .prev_action
550
+ ).any ()
551
+ assert tree .prev_action is None
552
+
553
+ @pytest .mark .parametrize ("intersect" , [False , True ])
554
+ def test_forest_check_obs_match (self , intersect ):
555
+ state0 = self ._state0 ()
556
+ if intersect :
557
+ forest = self ._make_forest_rebranching ()
558
+ else :
559
+ forest = self ._make_forest ()
560
+ tree = forest .get_tree (state0 )
561
+ for path in tree .valid_paths ():
562
+ prev_tree = tree
563
+ for p in path :
564
+ subtree = prev_tree .subtree [p ]
565
+ assert (
566
+ subtree .node_data ["observation" ]
567
+ == subtree .rollout [..., - 1 ]["next" , "observation" ]
568
+ ).all ()
569
+ assert (
570
+ subtree .node_observation
571
+ == subtree .rollout [..., - 1 ]["next" , "observation" ]
572
+ ).all ()
573
+ prev_tree = subtree
574
+
477
575
478
576
if __name__ == "__main__" :
479
577
args , unknown = argparse .ArgumentParser ().parse_known_args ()
0 commit comments