|
7 | 7 | import os.path
|
8 | 8 | import re
|
9 | 9 |
|
| 10 | +import numpy as np |
10 | 11 | import pytest
|
11 | 12 | import torch
|
12 | 13 | from _utils_internal import get_available_devices
|
@@ -1288,6 +1289,50 @@ def test_setitem_nested_dict_value(self, td_name, device):
|
1288 | 1289 | td_clone2["d"] = nested_tensordict_value
|
1289 | 1290 | assert (td_clone1 == td_clone2).all()
|
1290 | 1291 |
|
| 1292 | + def test_tensordict_set(self, td_name, device): |
| 1293 | + torch.manual_seed(1) |
| 1294 | + np.random.seed(1) |
| 1295 | + td = getattr(self, td_name)(device) |
| 1296 | + |
| 1297 | + # test set |
| 1298 | + val1 = np.ones(shape=(4, 3, 2, 1, 10)) |
| 1299 | + td.set("key1", val1) |
| 1300 | + assert (td.get("key1") == 1).all() |
| 1301 | + with pytest.raises(RuntimeError): |
| 1302 | + td.set("key1", np.ones(shape=(5, 10))) |
| 1303 | + |
| 1304 | + # test set_ |
| 1305 | + val2 = np.zeros(shape=(4, 3, 2, 1, 10)) |
| 1306 | + td.set_("key1", val2) |
| 1307 | + assert (td.get("key1") == 0).all() |
| 1308 | + with pytest.raises((KeyError, AttributeError)): |
| 1309 | + td.set_("smartypants", np.ones(shape=(4, 3, 2, 1, 5))) |
| 1310 | + |
| 1311 | + # test set_at_ |
| 1312 | + td.set("key2", np.random.randn(4, 3, 2, 1, 5)) |
| 1313 | + x = np.ones(shape=(2, 1, 5)) * 42 |
| 1314 | + td.set_at_("key2", x, (2, 2)) |
| 1315 | + assert (td.get("key2")[2, 2] == 42).all() |
| 1316 | + |
| 1317 | + def test_tensordict_set_dict_value(self, td_name, device): |
| 1318 | + torch.manual_seed(1) |
| 1319 | + np.random.seed(1) |
| 1320 | + td = getattr(self, td_name)(device) |
| 1321 | + |
| 1322 | + # test set |
| 1323 | + val1 = {"subkey1": torch.ones(4, 3, 2, 1, 10)} |
| 1324 | + td.set("key1", val1) |
| 1325 | + assert (td.get("key1").get("subkey1") == 1).all() |
| 1326 | + with pytest.raises(RuntimeError): |
| 1327 | + td.set("key1", torch.ones(5, 10)) |
| 1328 | + |
| 1329 | + # test set_ |
| 1330 | + val2 = {"subkey1": torch.zeros(4, 3, 2, 1, 10)} |
| 1331 | + td.set_("key1", val2) |
| 1332 | + assert (td.get("key1").get("subkey1") == 0).all() |
| 1333 | + with pytest.raises((KeyError, AttributeError)): |
| 1334 | + td.set_("smartypants", torch.ones(4, 3, 2, 1, 5)) |
| 1335 | + |
1291 | 1336 | def test_delitem(self, td_name, device):
|
1292 | 1337 | torch.manual_seed(1)
|
1293 | 1338 | td = getattr(self, td_name)(device)
|
|
0 commit comments