Skip to content

Commit ca5878f

Browse files
author
Vincent Moens
committed
[BugFix] Test and fix life cycle of env with dynamic non-tensor spec
ghstack-source-id: 77da3a6 Pull Request resolved: #2812 (cherry picked from commit b538c66)
1 parent 02e6493 commit ca5878f

File tree

9 files changed

+359
-42
lines changed

9 files changed

+359
-42
lines changed

test/mocking_classes.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import string
99
from typing import Dict, List, Optional
1010

11+
import numpy as np
12+
1113
import torch
1214
import torch.nn as nn
1315
from tensordict import TensorDict, TensorDictBase
@@ -26,6 +28,7 @@
2628
Unbounded,
2729
)
2830
from torchrl.data.utils import consolidate_spec
31+
from torchrl.envs import Transform
2932
from torchrl.envs.common import EnvBase
3033
from torchrl.envs.model_based.common import ModelBasedEnvBase
3134
from torchrl.envs.utils import (
@@ -34,7 +37,6 @@
3437
MarlGroupMapType,
3538
)
3639

37-
3840
spec_dict = {
3941
"bounded": Bounded,
4042
"one_hot": OneHot,
@@ -2268,3 +2270,108 @@ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDict:
22682270

22692271
def _set_seed(self, seed: Optional[int]):
22702272
...
2273+
2274+
2275+
@tensorclass()
2276+
class TC:
2277+
field0: str
2278+
field1: torch.Tensor
2279+
2280+
2281+
class EnvWithTensorClass(CountingEnv):
2282+
tc_cls = TC
2283+
2284+
def __init__(self, **kwargs):
2285+
super().__init__(**kwargs)
2286+
self.observation_spec["tc"] = Composite(
2287+
field0=NonTensor(example_data="an observation!", shape=self.batch_size),
2288+
field1=Unbounded(shape=self.batch_size),
2289+
shape=self.batch_size,
2290+
data_cls=TC,
2291+
)
2292+
2293+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2294+
td = super()._reset(tensordict, **kwargs)
2295+
td["tc"] = TC("0", torch.zeros(self.batch_size))
2296+
return td
2297+
2298+
def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2299+
td = super()._step(tensordict, **kwargs)
2300+
default = TC("0", 0)
2301+
f0 = tensordict.get("tc", default).field0
2302+
if f0 is None:
2303+
f0 = "0"
2304+
f1 = tensordict.get("tc", default).field1
2305+
if f1 is None:
2306+
f1 = torch.zeros(self.batch_size)
2307+
td["tc"] = TC(
2308+
str(int(f0) + 1),
2309+
f1 + 1,
2310+
)
2311+
return td
2312+
2313+
2314+
@tensorclass
2315+
class History:
2316+
role: str
2317+
content: str
2318+
2319+
2320+
class HistoryTransform(Transform):
2321+
"""A mocking class to record history."""
2322+
2323+
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
2324+
defaults = {
2325+
"role": NonTensor(
2326+
example_data="a role!",
2327+
shape=(-1,),
2328+
),
2329+
"content": NonTensor(
2330+
example_data="a content!",
2331+
shape=(-1,),
2332+
),
2333+
}
2334+
observation_spec["history"] = Composite(
2335+
defaults,
2336+
shape=(-1,),
2337+
data_cls=History,
2338+
)
2339+
assert observation_spec.device == self.parent.device
2340+
assert observation_spec["history"].device == self.parent.device
2341+
return observation_spec
2342+
2343+
def _reset(
2344+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
2345+
) -> TensorDictBase:
2346+
assert tensordict_reset.device == self.parent.device
2347+
tensordict_reset["history"] = torch.stack(
2348+
[
2349+
History(role="system", content="0"),
2350+
History(role="user", content="1"),
2351+
]
2352+
)
2353+
assert tensordict_reset["history"].device == self.parent.device
2354+
return tensordict_reset
2355+
2356+
def _step(
2357+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
2358+
) -> TensorDictBase:
2359+
assert next_tensordict.device == self.parent.device
2360+
history = tensordict["history"]
2361+
local_history = History(
2362+
role=np.random.choice(["user", "system", "assistant"]),
2363+
content=str(int(history.content[-1]) + 1),
2364+
device=history.device,
2365+
)
2366+
# history = tensordict["history"].append(local_history)
2367+
try:
2368+
history = torch.stack(list(history.unbind(0)) + [local_history])
2369+
except Exception:
2370+
raise
2371+
assert isinstance(history, History)
2372+
next_tensordict["history"] = history
2373+
assert next_tensordict["history"].device == self.parent.device, (
2374+
next_tensordict["history"],
2375+
self.parent.device,
2376+
)
2377+
return next_tensordict

test/test_env.py

Lines changed: 143 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
CatFrames,
4343
CatTensors,
4444
ChessEnv,
45+
ConditionalSkip,
4546
DoubleToFloat,
4647
EnvBase,
4748
EnvCreator,
@@ -70,6 +71,7 @@
7071
check_marl_grouping,
7172
make_composite_from_td,
7273
MarlGroupMapType,
74+
RandomPolicy,
7375
step_mdp,
7476
)
7577
from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator
@@ -131,6 +133,7 @@
131133
EnvWithMetadata,
132134
HeterogeneousCountingEnv,
133135
HeterogeneousCountingEnvPolicy,
136+
HistoryTransform,
134137
MockBatchedLockedEnv,
135138
MockBatchedUnLockedEnv,
136139
MockSerialEnv,
@@ -170,6 +173,7 @@
170173
EnvWithMetadata,
171174
HeterogeneousCountingEnv,
172175
HeterogeneousCountingEnvPolicy,
176+
HistoryTransform,
173177
MockBatchedLockedEnv,
174178
MockBatchedUnLockedEnv,
175179
MockSerialEnv,
@@ -3629,8 +3633,11 @@ def test_serial(self, bwad, use_buffers):
36293633
def test_parallel(self, bwad, use_buffers):
36303634
N = 50
36313635
env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers)
3632-
r = env.rollout(N, break_when_any_done=bwad)
3633-
assert r.get("non_tensor").tolist() == [list(range(N))] * 2
3636+
try:
3637+
r = env.rollout(N, break_when_any_done=bwad)
3638+
assert r.get("non_tensor").tolist() == [list(range(N))] * 2
3639+
finally:
3640+
env.close(raise_if_closed=False)
36343641

