Skip to content

Commit 8d13206

Browse files
committed
Update
[ghstack-poisoned]
2 parents ef90882 + 5bb3334 commit 8d13206

File tree

7 files changed

+90
-45
lines changed

7 files changed

+90
-45
lines changed

test/test_env.py

+28
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,34 @@ def test_parallel_env_device(
16921692
env_serial.close(raise_if_closed=False)
16931693
env0.close(raise_if_closed=False)
16941694

1695+
@pytest.mark.skipif(not _has_gym, reason="no gym")
1696+
@pytest.mark.parametrize("env_device", [None, "cpu"])
1697+
def test_parallel_env_device_vs_no_device(self, maybe_fork_ParallelEnv, env_device):
1698+
def make_env() -> GymEnv:
1699+
env = GymEnv(PENDULUM_VERSIONED(), device=env_device)
1700+
return env.append_transform(DoubleToFloat())
1701+
1702+
# Rollouts work with a regular env
1703+
parallel_env = maybe_fork_ParallelEnv(
1704+
num_workers=1, create_env_fn=make_env, device=None
1705+
)
1706+
parallel_env.reset()
1707+
parallel_env.set_seed(0)
1708+
torch.manual_seed(0)
1709+
1710+
parallel_rollout = parallel_env.rollout(max_steps=10)
1711+
1712+
# Rollout doesn't work with Parallelnv
1713+
parallel_env = maybe_fork_ParallelEnv(
1714+
num_workers=1, create_env_fn=make_env, device="cpu"
1715+
)
1716+
parallel_env.reset()
1717+
parallel_env.set_seed(0)
1718+
torch.manual_seed(0)
1719+
1720+
parallel_rollout_cpu = parallel_env.rollout(max_steps=10)
1721+
assert_allclose_td(parallel_rollout, parallel_rollout_cpu)
1722+
16951723
@pytest.mark.skipif(not _has_gym, reason="no gym")
16961724
@pytest.mark.flaky(reruns=3, reruns_delay=1)
16971725
@pytest.mark.parametrize(

test/test_storage_map.py

+11
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,17 @@ def test_edges(self):
350350
edges_check = {(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)}
351351
assert edges == edges_check
352352

353+
def test_make_node(self):
354+
td = TensorDict({"obs": torch.tensor([0])})
355+
tree = Tree(node_data=td)
356+
assert tree.node_data is not None
357+
358+
tree = Tree.make_node(data=td)
359+
assert tree.node_data is not None
360+
361+
tree = Tree.make_node(td)
362+
assert tree.node_data is not None
363+
353364

354365
class TestMCTSForest:
355366
def dummy_rollouts(self) -> Tuple[TensorDict, ...]:

torchrl/_utils.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import warnings
1919
from contextlib import nullcontext
2020
from copy import copy
21-
from distutils.util import strtobool
2221
from functools import wraps
2322
from importlib import import_module
2423
from typing import Any, Callable, cast, TypeVar
@@ -35,6 +34,21 @@
3534
except ImportError:
3635
from torch._dynamo import is_compiling
3736

37+
38+
def strtobool(val: Any) -> bool:
39+
"""Convert a string representation of truth to a boolean.
40+
41+
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
42+
Raises ValueError if 'val' is anything else.
43+
"""
44+
val = val.lower()
45+
if val in ("y", "yes", "t", "true", "on", "1"):
46+
return True
47+
if val in ("n", "no", "f", "false", "off", "0"):
48+
return False
49+
raise ValueError(f"Invalid truth value {val!r}")
50+
51+
3852
LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO")
3953
logger = logging.getLogger("torchrl")
4054
logger.setLevel(getattr(logging, LOGGING_LEVEL))

torchrl/data/map/tree.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def make_node(
122122
return cls(
123123
count=torch.zeros(()),
124124
wins=torch.zeros(()),
125-
node=data.exclude("action", "next"),
125+
node_data=data.exclude("action", "next"),
126126
rollout=rollout,
127127
subtree=subtree,
128128
device=device,

torchrl/envs/batched_envs.py

+19-35
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,14 @@ def __init__(
379379

380380
is_spec_locked = EnvBase.is_spec_locked
381381

382+
def select_and_clone(self, name, tensor, selected_keys=None):
383+
if selected_keys is None:
384+
selected_keys = self._selected_step_keys
385+
if name in selected_keys:
386+
if self.device is not None and tensor.device != self.device:
387+
return tensor.to(self.device, non_blocking=self.non_blocking)
388+
return tensor.clone()
389+
382390
@property
383391
def non_blocking(self):
384392
nb = self._non_blocking
@@ -1072,12 +1080,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
10721080
selected_output_keys = self._selected_reset_keys_filt
10731081

10741082
# select + clone creates 2 tds, but we can create one only
1075-
def select_and_clone(name, tensor):
1076-
if name in selected_output_keys:
1077-
return tensor.clone()
1078-
10791083
out = self.shared_tensordict_parent.named_apply(
1080-
select_and_clone,
1084+
lambda *args: self.select_and_clone(
1085+
*args, selected_keys=selected_output_keys
1086+
),
10811087
nested_keys=True,
10821088
filter_empty=True,
10831089
)
@@ -1150,14 +1156,14 @@ def _step(
11501156
# will be modified in-place at further steps
11511157
device = self.device
11521158

1153-
def select_and_clone(name, tensor):
1154-
if name in self._selected_step_keys:
1155-
return tensor.clone()
1159+
selected_keys = self._selected_step_keys
11561160

11571161
if partial_steps is not None:
11581162
next_td = TensorDict.lazy_stack([next_td[i] for i in workers_range])
11591163
out = next_td.named_apply(
1160-
select_and_clone, nested_keys=True, filter_empty=True
1164+
lambda *args: self.select_and_clone(*args, selected_keys),
1165+
nested_keys=True,
1166+
filter_empty=True,
11611167
)
11621168
if out_tds is not None:
11631169
out.update(
@@ -2010,20 +2016,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
20102016
next_td = shared_tensordict_parent.get("next")
20112017
device = self.device
20122018

2013-
if next_td.device != device and device is not None:
2014-
2015-
def select_and_clone(name, tensor):
2016-
if name in self._selected_step_keys:
2017-
return tensor.to(device, non_blocking=self.non_blocking)
2018-
2019-
else:
2020-
2021-
def select_and_clone(name, tensor):
2022-
if name in self._selected_step_keys:
2023-
return tensor.clone()
2024-
20252019
out = next_td.named_apply(
2026-
select_and_clone,
2020+
self.select_and_clone,
20272021
nested_keys=True,
20282022
filter_empty=True,
20292023
device=device,
@@ -2203,20 +2197,10 @@ def tentative_update(val, other):
22032197
selected_output_keys = self._selected_reset_keys_filt
22042198
device = self.device
22052199

2206-
if self.shared_tensordict_parent.device != device and device is not None:
2207-
2208-
def select_and_clone(name, tensor):
2209-
if name in selected_output_keys:
2210-
return tensor.to(device, non_blocking=self.non_blocking)
2211-
2212-
else:
2213-
2214-
def select_and_clone(name, tensor):
2215-
if name in selected_output_keys:
2216-
return tensor.clone()
2217-
22182200
out = self.shared_tensordict_parent.named_apply(
2219-
select_and_clone,
2201+
lambda *args: self.select_and_clone(
2202+
*args, selected_keys=selected_output_keys
2203+
),
22202204
nested_keys=True,
22212205
filter_empty=True,
22222206
device=device,

torchrl/envs/custom/llm.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,17 @@ def __init__(
133133
self.vocab_size = vocab_size
134134
self.token_key = unravel_key(token_key)
135135
self.str_key = unravel_key(str_key)
136-
self.attention_key = unravel_key(attention_key)
136+
if attention_key is not None:
137+
attention_key = unravel_key(attention_key)
138+
self.attention_key = attention_key
137139
self.no_stack = no_stack
138140
self.assign_reward = assign_reward
139141
self.assign_done = assign_done
140142

141143
# self.action_key = unravel_key(action_key)
142144
if str2str:
143145
self.full_observation_spec_unbatched = Composite(
144-
{
145-
self.str_key: NonTensor(
146-
example_data="a string", batched=True, shape=()
147-
)
148-
}
146+
{self.str_key: NonTensor(example_data="a string", batched=True, shape=())}
149147
)
150148
self.full_action_spec_unbatched = Composite(
151149
{action_key: NonTensor(example_data="a string", batched=True, shape=())}

torchrl/objectives/value/advantages.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1281,8 +1281,18 @@ def __init__(
12811281
skip_existing=skip_existing,
12821282
device=device,
12831283
)
1284-
self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
1285-
self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device))
1284+
self.register_buffer(
1285+
"gamma",
1286+
gamma.to(self._device)
1287+
if isinstance(gamma, Tensor)
1288+
else torch.tensor(gamma, device=self._device),
1289+
)
1290+
self.register_buffer(
1291+
"lmbda",
1292+
lmbda.to(self._device)
1293+
if isinstance(lmbda, Tensor)
1294+
else torch.tensor(lmbda, device=self._device),
1295+
)
12861296
self.average_gae = average_gae
12871297
self.vectorized = vectorized
12881298
self.time_dim = time_dim

0 commit comments

Comments
 (0)