Skip to content

Commit edc284f

Browse files
roloVincent Moens
authored andcommitted
[BugFix] Tree make node fix (#2839)
(cherry picked from commit ba8be9c)
1 parent 0436851 commit edc284f

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

test/test_storage_map.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,17 @@ def test_edges(self):
350350
edges_check = {(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)}
351351
assert edges == edges_check
352352

353+
def test_make_node(self):
354+
td = TensorDict({"obs": torch.tensor([0])})
355+
tree = Tree(node_data=td)
356+
assert tree.node_data is not None
357+
358+
tree = Tree.make_node(data=td)
359+
assert tree.node_data is not None
360+
361+
tree = Tree.make_node(td)
362+
assert tree.node_data is not None
363+
353364

354365
class TestMCTSForest:
355366
def dummy_rollouts(self) -> Tuple[TensorDict, ...]:

torchrl/data/map/tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def make_node(
123123
return cls(
124124
count=torch.zeros(()),
125125
wins=torch.zeros(()),
126-
node=data.exclude("action", "next"),
126+
node_data=data.exclude("action", "next"),
127127
rollout=rollout,
128128
subtree=subtree,
129129
device=device,

0 commit comments

Comments
 (0)