Skip to content

Commit 3854ea4

Browse files
author
Vincent Moens
committed
Update (base update)
[ghstack-poisoned]
1 parent b538c66 commit 3854ea4

File tree

2 files changed

+138
-62
lines changed

2 files changed

+138
-62
lines changed

torchrl/envs/custom/pendulum.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,20 @@ def _reset(self, tensordict):
269269
batch_size = (
270270
tensordict.batch_size if tensordict is not None else self.batch_size
271271
)
272-
if tensordict is None or tensordict.is_empty():
272+
if tensordict is None or "params" not in tensordict:
273273
# if no ``tensordict`` is passed, we generate a single set of hyperparameters
274274
# Otherwise, we assume that the input ``tensordict`` contains all the relevant
275275
# parameters to get started.
276276
tensordict = self.gen_params(batch_size=batch_size, device=self.device)
277+
elif "th" in tensordict and "thdot" in tensordict:
278+
# we can hard-reset the env too
279+
return tensordict
280+
out = self._reset_random_data(
281+
tensordict.shape, batch_size, tensordict["params"]
282+
)
283+
return out
284+
285+
def _reset_random_data(self, shape, batch_size, params):
277286

278287
high_th = torch.tensor(self.DEFAULT_X, device=self.device)
279288
high_thdot = torch.tensor(self.DEFAULT_Y, device=self.device)
@@ -284,20 +293,20 @@ def _reset(self, tensordict):
284293
# of simulators run simultaneously. In other contexts, the initial
285294
# random state's shape will depend upon the environment batch-size instead.
286295
th = (
287-
torch.rand(tensordict.shape, generator=self.rng, device=self.device)
296+
torch.rand(shape, generator=self.rng, device=self.device)
288297
* (high_th - low_th)
289298
+ low_th
290299
)
291300
thdot = (
292-
torch.rand(tensordict.shape, generator=self.rng, device=self.device)
301+
torch.rand(shape, generator=self.rng, device=self.device)
293302
* (high_thdot - low_thdot)
294303
+ low_thdot
295304
)
296305
out = TensorDict(
297306
{
298307
"th": th,
299308
"thdot": thdot,
300-
"params": tensordict["params"],
309+
"params": params,
301310
},
302311
batch_size=batch_size,
303312
)

0 commit comments

Comments
 (0)