Skip to content

Commit 8cb1ee1

Browse files
Vincent Moensvmoens
authored andcommitted
[BugFix] Fix offline CatFrames (#1953)
1 parent 9987d92 commit 8cb1ee1

File tree

2 files changed

+183
-81
lines changed

2 files changed

+183
-81
lines changed

test/test_transforms.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -762,16 +762,16 @@ def test_transform_env_clone(self):
762762
).all()
763763
assert cloned is not env.transform
764764

765-
@pytest.mark.parametrize("dim", [-2, -1])
765+
@pytest.mark.parametrize("dim", [-1])
766766
@pytest.mark.parametrize("N", [3, 4])
767-
@pytest.mark.parametrize("padding", ["same", "zeros", "constant"])
767+
@pytest.mark.parametrize("padding", ["zeros", "constant", "same"])
768768
def test_transform_model(self, dim, N, padding):
769769
# test equivalence between transforms within an env and within a rb
770770
key1 = "observation"
771771
keys = [key1]
772772
out_keys = ["out_" + key1]
773773
cat_frames = CatFrames(
774-
N=N, in_keys=out_keys, out_keys=out_keys, dim=dim, padding=padding
774+
N=N, in_keys=keys, out_keys=out_keys, dim=dim, padding=padding
775775
)
776776
cat_frames2 = CatFrames(
777777
N=N,
@@ -781,23 +781,22 @@ def test_transform_model(self, dim, N, padding):
781781
padding=padding,
782782
)
783783
envbase = ContinuousActionVecMockEnv()
784-
env = TransformedEnv(
785-
envbase,
786-
Compose(
787-
UnsqueezeTransform(dim, in_keys=keys, out_keys=out_keys), cat_frames
788-
),
789-
)
784+
env = TransformedEnv(envbase, cat_frames)
785+
790786
torch.manual_seed(10)
791787
env.set_seed(10)
792788
td = env.rollout(10)
789+
793790
torch.manual_seed(10)
794791
envbase.set_seed(10)
795792
tdbase = envbase.rollout(10)
793+
796794
tdbase0 = tdbase.clone()
797795

798796
model = nn.Sequential(cat_frames2, nn.Identity())
799797
model(tdbase)
800-
assert (td == tdbase).all()
798+
assert assert_allclose_td(td, tdbase)
799+
801800
with pytest.warns(UserWarning):
802801
tdbase0.names = None
803802
model(tdbase0)
@@ -816,7 +815,7 @@ def test_transform_model(self, dim, N, padding):
816815
# check that swapping dims and names leads to same result
817816
assert_allclose_td(v1, v2.transpose(0, 1))
818817

819-
@pytest.mark.parametrize("dim", [-2, -1])
818+
@pytest.mark.parametrize("dim", [-1])
820819
@pytest.mark.parametrize("N", [3, 4])
821820
@pytest.mark.parametrize("padding", ["same", "zeros", "constant"])
822821
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
@@ -826,7 +825,7 @@ def test_transform_rb(self, dim, N, padding, rbclass):
826825
keys = [key1]
827826
out_keys = ["out_" + key1]
828827
cat_frames = CatFrames(
829-
N=N, in_keys=out_keys, out_keys=out_keys, dim=dim, padding=padding
828+
N=N, in_keys=keys, out_keys=out_keys, dim=dim, padding=padding
830829
)
831830
cat_frames2 = CatFrames(
832831
N=N,
@@ -836,12 +835,7 @@ def test_transform_rb(self, dim, N, padding, rbclass):
836835
padding=padding,
837836
)
838837

839-
env = TransformedEnv(
840-
ContinuousActionVecMockEnv(),
841-
Compose(
842-
UnsqueezeTransform(dim, in_keys=keys, out_keys=out_keys), cat_frames
843-
),
844-
)
838+
env = TransformedEnv(ContinuousActionVecMockEnv(), cat_frames)
845839
td = env.rollout(10)
846840

847841
rb = rbclass(storage=LazyTensorStorage(20))
@@ -875,8 +869,8 @@ def test_transform_as_inverse(self, dim, N, padding):
875869
td = env1.rollout(rollout_length)
876870

