@@ -1547,19 +1547,26 @@ def __init__(
1547
1547
if cat_dim > 0 :
1548
1548
raise ValueError (self ._CAT_DIM_ERR )
1549
1549
self .cat_dim = cat_dim
1550
+ for in_key in self .in_keys :
1551
+ buffer_name = f"_cat_buffers_{ in_key } "
1552
+ setattr (
1553
+ self ,
1554
+ buffer_name ,
1555
+ torch .nn .parameter .UninitializedBuffer (
1556
+ device = torch .device ("cpu" ), dtype = torch .get_default_dtype ()
1557
+ ),
1558
+ )
1550
1559
1551
1560
def reset (self , tensordict : TensorDictBase ) -> TensorDictBase :
1552
1561
"""Resets _buffers."""
1553
1562
# Non-batched environments
1554
1563
if len (tensordict .batch_size ) < 1 or tensordict .batch_size [0 ] == 1 :
1555
1564
for in_key in self .in_keys :
1556
1565
buffer_name = f"_cat_buffers_{ in_key } "
1557
- try :
1558
- buffer = getattr (self , buffer_name )
1559
- buffer .fill_ (0.0 )
1560
- except AttributeError :
1561
- # we'll instantiate later, when needed
1562
- pass
1566
+ buffer = getattr (self , buffer_name )
1567
+ if isinstance (buffer , torch .nn .parameter .UninitializedBuffer ):
1568
+ continue
1569
+ buffer .fill_ (0.0 )
1563
1570
1564
1571
# Batched environments
1565
1572
else :
@@ -1573,12 +1580,10 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
1573
1580
)
1574
1581
for in_key in self .in_keys :
1575
1582
buffer_name = f"_cat_buffers_{ in_key } "
1576
- try :
1577
- buffer = getattr (self , buffer_name )
1578
- buffer [_reset ] = 0.0
1579
- except AttributeError :
1580
- # we'll instantiate later, when needed
1581
- pass
1583
+ buffer = getattr (self , buffer_name )
1584
+ if isinstance (buffer , torch .nn .parameter .UninitializedBuffer ):
1585
+ continue
1586
+ buffer [_reset ] = 0.0
1582
1587
1583
1588
return tensordict
1584
1589
@@ -1587,15 +1592,9 @@ def _make_missing_buffer(self, data, buffer_name):
1587
1592
d = shape [self .cat_dim ]
1588
1593
shape [self .cat_dim ] = d * self .N
1589
1594
shape = torch .Size (shape )
1590
- self .register_buffer (
1591
- buffer_name ,
1592
- torch .zeros (
1593
- shape ,
1594
- dtype = data .dtype ,
1595
- device = data .device ,
1596
- ),
1597
- )
1598
- buffer = getattr (self , buffer_name )
1595
+ getattr (self , buffer_name ).materialize (shape )
1596
+ buffer = getattr (self , buffer_name ).to (data .dtype ).to (data .device ).zero_ ()
1597
+ setattr (self , buffer_name , buffer )
1599
1598
return buffer
1600
1599
1601
1600
def _call (self , tensordict : TensorDictBase ) -> TensorDictBase :
@@ -1605,12 +1604,12 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
1605
1604
buffer_name = f"_cat_buffers_{ in_key } "
1606
1605
data = tensordict [in_key ]
1607
1606
d = data .size (self .cat_dim )
1608
- try :
1609
- buffer = getattr (self , buffer_name )
1607
+ buffer = getattr (self , buffer_name )
1608
+ if isinstance (buffer , torch .nn .parameter .UninitializedBuffer ):
1609
+ buffer = self ._make_missing_buffer (data , buffer_name )
1610
+ else :
1610
1611
# shift obs 1 position to the right
1611
1612
buffer .copy_ (torch .roll (buffer , shifts = - d , dims = self .cat_dim ))
1612
- except AttributeError :
1613
- buffer = self ._make_missing_buffer (data , buffer_name )
1614
1613
# add new obs
1615
1614
idx = self .cat_dim
1616
1615
if idx < 0 :
0 commit comments