36353642
class AddString(Transform):
36363643
def __init__(self):
@@ -3662,19 +3669,22 @@ def test_partial_reset(self, batched):
36623669
env = ParallelEnv(2, [env0, env1], mp_start_method=mp_ctx)
36633670
else:
36643671
env = SerialEnv(2, [env0, env1])
3665-
s = env.reset()
3666-
i = 0
3667-
for i in range(10): # noqa: B007
3668-
s, s_ = env.step_and_maybe_reset(
3669-
s.set("action", torch.ones(2, 1, dtype=torch.int))
3670-
)
3671-
if s.get(("next", "done")).any():
3672-
break
3673-
s = s_
3674-
assert i == 5
3675-
assert (s["next", "done"] == torch.tensor([[True], [False]])).all()
3676-
assert s_["string"] == ["0", "6"]
3677-
assert s["next", "string"] == ["6", "6"]
3672+
try:
3673+
s = env.reset()
3674+
i = 0
3675+
for i in range(10): # noqa: B007
3676+
s, s_ = env.step_and_maybe_reset(
3677+
s.set("action", torch.ones(2, 1, dtype=torch.int))
3678+
)
3679+
if s.get(("next", "done")).any():
3680+
break
3681+
s = s_
3682+
assert i == 5
3683+
assert (s["next", "done"] == torch.tensor([[True], [False]])).all()
3684+
assert s_["string"] == ["0", "6"]
3685+
assert s["next", "string"] == ["6", "6"]
3686+
finally:
3687+
env.close(raise_if_closed=False)
36783688

36793689
@pytest.mark.skipif(not _has_transformers, reason="transformers required")
36803690
def test_str2str_env_tokenizer(self):
@@ -4182,6 +4192,124 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
41824192
assert (td[3].get("next") != 0).any()
41834193

41844194

