Skip to content

Commit 5ed32a9

Browse files
authored
Remove current_tensordict references (#229)
1 parent ff872a2 commit 5ed32a9

File tree

18 files changed

+882
-989
lines changed

18 files changed

+882
-989
lines changed

test/mocking_classes.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,14 @@ class DiscreteActionVecMockEnv(_MockEnv):
129129
)
130130
action_spec = OneHotDiscreteTensorSpec(7)
131131
reward_spec = UnboundedContinuousTensorSpec()
132+
132133
from_pixels = False
133134

134135
out_key = "observation"
136+
_out_key = "observation_orig"
137+
input_spec = CompositeSpec(
138+
**{_out_key: observation_spec["next_observation"], "action": action_spec}
139+
)
135140

136141
def _get_in_obs(self, obs):
137142
return obs
@@ -145,6 +150,7 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict:
145150
tensordict = tensordict.select().set(
146151
"next_" + self.out_key, self._get_out_obs(state)
147152
)
153+
tensordict = tensordict.set("next_" + self._out_key, self._get_out_obs(state))
148154
tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool))
149155
return tensordict
150156

@@ -157,12 +163,12 @@ def _step(
157163
assert (a.sum(-1) == 1).all()
158164
assert not self.is_done, "trying to execute step in done env"
159165

160-
obs = (
161-
self._get_in_obs(self.current_tensordict.get(self.out_key))
162-
+ a / self.maxstep
163-
)
166+
obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep
164167
tensordict = tensordict.select() # empty tensordict
168+
165169
tensordict.set("next_" + self.out_key, self._get_out_obs(obs))
170+
tensordict.set("next_" + self._out_key, self._get_out_obs(obs))
171+
166172
done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1))
167173
reward = done.any(-1).unsqueeze(-1)
168174
# set done to False
@@ -182,6 +188,10 @@ class ContinuousActionVecMockEnv(_MockEnv):
182188
from_pixels = False
183189

184190
out_key = "observation"
191+
_out_key = "observation_orig"
192+
input_spec = CompositeSpec(
193+
**{_out_key: observation_spec["next_observation"], "action": action_spec}
194+
)
185195

186196
def _get_in_obs(self, obs):
187197
return obs
@@ -193,9 +203,9 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict:
193203
self.counter += 1
194204
self.step_count = 0
195205
state = torch.zeros(self.size) + self.counter
196-
tensordict = tensordict.select().set(
197-
"next_" + self.out_key, self._get_out_obs(state)
198-
)
206+
tensordict = tensordict.select()
207+
tensordict.set("next_" + self.out_key, self._get_out_obs(state))
208+
tensordict.set("next_" + self._out_key, self._get_out_obs(state))
199209
tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool))
200210
return tensordict
201211

@@ -208,11 +218,12 @@ def _step(
208218
a = tensordict.get("action")
209219
assert not self.is_done, "trying to execute step in done env"
210220

211-
obs = self._obs_step(
212-
self._get_in_obs(self.current_tensordict.get(self.out_key)), a
213-
)
221+
obs = self._obs_step(self._get_in_obs(tensordict.get(self._out_key)), a)
214222
tensordict = tensordict.select() # empty tensordict
223+
215224
tensordict.set("next_" + self.out_key, self._get_out_obs(obs))
225+
tensordict.set("next_" + self._out_key, self._get_out_obs(obs))
226+
216227
done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1))
217228
reward = done.any(-1).unsqueeze(-1)
218229
done = done.all(-1).unsqueeze(-1)
@@ -251,6 +262,10 @@ class DiscreteActionConvMockEnv(DiscreteActionVecMockEnv):
251262
from_pixels = True
252263

253264
out_key = "pixels"
265+
_out_key = "pixels_orig"
266+
input_spec = CompositeSpec(
267+
**{_out_key: observation_spec["next_pixels"], "action": action_spec}
268+
)
254269

255270
def _get_out_obs(self, obs):
256271
obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(0)
@@ -287,6 +302,10 @@ class ContinuousActionConvMockEnv(ContinuousActionVecMockEnv):
287302
from_pixels = True
288303

289304
out_key = "pixels"
305+
_out_key = "pixels_orig"
306+
input_spec = CompositeSpec(
307+
**{_out_key: observation_spec["next_pixels"], "action": action_spec}
308+
)
290309

291310
def _get_out_obs(self, obs):
292311
obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(0)

test/test_collector.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def env_fn():
274274
ccollector.shutdown()
275275

276276

277-
@pytest.mark.parametrize("num_env", [3, 1])
277+
@pytest.mark.parametrize("num_env", [1, 3])
278278
@pytest.mark.parametrize("env_name", ["conv", "vec"])
279279
def test_collector_consistency(num_env, env_name, seed=100):
280280
if num_env == 1:
@@ -320,9 +320,9 @@ def env_fn(seed):
320320
device="cpu",
321321
pin_memory=False,
322322
)
323-
collector = iter(collector)
324-
b1 = next(collector)
325-
b2 = next(collector)
323+
collector_iter = iter(collector)
324+
b1 = next(collector_iter)
325+
b2 = next(collector_iter)
326326
with pytest.raises(AssertionError):
327327
assert_allclose_td(b1, b2)
328328

@@ -334,6 +334,7 @@ def env_fn(seed):
334334
), f"got batch_size {rollout1a.batch_size} and {b1.batch_size}"
335335

336336
assert_allclose_td(rollout1a, b1.select(*rollout1a.keys()))
337+
collector.shutdown()
337338

338339

339340
@pytest.mark.parametrize("num_env", [1, 3])

0 commit comments

Comments
 (0)