Skip to content

Commit f078fcd

Browse files
Vincent Moensvmoens
authored andcommitted
[BugFix] Fix offline CatFrames for pixels (#1964)
1 parent 8cb1ee1 commit f078fcd

File tree

2 files changed

+29
-18
lines changed

2 files changed

+29
-18
lines changed

test/test_transforms.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -967,17 +967,21 @@ def test_transform_no_env(self, device, d, batch_size, dim, N):
967967

968968
@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
969969
@pytest.mark.parametrize("padding", ["zeros", "constant", "same"])
970-
def test_tranform_offline_against_online(self, padding):
970+
@pytest.mark.parametrize("envtype", ["gym", "conv"])
971+
def test_tranform_offline_against_online(self, padding, envtype):
971972
torch.manual_seed(0)
973+
key = "observation" if envtype == "gym" else "pixels"
972974
env = SerialEnv(
973975
3,
974976
lambda: TransformedEnv(
975-
GymEnv("CartPole-v1"),
977+
GymEnv("CartPole-v1")
978+
if envtype == "gym"
979+
else DiscreteActionConvMockEnv(),
976980
CatFrames(
977-
dim=-1,
981+
dim=-3 if envtype == "conv" else -1,
978982
N=5,
979-
in_keys=["observation"],
980-
out_keys=["observation_cat"],
983+
in_keys=[key],
984+
out_keys=[f"{key}_cat"],
981985
padding=padding,
982986
),
983987
),
@@ -987,19 +991,17 @@ def test_tranform_offline_against_online(self, padding):
987991
r = env.rollout(100, break_when_any_done=False)
988992

989993
c = CatFrames(
990-
dim=-1,
994+
dim=-3 if envtype == "conv" else -1,
991995
N=5,
992-
in_keys=["observation", ("next", "observation")],
993-
out_keys=["observation_cat2", ("next", "observation_cat2")],
996+
in_keys=[key, ("next", key)],
997+
out_keys=[f"{key}_cat2", ("next", f"{key}_cat2")],
994998
padding=padding,
995999
)
9961000

9971001
r2 = c(r)
9981002

999-
torch.testing.assert_close(r2["observation_cat2"], r2["observation_cat"])
1000-
assert (r2["observation_cat2"] == r2["observation_cat"]).all()
1001-
1002-
assert (r2["next", "observation_cat2"] == r2["next", "observation_cat"]).all()
1003+
torch.testing.assert_close(r2[f"{key}_cat2"], r2[f"{key}_cat"])
1004+
torch.testing.assert_close(r2["next", f"{key}_cat2"], r2["next", f"{key}_cat"])
10031005

10041006
@pytest.mark.parametrize("device", get_default_devices())
10051007
@pytest.mark.parametrize("batch_size", [(), (1,), (1, 2)])

torchrl/envs/transforms/transforms.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030
from tensordict._tensordict import _unravel_key_to_tuple
3131
from tensordict.nn import dispatch, TensorDictModuleBase
32-
from tensordict.utils import expand_as_right, NestedKey
32+
from tensordict.utils import expand_as_right, expand_right, NestedKey
3333
from torch import nn, Tensor
3434
from torch.utils._pytree import tree_map
3535
from torchrl._utils import _replace_last
@@ -2978,7 +2978,12 @@ def unfold_done(done, N):
29782978
data = data.unfold(tensordict.ndim - 1, self.N, 1)
29792979

29802980
# Place -1 dim at self.dim place before squashing
2981-
done_mask_expand = expand_as_right(done_mask, data)
2981+
done_mask_expand = done_mask.view(
2982+
*done_mask.shape[: tensordict.ndim],
2983+
*(1,) * (data.ndim - 1 - tensordict.ndim),
2984+
done_mask.shape[-1],
2985+
)
2986+
done_mask_expand = expand_as_right(done_mask_expand, data)
29822987
data = data.permute(
29832988
*range(0, data.ndim + self.dim - 1),
29842989
-1,
@@ -2994,11 +2999,13 @@ def unfold_done(done, N):
29942999
else:
29953000
# TODO: This is a pretty bad implementation, could be
29963001
# made more efficient but it works!
2997-
reset_vals = list(data_orig[reset.squeeze(-1)].unbind(0))
3002+
reset_any = reset.any(-1, False)
3003+
reset_vals = list(data_orig[reset_any].unbind(0))
29983004
j_ = float("inf")
29993005
reps = []
30003006
d = data.ndim + self.dim - 1
3001-
for j in done_mask_expand.sum(d).sum(d).view(-1) // n_feat:
3007+
n_feat = data.shape[data.ndim + self.dim :].numel()
3008+
for j in done_mask_expand.flatten(d, -1).sum(-1).view(-1) // n_feat:
30023009
if j > j_:
30033010
reset_vals = reset_vals[1:]
30043011
reps.extend([reset_vals[0]] * int(j))
@@ -3008,8 +3015,10 @@ def unfold_done(done, N):
30083015

30093016
if first_val is not None:
30103017
# Aggregate reset along last dim
3011-
reset = reset.any(-1, True)
3012-
rexp = reset.expand(*reset.shape[:-1], n_feat)
3018+
reset_any = reset.any(-1, False)
3019+
rexp = expand_right(
3020+
reset_any, (*reset_any.shape, *data.shape[data.ndim + self.dim :])
3021+
)
30133022
rexp = torch.cat(
30143023
[
30153024
torch.zeros_like(

0 commit comments

Comments
 (0)