@@ -1040,6 +1040,33 @@ def _make_node_map(self, source, dest):
1040
1040
self .max_size = self .data_map .max_size
1041
1041
1042
1042
def extend (self , rollout , * , return_node : bool = False ):
1043
+ """Add a rollout to the forest.
1044
+
1045
+ Nodes are only added to a tree at points where rollouts diverge from
1046
+ each other and at the endpoints of rollouts.
1047
+
1048
+ If there is no existing tree that matches the first steps of the
1049
+ rollout, a new tree is added. Only one node is created, for the final
1050
+ step.
1051
+
1052
+ If there is an existing tree that matches, the rollout is added to that
1053
+ tree. If the rollout diverges from all other rollouts in the tree at
1054
+ some step, a new node is created before the step where the rollouts
1055
+ diverge, and a leaf node is created for the final step of the rollout.
1056
+ If all of the rollout's steps match with a previously added rollout,
1057
+ nothing changes. If the rollout matches up to a leaf node of a tree but
1058
+ continues beyond it, that node is extended to the end of the rollout,
1059
+ and no new nodes are created.
1060
+
1061
+ Args:
1062
+ rollout (TensorDict): The rollout to add to the forest.
1063
+ return_node (bool, optional): If True, the method returns the added
1064
+ node. Default is ``False``.
1065
+
1066
+ Returns:
1067
+ Tree: The node that was added to the forest. This is only
1068
+ returned if ``return_node`` is True.
1069
+ """
1043
1070
source , dest = (
1044
1071
rollout .exclude ("next" ).copy (),
1045
1072
rollout .select ("next" , * self .action_keys ).copy (),
0 commit comments