Skip to content

Commit 104b880

Browse files
author
Vincent Moens
committed
[Feature] Timer transform
ghstack-source-id: e42f2ae Pull Request resolved: #2806
1 parent d3dca73 commit 104b880

File tree

5 files changed

+199
-0
lines changed

5 files changed

+199
-0
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,7 @@ to be able to create this other composition:
10701070
TargetReturn
10711071
TensorDictPrimer
10721072
TimeMaxPool
1073+
Timer
10731074
Tokenizer
10741075
ToTensorImage
10751076
TrajCounter

test/test_transforms.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import tensordict.tensordict
2525
import torch
2626
from tensordict import (
27+
LazyStackedTensorDict,
2728
NonTensorData,
2829
NonTensorStack,
2930
TensorDict,
@@ -102,6 +103,7 @@
102103
TargetReturn,
103104
TensorDictPrimer,
104105
TimeMaxPool,
106+
Timer,
105107
Tokenizer,
106108
ToTensorImage,
107109
TrajCounter,
@@ -13879,6 +13881,90 @@ def test_transform_inverse(self):
1387913881
return
1388013882

1388113883

13884+
class TestTimer(TransformBase):
13885+
def test_single_trans_env_check(self):
13886+
env = TransformedEnv(ContinuousActionVecMockEnv(), Timer())
13887+
check_env_specs(env)
13888+
env.close()
13889+
13890+
def test_serial_trans_env_check(self):
13891+
env = SerialEnv(
13892+
2, lambda: TransformedEnv(ContinuousActionVecMockEnv(), Timer())
13893+
)
13894+
check_env_specs(env)
13895+
env.close()
13896+
13897+
def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv):
13898+
env = maybe_fork_ParallelEnv(
13899+
2, lambda: TransformedEnv(ContinuousActionVecMockEnv(), Timer())
13900+
)
13901+
try:
13902+
check_env_specs(env)
13903+
finally:
13904+
try:
13905+
env.close()
13906+
except RuntimeError:
13907+
pass
13908+
13909+
def test_trans_serial_env_check(self):
13910+
env = TransformedEnv(
13911+
SerialEnv(2, lambda: ContinuousActionVecMockEnv()), Timer()
13912+
)
13913+
try:
13914+
check_env_specs(env)
13915+
finally:
13916+
try:
13917+
env.close()
13918+
except RuntimeError:
13919+
pass
13920+
13921+
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
13922+
env = TransformedEnv(
13923+
maybe_fork_ParallelEnv(2, lambda: ContinuousActionVecMockEnv()),
13924+
Timer(),
13925+
)
13926+
try:
13927+
check_env_specs(env)
13928+
finally:
13929+
try:
13930+
env.close()
13931+
except RuntimeError:
13932+
pass
13933+
13934+
def test_transform_no_env(self):
13935+
torch.manual_seed(0)
13936+
t = Timer()
13937+
with pytest.raises(NotImplementedError):
13938+
t(TensorDict())
13939+
13940+
def test_transform_compose(self):
13941+
torch.manual_seed(0)
13942+
t = Compose(Timer())
13943+
with pytest.raises(NotImplementedError):
13944+
t(TensorDict())
13945+
13946+
def test_transform_env(self):
13947+
env = TransformedEnv(ContinuousActionVecMockEnv(), Timer())
13948+
rollout = env.rollout(3)
13949+
# The stack must be contiguous
13950+
assert not isinstance(rollout, LazyStackedTensorDict)
13951+
assert (rollout["time_policy"] >= 0).all()
13952+
assert (rollout["time_step"] > 0).all()
13953+
13954+
def test_transform_model(self):
13955+
torch.manual_seed(0)
13956+
t = nn.Sequential(Timer())
13957+
with pytest.raises(NotImplementedError):
13958+
t(TensorDict())
13959+
13960+
def test_transform_rb(self):
13961+
# NotImplemented tested elsewhere
13962+
return
13963+
13964+
def test_transform_inverse(self):
13965+
raise pytest.skip("Tested elsewhere")
13966+
13967+
1388213968
if __name__ == "__main__":
1388313969
args, unknown = argparse.ArgumentParser().parse_known_args()
1388413970
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
TargetReturn,
9898
TensorDictPrimer,
9999
TimeMaxPool,
100+
Timer,
100101
Tokenizer,
101102
ToTensorImage,
102103
TrajCounter,

torchrl/envs/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
TargetReturn,
5858
TensorDictPrimer,
5959
TimeMaxPool,
60+
Timer,
6061
Tokenizer,
6162
ToTensorImage,
6263
TrajCounter,

