Skip to content

Commit c0bc12a

Browse files
authored
[Feature] Refactor CatFrames using a proper preallocated buffer (#847)
1 parent 4a81a6c commit c0bc12a

File tree

3 files changed

+109
-30
lines changed

3 files changed

+109
-30
lines changed

test/test_shared.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import argparse
6-
import sys
76
import time
87
import warnings
98

test/test_transforms.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,44 +1289,55 @@ def test_catframes_transform_observation_spec(self):
12891289
)
12901290

12911291
@pytest.mark.parametrize("device", get_available_devices())
1292-
def test_catframes_buffer_check_latest_frame(self, device):
1292+
@pytest.mark.parametrize("d", range(1, 4))
1293+
def test_catframes_buffer_check_latest_frame(self, device, d):
12931294
key1 = "first key"
12941295
key2 = "second key"
12951296
N = 4
12961297
keys = [key1, key2]
1297-
key1_tensor = torch.zeros(1, 1, 3, 3, device=device)
1298-
key2_tensor = torch.ones(1, 1, 3, 3, device=device)
1298+
key1_tensor = torch.ones(1, d, 3, 3, device=device) * 2
1299+
key2_tensor = torch.ones(1, d, 3, 3, device=device)
12991300
key_tensors = [key1_tensor, key2_tensor]
13001301
td = TensorDict(dict(zip(keys, key_tensors)), [1], device=device)
13011302
cat_frames = CatFrames(N=N, in_keys=keys)
13021303

1303-
cat_frames(td)
1304-
latest_frame = td.get(key2)
1304+
tdclone = cat_frames(td.clone())
1305+
latest_frame = tdclone.get(key2)
1306+
1307+
assert latest_frame.shape[1] == N * d
1308+
assert (latest_frame[0, :-d] == 0).all()
1309+
assert (latest_frame[0, -d:] == 1).all()
1310+
1311+
tdclone = cat_frames(td.clone())
1312+
latest_frame = tdclone.get(key2)
13051313

1306-
assert latest_frame.shape[1] == N
1307-
for i in range(0, N - 1):
1308-
assert torch.equal(latest_frame[0][i], key2_tensor[0][0])
1309-
assert torch.equal(latest_frame[0][N - 1], key1_tensor[0][0])
1314+
assert latest_frame.shape[1] == N * d
1315+
assert (latest_frame[0, : -2 * d] == 0).all()
1316+
assert (latest_frame[0, -2 * d :] == 1).all()
13101317

13111318
@pytest.mark.parametrize("device", get_available_devices())
13121319
def test_catframes_reset(self, device):
13131320
key1 = "first key"
13141321
key2 = "second key"
13151322
N = 4
13161323
keys = [key1, key2]
1317-
key1_tensor = torch.zeros(1, 1, 3, 3, device=device)
1318-
key2_tensor = torch.ones(1, 1, 3, 3, device=device)
1324+
key1_tensor = torch.randn(1, 1, 3, 3, device=device)
1325+
key2_tensor = torch.randn(1, 1, 3, 3, device=device)
13191326
key_tensors = [key1_tensor, key2_tensor]
13201327
td = TensorDict(dict(zip(keys, key_tensors)), [1], device=device)
13211328
cat_frames = CatFrames(N=N, in_keys=keys)
13221329

13231330
cat_frames(td)
1324-
buffer_length1 = len(cat_frames.buffer)
1331+
buffer = getattr(cat_frames, f"_cat_buffers_{key1}")
1332+
13251333
passed_back_td = cat_frames.reset(td)
13261334

1327-
assert buffer_length1 == 2
13281335
assert td is passed_back_td
1329-
assert 0 == len(cat_frames.buffer)
1336+
assert (0 == buffer).all()
1337+
1338+
_ = cat_frames._call(td)
1339+
assert (0 == buffer[..., :-1, :, :]).all()
1340+
assert (0 != buffer[..., -1:, :, :]).all()
13301341

13311342
@pytest.mark.parametrize("device", get_available_devices())
13321343
def test_finitetensordictcheck(self, device):

torchrl/envs/transforms/transforms.py

