Skip to content

Commit ff5f78d

Browse files
authored
[Refactor] Refactoring set*() methods for TensorDictBase class (#438)
* Refactoring set*() functions. Initial commit * Lint fix * Adding tests, additional refactoring
1 parent 902a393 commit ff5f78d

File tree

2 files changed

+104
-86
lines changed

2 files changed

+104
-86
lines changed

test/test_tensordict.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os.path
88
import re
99

10+
import numpy as np
1011
import pytest
1112
import torch
1213
from _utils_internal import get_available_devices
@@ -1288,6 +1289,50 @@ def test_setitem_nested_dict_value(self, td_name, device):
12881289
td_clone2["d"] = nested_tensordict_value
12891290
assert (td_clone1 == td_clone2).all()
12901291

1292+
def test_tensordict_set(self, td_name, device):
1293+
torch.manual_seed(1)
1294+
np.random.seed(1)
1295+
td = getattr(self, td_name)(device)
1296+
1297+
# test set
1298+
val1 = np.ones(shape=(4, 3, 2, 1, 10))
1299+
td.set("key1", val1)
1300+
assert (td.get("key1") == 1).all()
1301+
with pytest.raises(RuntimeError):
1302+
td.set("key1", np.ones(shape=(5, 10)))
1303+
1304+
# test set_
1305+
val2 = np.zeros(shape=(4, 3, 2, 1, 10))
1306+
td.set_("key1", val2)
1307+
assert (td.get("key1") == 0).all()
1308+
with pytest.raises((KeyError, AttributeError)):
1309+
td.set_("smartypants", np.ones(shape=(4, 3, 2, 1, 5)))
1310+
1311+
# test set_at_
1312+
td.set("key2", np.random.randn(4, 3, 2, 1, 5))
1313+
x = np.ones(shape=(2, 1, 5)) * 42
1314+
td.set_at_("key2", x, (2, 2))
1315+
assert (td.get("key2")[2, 2] == 42).all()
1316+
1317+
def test_tensordict_set_dict_value(self, td_name, device):
1318+
torch.manual_seed(1)
1319+
np.random.seed(1)
1320+
td = getattr(self, td_name)(device)
1321+
1322+
# test set
1323+
val1 = {"subkey1": torch.ones(4, 3, 2, 1, 10)}
1324+
td.set("key1", val1)
1325+
assert (td.get("key1").get("subkey1") == 1).all()
1326+
with pytest.raises(RuntimeError):
1327+
td.set("key1", torch.ones(5, 10))
1328+
1329+
# test set_
1330+
val2 = {"subkey1": torch.zeros(4, 3, 2, 1, 10)}
1331+
td.set_("key1", val2)
1332+
assert (td.get("key1").get("subkey1") == 0).all()
1333+
with pytest.raises((KeyError, AttributeError)):
1334+
td.set_("smartypants", torch.ones(4, 3, 2, 1, 5))
1335+
12911336
def test_delitem(self, td_name, device):
12921337
torch.manual_seed(1)
12931338
td = getattr(self, td_name)(device)

0 commit comments

Comments
 (0)