torchrl/envs/transforms/transforms.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import hashlib
1010
import importlib.util
1111
import multiprocessing as mp
12+
import time
1213
import warnings
1314
import weakref
1415
from copy import copy
@@ -10823,3 +10824,112 @@ def _transform_observation_spec(
1082310824
)
1082410825
)
1082510826
return observation_spec
10827+
10828+
10829+
class Timer(Transform):
10830+
"""A transform that measures the time intervals between `inv` and `call` operations in an environment.
10831+
10832+
The `Timer` transform is used to track the time elapsed between the `inv` call and the `call`,
10833+
and between the `call` and the `inv` call. This is useful for performance monitoring and debugging
10834+
within an environment. The time is measured in seconds and stored as a tensor with the default
10835+
dtype from PyTorch. If the tensordict has a batch size (e.g., in batched environments), the time will be expended
10836+
to the size of the input tensordict.
10837+
10838+
Attributes:
10839+
out_keys: The keys of the output tensordict for the inverse transform. Defaults to
10840+
`out_keys = [f"{time_key}_step", f"{time_key}_policy"]`, where the first key represents
10841+
the time it takes to make a step in the environment, and the second key represents the
10842+
time it takes to execute the policy.
10843+
time_key: A prefix for the keys where the time intervals will be stored in the tensordict.
10844+
Defaults to `"time"`.
10845+
10846+
Examples:
10847+
>>> from torchrl.envs import Timer, GymEnv
10848+
>>>
10849+
>>> env = GymEnv("Pendulum-v1").append_transform(Timer())
10850+
>>> r = env.rollout(10)
10851+
>>> print("time for policy", r["time_policy"])
10852+
time for policy tensor([0.0000, 0.0882, 0.0004, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002,
10853+
0.0002])
10854+
>>> print("time for step", r["time_step"])
10855+
time for step tensor([9.5797e-04, 1.6289e-03, 9.7990e-05, 8.0824e-05, 9.0837e-05, 7.6056e-05,
10856+
8.2016e-05, 7.6056e-05, 8.1062e-05, 7.7009e-05])
10857+
"""
10858+
10859+
def __init__(self, out_keys: Sequence[NestedKey] = None, time_key: str = "time"):
10860+
if out_keys is None:
10861+
out_keys = [f"{time_key}_step", f"{time_key}_policy"]
10862+
elif len(out_keys) != 2:
10863+
raise TypeError(f"Expected two out_keys. Got out_keys={out_keys}.")
10864+
super().__init__([], out_keys)
10865+
self.time_key = time_key
10866+
self.last_inv_time = None
10867+
self.last_call_time = None
10868+
10869+
def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
10870+
self.last_inv_time = time.time()
10871+
return tensordict
10872+
10873+
def _maybe_expand_and_set(self, key, time_elapsed, tensordict):
10874+
if isinstance(key, tuple):
10875+
parent_td = tensordict.get(key[:-1])
10876+
key = key[-1]
10877+
else:
10878+
parent_td = tensordict
10879+
batch_size = parent_td.batch_size
10880+
if batch_size:
10881+
# Get the parent shape
10882+
time_elapsed_expand = time_elapsed.expand(parent_td.batch_size)
10883+
else:
10884+
time_elapsed_expand = time_elapsed
10885+
parent_td.set(key, time_elapsed_expand)
10886+
10887+
def _reset(
10888+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
10889+
) -> TensorDictBase:
10890+
current_time = time.time()
10891+
if self.last_inv_time is not None:
10892+
time_elapsed = torch.tensor(
10893+
current_time - self.last_inv_time, device=tensordict.device
10894+
)
10895+
self._maybe_expand_and_set(self.out_keys[0], time_elapsed, tensordict_reset)
10896+
self.last_call_time = current_time
10897+
# Placeholder
10898+
self._maybe_expand_and_set(self.out_keys[1], time_elapsed * 0, tensordict_reset)
10899+
return tensordict_reset
10900+
10901+
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
10902+
current_time = time.time()
10903+
if self.last_call_time is not None:
10904+
time_elapsed = torch.tensor(
10905+
current_time - self.last_call_time, device=tensordict.device
10906+
)
10907+
self._maybe_expand_and_set(self.out_keys[1], time_elapsed, tensordict)
10908+
self.last_inv_time = current_time
10909+
return tensordict
10910+
10911+
def _step(
10912+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
10913+
) -> TensorDictBase:
10914+
current_time = time.time()
10915+
if self.last_inv_time is not None:
10916+
time_elapsed = torch.tensor(
10917+
current_time - self.last_inv_time, device=tensordict.device
10918+
)
10919+
self._maybe_expand_and_set(self.out_keys[0], time_elapsed, next_tensordict)
10920+
self.last_call_time = current_time
10921+
# presumbly no need to worry about batch size incongruencies here
10922+
next_tensordict.set(self.out_keys[1], tensordict.get(self.out_keys[1]))
10923+
return next_tensordict
10924+
10925+
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
10926+
observation_spec[self.out_keys[0]] = Unbounded(
10927+
shape=observation_spec.shape, device=observation_spec.device
10928+
)
10929+
observation_spec[self.out_keys[1]] = Unbounded(
10930+
shape=observation_spec.shape, device=observation_spec.device
10931+
)
10932+
return observation_spec
10933+
10934+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
10935+
raise NotImplementedError(FORWARD_NOT_IMPLEMENTED)

0 commit comments

Comments
 (0)