Lines changed: 84 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,27 +1521,107 @@ class CatFrames(ObservationTransform):
15211521
cat_dim (int, optional): dimension along which concatenate the
15221522
observations. Default is `cat_dim=-3`.
15231523
in_keys (list of int, optional): keys pointing to the frames that have
1524-
to be concatenated.
1524+
to be concatenated. Defaults to ["pixels"].
1525+
out_keys (list of int, optional): keys pointing to where the output
1526+
has to be written. Defaults to the value of `in_keys`.
15251527
15261528
"""
15271529

15281530
inplace = False
1531+
_CAT_DIM_ERR = (
1532+
"cat_dim must be > 0 to accomodate for tensordict of "
1533+
"different batch-sizes (since negative dims are batch invariant)."
1534+
)
15291535

15301536
def __init__(
15311537
self,
15321538
N: int = 4,
15331539
cat_dim: int = -3,
15341540
in_keys: Optional[Sequence[str]] = None,
1541+
out_keys: Optional[Sequence[str]] = None,
15351542
):
15361543
if in_keys is None:
15371544
in_keys = IMAGE_KEYS
1538-
super().__init__(in_keys=in_keys)
1545+
super().__init__(in_keys=in_keys, out_keys=out_keys)
15391546
self.N = N
1547+
if cat_dim > 0:
1548+
raise ValueError(self._CAT_DIM_ERR)
15401549
self.cat_dim = cat_dim
1541-
self.buffer = []
15421550

15431551
def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
1544-
self.buffer = []
1552+
"""Resets _buffers."""
1553+
# Non-batched environments
1554+
if len(tensordict.batch_size) < 1 or tensordict.batch_size[0] == 1:
1555+
for in_key in self.in_keys:
1556+
buffer_name = f"_cat_buffers_{in_key}"
1557+
try:
1558+
buffer = getattr(self, buffer_name)
1559+
buffer.fill_(0.0)
1560+
except AttributeError:
1561+
# we'll instantiate later, when needed
1562+
pass
1563+
1564+
# Batched environments
1565+
else:
1566+
_reset = tensordict.get(
1567+
"_reset",
1568+
torch.ones(
1569+
tensordict.batch_size,
1570+
dtype=torch.bool,
1571+
device=tensordict.device,
1572+
),
1573+
)
1574+
for in_key in self.in_keys:
1575+
buffer_name = f"_cat_buffers_{in_key}"
1576+
try:
1577+
buffer = getattr(self, buffer_name)
1578+
buffer[_reset] = 0.0
1579+
except AttributeError:
1580+
# we'll instantiate later, when needed
1581+
pass
1582+
1583+
return tensordict
1584+
1585+
def _make_missing_buffer(self, data, buffer_name):
1586+
shape = list(data.shape)
1587+
d = shape[self.cat_dim]
1588+
shape[self.cat_dim] = d * self.N
1589+
shape = torch.Size(shape)
1590+
self.register_buffer(
1591+
buffer_name,
1592+
torch.zeros(
1593+
shape,
1594+
dtype=data.dtype,
1595+
device=data.device,
1596+
),
1597+
)
1598+
buffer = getattr(self, buffer_name)
1599+
return buffer
1600+
1601+
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
1602+
"""Update the episode tensordict with max pooled keys."""
1603+
for in_key, out_key in zip(self.in_keys, self.out_keys):
1604+
# Lazy init of buffers
1605+
buffer_name = f"_cat_buffers_{in_key}"
1606+
data = tensordict[in_key]
1607+
d = data.size(self.cat_dim)
1608+
try:
1609+
buffer = getattr(self, buffer_name)
1610+
# shift obs 1 position to the right
1611+
buffer.copy_(torch.roll(buffer, shifts=-d, dims=self.cat_dim))
1612+
except AttributeError:
1613+
buffer = self._make_missing_buffer(data, buffer_name)
1614+
# add new obs
1615+
idx = self.cat_dim
1616+
if idx < 0:
1617+
idx = buffer.ndimension() + idx
1618+
else:
1619+
raise ValueError(self._CAT_DIM_ERR)
1620+
idx = [slice(None, None) for _ in range(idx)] + [slice(-d, None)]
1621+
buffer[idx].copy_(data)
1622+
# add to tensordict
1623+
tensordict.set(out_key, buffer.clone())
1624+
15451625
return tensordict
15461626

15471627
@_apply_to_composite
@@ -1557,17 +1637,6 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
15571637
observation_spec.shape = torch.Size(shape)
15581638
return observation_spec
15591639

1560-
def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
1561-
self.buffer.append(obs)
1562-
self.buffer = self.buffer[-self.N :]
1563-
buffer = list(reversed(self.buffer))
1564-
buffer = [buffer[0]] * (self.N - len(buffer)) + buffer
1565-
if len(buffer) != self.N:
1566-
raise RuntimeError(
1567-
f"actual buffer length ({buffer}) differs from expected (" f"{self.N})"
1568-
)
1569-
return torch.cat(buffer, self.cat_dim)
1570-
15711640
def __repr__(self) -> str:
15721641
return (
15731642
f"{self.__class__.__name__}(N={self.N}, cat_dim"

0 commit comments

Comments
 (0)