Skip to content

Commit 85fdbab

Browse files
authored
[BugFix] CatTensors: Prepended next_ to the out_key (#449)
1 parent 8dceee8 commit 85fdbab

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchrl/envs/transforms/transforms.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,7 +1544,7 @@ class CatTensors(Transform):
15441544
Args:
15451545
keys_in (Sequence of str): keys to be concatenated
15461546
out_key: key of the resulting tensor.
1547-
dim (int, optional): dimension along which the contenation will occur.
1547+
dim (int, optional): dimension along which the concatenation will occur.
15481548
Default is -1.
15491549
del_keys (bool, optional): if True, the input values will be deleted after
15501550
concatenation. Default is True.
@@ -1575,13 +1575,15 @@ class CatTensors(Transform):
15751575
def __init__(
15761576
self,
15771577
keys_in: Optional[Sequence[str]] = None,
1578-
out_key: str = "observation_vector",
1578+
out_key: str = "next_observation_vector",
15791579
dim: int = -1,
15801580
del_keys: bool = True,
15811581
unsqueeze_if_oor: bool = False,
15821582
):
15831583
if keys_in is None:
15841584
raise Exception("CatTensors requires keys to be non-empty")
1585+
if type(out_key) != str:
1586+
raise Exception("CatTensors requires out_key to be of type string")
15851587
super().__init__(keys_in=keys_in)
15861588
if not out_key.startswith("next_") and all(
15871589
key.startswith("next_") for key in keys_in

0 commit comments

Comments
 (0)