Skip to content

Commit 8efbb26

Browse files
authored
[BugFix] Loading state_dict on uninitialized CatFrames (#855)
1 parent aa971a7 commit 8efbb26

File tree

1 file changed

+24
-25
lines changed

1 file changed

+24
-25
lines changed

torchrl/envs/transforms/transforms.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,19 +1547,26 @@ def __init__(
15471547
if cat_dim > 0:
15481548
raise ValueError(self._CAT_DIM_ERR)
15491549
self.cat_dim = cat_dim
1550+
for in_key in self.in_keys:
1551+
buffer_name = f"_cat_buffers_{in_key}"
1552+
setattr(
1553+
self,
1554+
buffer_name,
1555+
torch.nn.parameter.UninitializedBuffer(
1556+
device=torch.device("cpu"), dtype=torch.get_default_dtype()
1557+
),
1558+
)
15501559

15511560
def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
15521561
"""Resets _buffers."""
15531562
# Non-batched environments
15541563
if len(tensordict.batch_size) < 1 or tensordict.batch_size[0] == 1:
15551564
for in_key in self.in_keys:
15561565
buffer_name = f"_cat_buffers_{in_key}"
1557-
try:
1558-
buffer = getattr(self, buffer_name)
1559-
buffer.fill_(0.0)
1560-
except AttributeError:
1561-
# we'll instantiate later, when needed
1562-
pass
1566+
buffer = getattr(self, buffer_name)
1567+
if isinstance(buffer, torch.nn.parameter.UninitializedBuffer):
1568+
continue
1569+
buffer.fill_(0.0)
15631570

15641571
# Batched environments
15651572
else:
@@ -1573,12 +1580,10 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
15731580
)
15741581
for in_key in self.in_keys:
15751582
buffer_name = f"_cat_buffers_{in_key}"
1576-
try:
1577-
buffer = getattr(self, buffer_name)
1578-
buffer[_reset] = 0.0
1579-
except AttributeError:
1580-
# we'll instantiate later, when needed
1581-
pass
1583+
buffer = getattr(self, buffer_name)
1584+
if isinstance(buffer, torch.nn.parameter.UninitializedBuffer):
1585+
continue
1586+
buffer[_reset] = 0.0
15821587

15831588
return tensordict
15841589

@@ -1587,15 +1592,9 @@ def _make_missing_buffer(self, data, buffer_name):
15871592
d = shape[self.cat_dim]
15881593
shape[self.cat_dim] = d * self.N
15891594
shape = torch.Size(shape)
1590-
self.register_buffer(
1591-
buffer_name,
1592-
torch.zeros(
1593-
shape,
1594-
dtype=data.dtype,
1595-
device=data.device,
1596-
),
1597-
)
1598-
buffer = getattr(self, buffer_name)
1595+
getattr(self, buffer_name).materialize(shape)
1596+
buffer = getattr(self, buffer_name).to(data.dtype).to(data.device).zero_()
1597+
setattr(self, buffer_name, buffer)
15991598
return buffer
16001599

16011600
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
@@ -1605,12 +1604,12 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
16051604
buffer_name = f"_cat_buffers_{in_key}"
16061605
data = tensordict[in_key]
16071606
d = data.size(self.cat_dim)
1608-
try:
1609-
buffer = getattr(self, buffer_name)
1607+
buffer = getattr(self, buffer_name)
1608+
if isinstance(buffer, torch.nn.parameter.UninitializedBuffer):
1609+
buffer = self._make_missing_buffer(data, buffer_name)
1610+
else:
16101611
# shift obs 1 position to the right
16111612
buffer.copy_(torch.roll(buffer, shifts=-d, dims=self.cat_dim))
1612-
except AttributeError:
1613-
buffer = self._make_missing_buffer(data, buffer_name)
16141613
# add new obs
16151614
idx = self.cat_dim
16161615
if idx < 0:

0 commit comments

Comments
 (0)