4195+
class TestEnvWithHistory:
4196+
@pytest.fixture(autouse=True, scope="class")
4197+
def set_capture(self):
4198+
with set_capture_non_tensor_stack(False), set_auto_unwrap_transformed_env(
4199+
False
4200+
):
4201+
yield
4202+
return
4203+
4204+
def _make_env(self, device, max_steps=10):
4205+
return CountingEnv(device=device, max_steps=max_steps).append_transform(
4206+
HistoryTransform()
4207+
)
4208+
4209+
def _make_skipping_env(self, device, max_steps=10):
4210+
env = self._make_env(device=device, max_steps=max_steps)
4211+
# skip every 3 steps
4212+
env = env.append_transform(
4213+
ConditionalSkip(lambda td: ((td["step_count"] % 3) == 2))
4214+
)
4215+
env = TransformedEnv(env, StepCounter())
4216+
return env
4217+
4218+
@pytest.mark.parametrize("device", [None, "cpu"])
4219+
def test_env_history_base(self, device):
4220+
env = self._make_env(device)
4221+
env.check_env_specs()
4222+
4223+
@pytest.mark.parametrize("device", [None, "cpu"])
4224+
def test_skipping_history_env(self, device):
4225+
env = self._make_skipping_env(device)
4226+
env.check_env_specs()
4227+
r = env.rollout(100)
4228+
4229+
@pytest.mark.parametrize("device_env", [None, "cpu"])
4230+
@pytest.mark.parametrize("device", [None, "cpu"])
4231+
@pytest.mark.parametrize("batch_cls", [SerialEnv, "parallel"])
4232+
@pytest.mark.parametrize("consolidate", [False, True])
4233+
def test_env_history_base_batched(
4234+
self, device, device_env, batch_cls, maybe_fork_ParallelEnv, consolidate
4235+
):
4236+
if batch_cls == "parallel":
4237+
batch_cls = maybe_fork_ParallelEnv
4238+
env = batch_cls(
4239+
2,
4240+
lambda: self._make_env(device_env),
4241+
device=device,
4242+
consolidate=consolidate,
4243+
)
4244+
try:
4245+
assert not env._use_buffers
4246+
env.check_env_specs(break_when_any_done="both")
4247+
finally:
4248+
env.close(raise_if_closed=False)
4249+
4250+
@pytest.mark.parametrize("device_env", [None, "cpu"])
4251+
@pytest.mark.parametrize("device", [None, "cpu"])
4252+
@pytest.mark.parametrize("batch_cls", [SerialEnv, "parallel"])
4253+
@pytest.mark.parametrize("consolidate", [False, True])
4254+
def test_skipping_history_env_batched(
4255+
self, device, device_env, batch_cls, maybe_fork_ParallelEnv, consolidate
4256+
):
4257+
if batch_cls == "parallel":
4258+
batch_cls = maybe_fork_ParallelEnv
4259+
env = batch_cls(
4260+
2,
4261+
lambda: self._make_skipping_env(device_env),
4262+
device=device,
4263+
consolidate=consolidate,
4264+
)
4265+
try:
4266+
env.check_env_specs()
4267+
finally:
4268+
env.close(raise_if_closed=False)
4269+
4270+
@pytest.mark.parametrize("device_env", [None, "cpu"])
4271+
@pytest.mark.parametrize("collector_cls", [SyncDataCollector])
4272+
def test_env_history_base_collector(self, device_env, collector_cls):
4273+
env = self._make_env(device_env)
4274+
collector = collector_cls(
4275+
env, RandomPolicy(env.full_action_spec), total_frames=35, frames_per_batch=5
4276+
)
4277+
for d in collector:
4278+
for i in range(d.shape[0] - 1):
4279+
assert (
4280+
d[i + 1]["history"].content[0] == d[i]["next", "history"].content[0]
4281+
)
4282+
4283+
@pytest.mark.parametrize("device_env", [None, "cpu"])
4284+
@pytest.mark.parametrize("collector_cls", [SyncDataCollector])
4285+
def test_skipping_history_env_collector(self, device_env, collector_cls):
4286+
env = self._make_skipping_env(device_env, max_steps=10)
4287+
collector = collector_cls(
4288+
env,
4289+
lambda td: td.update(env.full_action_spec.one()),
4290+
total_frames=35,
4291+
frames_per_batch=5,
4292+
)
4293+
length = None
4294+
count = 1
4295+
for d in collector:
4296+
for k in range(1, 5):
4297+
if len(d[k]["history"].content) == 2:
4298+
count = 1
4299+
continue
4300+
if count % 3 == 2:
4301+
assert (
4302+
d[k]["next", "history"].content
4303+
== d[k - 1]["next", "history"].content
4304+
), (d["next", "history"].content, k, count)
4305+
else:
4306+
assert d[k]["next", "history"].content[-1] == str(
4307+
int(d[k - 1]["next", "history"].content[-1]) + 1
4308+
), (d["next", "history"].content, k, count)
4309+
count += 1
4310+
count += 1
4311+
4312+
41854313
if __name__ == "__main__":
41864314
args, unknown = argparse.ArgumentParser().parse_known_args()
41874315
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_specs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3912,6 +3912,13 @@ def test_example_data_ineq(self):
39123912
nts1 = NonTensor(shape=(3, 4), example_data="example_data 2")
39133913
assert nts0 != nts1
39143914

3915+
def test_device_cast(self):
3916+
comp = Composite(device="cpu")
3917+
comp["nontensor"] = NonTensor(device=None)
3918+
assert comp["nontensor"].device == torch.device("cpu")
3919+
comp["nontensor"] = NonTensor(device="cpu")
3920+
assert comp["nontensor"].device == torch.device("cpu")
3921+
39153922

39163923
@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")
39173924
def test_device_ordinal():

torchrl/_utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,20 @@ def erase():
162162
def _check_for_faulty_process(processes):
163163
terminate = False
164164
for p in processes:
165-
if not p.is_alive():
165+
if not p._closed and not p.is_alive():
166166
terminate = True
167167
for _p in processes:
168-
if _p.is_alive():
169-
_p.terminate()
170-
_p.close()
171-
if terminate:
172-
break
168+
_p: mp.Process
169+
if not _p._closed and _p.is_alive():
170+
try:
171+
_p.terminate()
172+
except Exception:
173+
_p.kill()
174+
finally:
175+
time.sleep(0.1)
176+
_p.close()
177+
if terminate:
178+
break
173179
if terminate:
174180
raise RuntimeError(
175181
"At least one process failed. Check for more infos in the log."

torchrl/collectors/collectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,7 @@ def cuda_check(tensor: torch.Tensor):
10571057
# This may be a bit dangerous as `torch.device("cuda")` may not have a precise
10581058
# device associated, whereas `tensor.device` always has
10591059
for spec in self.env.specs.values(True, True):
1060-
if spec.device.type == "cuda":
1060+
if spec.device is not None and spec.device.type == "cuda":
10611061
if ":" not in str(spec.device):
10621062
raise RuntimeError(
10631063
"A cuda spec did not have a device associated. Make sure to "

0 commit comments

Comments
 (0)