877871
transformed_td = cat_frames._inv_call(td)
878-
assert transformed_td.get(in_keys[0]).shape == (rollout_length, obs_dim, N)
879-
assert transformed_td.get(in_keys[1]).shape == (rollout_length, obs_dim, N)
872+
assert transformed_td.get(in_keys[0]).shape == (rollout_length, obs_dim * N)
873+
assert transformed_td.get(in_keys[1]).shape == (rollout_length, obs_dim * N)
880874
with pytest.raises(
881875
Exception,
882876
match="CatFrames as inverse is not supported as a transform for environments, only for replay buffers.",
@@ -971,14 +965,48 @@ def test_transform_no_env(self, device, d, batch_size, dim, N):
971965
# we don't want the same tensor to be returned twice, but they're all copies of the same buffer
972966
assert v1 is not v2
973967

968+
@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
969+
@pytest.mark.parametrize("padding", ["zeros", "constant", "same"])
970+
def test_tranform_offline_against_online(self, padding):
971+
torch.manual_seed(0)
972+
env = SerialEnv(
973+
3,
974+
lambda: TransformedEnv(
975+
GymEnv("CartPole-v1"),
976+
CatFrames(
977+
dim=-1,
978+
N=5,
979+
in_keys=["observation"],
980+
out_keys=["observation_cat"],
981+
padding=padding,
982+
),
983+
),
984+
)
985+
env.set_seed(0)
986+
987+
r = env.rollout(100, break_when_any_done=False)
988+
989+
c = CatFrames(
990+
dim=-1,
991+
N=5,
992+
in_keys=["observation", ("next", "observation")],
993+
out_keys=["observation_cat2", ("next", "observation_cat2")],
994+
padding=padding,
995+
)
996+
997+
r2 = c(r)
998+
999+
torch.testing.assert_close(r2["observation_cat2"], r2["observation_cat"])
1000+
assert (r2["observation_cat2"] == r2["observation_cat"]).all()
1001+
1002+
assert (r2["next", "observation_cat2"] == r2["next", "observation_cat"]).all()
1003+
9741004
@pytest.mark.parametrize("device", get_default_devices())
9751005
@pytest.mark.parametrize("batch_size", [(), (1,), (1, 2)])
9761006
@pytest.mark.parametrize("d", range(2, 3))
9771007
@pytest.mark.parametrize(
9781008
"dim",
979-
[
980-
-3,
981-
],
1009+
[-3],
9821010
)
9831011
@pytest.mark.parametrize("N", [2, 4])
9841012
def test_transform_compose(self, device, d, batch_size, dim, N):

torchrl/envs/transforms/transforms.py

Lines changed: 132 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2620,6 +2620,8 @@ class CatFrames(ObservationTransform):
26202620
reset indicator. Must be unique. If not provided, defaults to the
26212621
only reset key of the parent environment (if it has only one)
26222622
and raises an exception otherwise.
2623+
done_key (NestedKey, optional): the done key to be used as partial
2624+
done indicator. Must be unique. If not provided, defaults to ``"done"``.
26232625
26242626
Examples:
26252627
>>> from torchrl.envs.libs.gym import GymEnv
@@ -2700,6 +2702,7 @@ def __init__(
27002702
padding_value=0,
27012703
as_inverse=False,
27022704
reset_key: NestedKey | None = None,
2705+
done_key: NestedKey | None = None,
27032706
):
27042707
if in_keys is None:
27052708
in_keys = IMAGE_KEYS
@@ -2733,6 +2736,19 @@ def __init__(
27332736
# keeps track of calls to _reset since it's only _call that will populate the buffer
27342737
self.as_inverse = as_inverse
27352738
self.reset_key = reset_key
2739+
self.done_key = done_key
2740+
2741+
@property
2742+
def done_key(self):
2743+
done_key = self.__dict__.get("_done_key", None)
2744+
if done_key is None:
2745+
done_key = "done"
2746+
self._done_key = done_key
2747+
return done_key
2748+
2749+
@done_key.setter
2750+
def done_key(self, value):
2751+
self._done_key = value
27362752

27372753
@property
27382754
def reset_key(self):
@@ -2829,15 +2845,6 @@ def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase:
28292845
# make linter happy. An exception has already been raised
28302846
raise NotImplementedError
28312847

2832-
# # this duplicates the code below, but only for _reset values
2833-
# if _all:
2834-
# buffer.copy_(torch.roll(buffer_reset, shifts=-d, dims=dim))
2835-
# buffer_reset = buffer
2836-
# else:
2837-
# buffer_reset = buffer[_reset] = torch.roll(
2838-
# buffer_reset, shifts=-d, dims=dim
2839-
# )
2840-
# add new obs
28412848
if self.dim < 0:
28422849
n = buffer_reset.ndimension() + self.dim
28432850
else:
@@ -2906,69 +2913,136 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase:
29062913
if i != tensordict.ndim - 1:
29072914
tensordict = tensordict.transpose(tensordict.ndim - 1, i)
29082915
# first sort the in_keys with strings and non-strings
2909-
in_keys = list(
2910-
zip(
2911-
(in_key, out_key)
2912-
for in_key, out_key in zip(self.in_keys, self.out_keys)
2913-
if isinstance(in_key, str) or len(in_key) == 1
2914-
)
2915-
)
2916-
in_keys += list(
2917-
zip(
2918-
(in_key, out_key)
2919-
for in_key, out_key in zip(self.in_keys, self.out_keys)
2920-
if not isinstance(in_key, str) and not len(in_key) == 1
2916+
keys = [
2917+
(in_key, out_key)
2918+
for in_key, out_key in zip(self.in_keys, self.out_keys)
2919+
if isinstance(in_key, str)
2920+
]
2921+
keys += [
2922+
(in_key, out_key)
2923+
for in_key, out_key in zip(self.in_keys, self.out_keys)
2924+
if not isinstance(in_key, str)
2925+
]
2926+
2927+
def unfold_done(done, N):
2928+
prefix = (slice(None),) * (tensordict.ndim - 1)
2929+
reset = torch.cat(
2930+
[
2931+
torch.zeros_like(done[prefix + (slice(self.N - 1),)]),
2932+
torch.ones_like(done[prefix + (slice(1),)]),
2933+
done[prefix + (slice(None, -1),)],
2934+
],
2935+
tensordict.ndim - 1,
29212936
)
2922-
)
2923-
for in_key, out_key in zip(self.in_keys, self.out_keys):
2937+
reset_unfold = reset.unfold(tensordict.ndim - 1, self.N, 1)
2938+
reset_unfold_slice = reset_unfold[..., -1]
2939+
reset_unfold_list = [torch.zeros_like(reset_unfold_slice)]
2940+
for r in reversed(reset_unfold.unbind(-1)):
2941+
reset_unfold_list.append(r | reset_unfold_list[-1])
2942+
reset_unfold_slice = reset_unfold_list[-1]
2943+
reset_unfold = torch.stack(list(reversed(reset_unfold_list))[1:], -1)
2944+
reset = reset[prefix + (slice(self.N - 1, None),)]
2945+
reset[prefix + (0,)] = 1
2946+
return reset_unfold, reset
2947+
2948+
done = tensordict.get(("next", self.done_key))
2949+
done_mask, reset = unfold_done(done, self.N)
2950+
2951+
for in_key, out_key in keys:
29242952
# check if we have an obs in "next" that has already been processed.
29252953
# If so, we must add an offset
2926-
data = tensordict.get(in_key)
2954+
data_orig = data = tensordict.get(in_key)
2955+
n_feat = data_orig.shape[data.ndim + self.dim]
2956+
first_val = None
29272957
if isinstance(in_key, tuple) and in_key[0] == "next":
29282958
# let's get the out_key we have already processed
2929-
prev_out_key = dict(zip(self.in_keys, self.out_keys))[in_key[1]]
2930-
prev_val = tensordict.get(prev_out_key)
2931-
# the first item is located along `dim+1` at the last index of the
2932-
# first time index
2933-
idx = (
2934-
[slice(None)] * (tensordict.ndim - 1)
2935-
+ [0]
2936-
+ [..., -1]
2937-
+ [slice(None)] * (abs(self.dim) - 1)
2959+
prev_out_key = dict(zip(self.in_keys, self.out_keys)).get(
2960+
in_key[1], None
29382961
)
2939-
first_val = prev_val[tuple(idx)].unsqueeze(tensordict.ndim - 1)
2940-
data0 = [first_val] * (self.N - 1)
2941-
if self.padding == "constant":
2942-
data0 = [
2943-
torch.full_like(elt, self.padding_value) for elt in data0[:-1]
2944-
] + data0[-1:]
2945-
elif self.padding == "same":
2946-
pass
2947-
else:
2948-
# make linter happy. An exception has already been raised
2949-
raise NotImplementedError
2950-
elif self.padding == "same":
2951-
idx = [slice(None)] * (tensordict.ndim - 1) + [0]
2952-
data0 = [data[tuple(idx)].unsqueeze(tensordict.ndim - 1)] * (self.N - 1)
2953-
elif self.padding == "constant":
2954-
idx = [slice(None)] * (tensordict.ndim - 1) + [0]
2955-
data0 = [
2956-
torch.full_like(data[tuple(idx)], self.padding_value).unsqueeze(
2957-
tensordict.ndim - 1
2962+
if prev_out_key is not None:
2963+
prev_val = tensordict.get(prev_out_key)
2964+
# n_feat = prev_val.shape[data.ndim + self.dim] // self.N
2965+
first_val = prev_val.unflatten(
2966+
data.ndim + self.dim, (self.N, n_feat)
29582967
)
2959-
] * (self.N - 1)
2960-
else:
2961-
# make linter happy. An exception has already been raised
2962-
raise NotImplementedError
2968+
2969+
idx = [slice(None)] * (tensordict.ndim - 1) + [0]
2970+
data0 = [
2971+
torch.full_like(data[tuple(idx)], self.padding_value).unsqueeze(
2972+
tensordict.ndim - 1
2973+
)
2974+
] * (self.N - 1)
29632975

29642976
data = torch.cat(data0 + [data], tensordict.ndim - 1)
29652977

29662978
data = data.unfold(tensordict.ndim - 1, self.N, 1)
2979+
2980+
# Place -1 dim at self.dim place before squashing
2981+
done_mask_expand = expand_as_right(done_mask, data)
29672982
data = data.permute(
2968-
*range(0, data.ndim + self.dim),
2983+
*range(0, data.ndim + self.dim - 1),
2984+
-1,
2985+
*range(data.ndim + self.dim - 1, data.ndim - 1),
2986+
)
2987+
done_mask_expand = done_mask_expand.permute(
2988+
*range(0, done_mask_expand.ndim + self.dim - 1),
29692989
-1,
2970-
*range(data.ndim + self.dim, data.ndim - 1),
2990+
*range(done_mask_expand.ndim + self.dim - 1, done_mask_expand.ndim - 1),
29712991
)
2992+
if self.padding != "same":
2993+
data = torch.where(done_mask_expand, self.padding_value, data)
2994+
else:
2995+
# TODO: This is a pretty bad implementation, could be
2996+
# made more efficient but it works!
2997+
reset_vals = list(data_orig[reset.squeeze(-1)].unbind(0))
2998+
j_ = float("inf")
2999+
reps = []
3000+
d = data.ndim + self.dim - 1
3001+
for j in done_mask_expand.sum(d).sum(d).view(-1) // n_feat:
3002+
if j > j_:
3003+
reset_vals = reset_vals[1:]
3004+
reps.extend([reset_vals[0]] * int(j))
3005+
j_ = j
3006+
reps = torch.stack(reps)
3007+
data = torch.masked_scatter(data, done_mask_expand, reps.reshape(-1))
3008+
3009+
if first_val is not None:
3010+
# Aggregate reset along last dim
3011+
reset = reset.any(-1, True)
3012+
rexp = reset.expand(*reset.shape[:-1], n_feat)
3013+
rexp = torch.cat(
3014+
[
3015+
torch.zeros_like(
3016+
data0[0].repeat_interleave(
3017+
len(data0), dim=tensordict.ndim - 1
3018+
),
3019+
dtype=torch.bool,
3020+
),
3021+
rexp,
3022+
],
3023+
tensordict.ndim - 1,
3024+
)
3025+
rexp = rexp.unfold(tensordict.ndim - 1, self.N, 1)
3026+
rexp_orig = rexp
3027+
rexp = torch.cat([rexp[..., 1:], torch.zeros_like(rexp[..., -1:])], -1)
3028+
if self.padding == "same":
3029+
rexp_orig = rexp_orig.flip(-1).cumsum(-1).flip(-1).bool()
3030+
rexp = rexp.flip(-1).cumsum(-1).flip(-1).bool()
3031+
rexp_orig = torch.cat(
3032+
[torch.zeros_like(rexp_orig[..., -1:]), rexp_orig[..., 1:]], -1
3033+
)
3034+
rexp = rexp.permute(
3035+
*range(0, rexp.ndim + self.dim - 1),
3036+
-1,
3037+
*range(rexp.ndim + self.dim - 1, rexp.ndim - 1),
3038+
)
3039+
rexp_orig = rexp_orig.permute(
3040+
*range(0, rexp_orig.ndim + self.dim - 1),
3041+
-1,
3042+
*range(rexp_orig.ndim + self.dim - 1, rexp_orig.ndim - 1),
3043+
)
3044+
data[rexp] = first_val[rexp_orig]
3045+
data = data.flatten(data.ndim + self.dim - 1, data.ndim + self.dim)
29723046
tensordict.set(out_key, data)
29733047
if tensordict_orig is not tensordict:
29743048
tensordict_orig = tensordict.transpose(tensordict.ndim - 1, i)

0 commit comments

Comments
 (0)