Skip to content

Commit caf13c6

Browse files
committed
[Doc] Add docstring for MCTSForest.extend
ghstack-source-id: dbef5e4 Pull Request resolved: pytorch#2795
1 parent 6063130 commit caf13c6

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

torchrl/data/map/tree.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,33 @@ def _make_node_map(self, source, dest):
10401040
self.max_size = self.data_map.max_size
10411041

10421042
def extend(self, rollout, *, return_node: bool = False):
1043+
"""Add a rollout to the forest.
1044+
1045+
Nodes are only added to a tree at points where rollouts diverge from
1046+
each other and at the endpoints of rollouts.
1047+
1048+
If there is no existing tree that matches the first steps of the
1049+
rollout, a new tree is added. Only one node is created, for the final
1050+
step.
1051+
1052+
If there is an existing tree that matches, the rollout is added to that
1053+
tree. If the rollout diverges from all other rollouts in the tree at
1054+
some step, a new node is created before the step where the rollouts
1055+
diverge, and a leaf node is created for the final step of the rollout.
1056+
If all of the rollout's steps match with a previously added rollout,
1057+
nothing changes. If the rollout matches up to a leaf node of a tree but
1058+
continues beyond it, that node is extended to the end of the rollout,
1059+
and no new nodes are created.
1060+
1061+
Args:
1062+
rollout (TensorDict): The rollout to add to the forest.
1063+
return_node (bool, optional): If True, the method returns the added
1064+
node. Default is ``False``.
1065+
1066+
Returns:
1067+
Tree: The node that was added to the forest. This is only
1068+
returned if ``return_node`` is True.
1069+
"""
10431070
source, dest = (
10441071
rollout.exclude("next").copy(),
10451072
rollout.select("next", *self.action_keys).copy(),

0 commit comments

Comments
 (0)