Skip to content

Commit 24d14ad

Browse files
author
Vincent Moens
authored
[Doc, Feature] Doc improvements for video recording and CSV video formats (#1829)
1 parent c390cf6 commit 24d14ad

File tree

9 files changed

+139
-22
lines changed

9 files changed

+139
-22
lines changed

.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ dependencies:
2525
- patchelf
2626
- pyopengl==3.1.4
2727
- ray<2.8.0
28+
- av

test/test_loggers.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import pytest
1414
import torch
15+
16+
from tensordict import MemoryMappedTensor
1517
from torchrl.record.loggers.csv import CSVLogger
1618
from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger
1719
from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger
@@ -150,16 +152,22 @@ def test_log_scalar(self, steps, tmpdir):
150152
assert row == f"{step},{values[i].item()}\n"
151153

152154
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
153-
def test_log_video(self, steps, tmpdir):
155+
@pytest.mark.parametrize(
156+
"video_format", ["pt", "memmap"] + ["mp4"] if _has_tv else []
157+
)
158+
def test_log_video(self, steps, video_format, tmpdir):
154159
torch.manual_seed(0)
155160
exp_name = "ramala"
156-
logger = CSVLogger(log_dir=tmpdir, exp_name=exp_name)
161+
logger = CSVLogger(log_dir=tmpdir, exp_name=exp_name, video_format=video_format)
157162

158163
# creating a sample video (T, C, H, W), where T - number of frames,
159164
# C - number of image channels (e.g. 3 for RGB), H, W - image dimensions.
160165
# the first 64 frames are black and the next 64 are white
161166
video = torch.cat(
162-
(torch.zeros(64, 1, 32, 32), torch.full((64, 1, 32, 32), 255))
167+
(
168+
torch.zeros(64, 1, 32, 32, dtype=torch.uint8),
169+
torch.full((64, 1, 32, 32), 255, dtype=torch.uint8),
170+
)
163171
)
164172
video = video[None, :]
165173
for i in range(3):
@@ -171,11 +179,31 @@ def test_log_video(self, steps, tmpdir):
171179
sleep(0.01) # wait until events are registered
172180

173181
# check that the logged videos are the same as the initial video
174-
video_file_name = "foo_" + ("0" if not steps else str(steps[0])) + ".pt"
175-
logged_video = torch.load(
176-
os.path.join(tmpdir, exp_name, "videos", video_file_name)
182+
extention = (
183+
".pt"
184+
if video_format == "pt"
185+
else ".memmap"
186+
if video_format == "memmap"
187+
else ".mp4"
177188
)
178-
assert torch.equal(video, logged_video), logged_video
189+
video_file_name = "foo_" + ("0" if not steps else str(steps[0])) + extention
190+
path = os.path.join(tmpdir, exp_name, "videos", video_file_name)
191+
if video_format == "pt":
192+
logged_video = torch.load(path)
193+
assert torch.equal(video, logged_video), logged_video
194+
elif video_format == "memmap":
195+
logged_video = MemoryMappedTensor.from_filename(
196+
path, dtype=torch.uint8, shape=(1, 128, 1, 32, 32)
197+
)
198+
assert torch.equal(video, logged_video), logged_video
199+
elif video_format == "mp4":
200+
import torchvision
201+
202+
logged_video = torchvision.io.read_video(path, output_format="TCHW")[0][
203+
:, :1
204+
]
205+
logged_video = logged_video.unsqueeze(0)
206+
torch.testing.assert_close(video, logged_video)
179207

180208
# check that we catch the error in case the format of the tensor is wrong
181209
video_wrong_format = torch.zeros(64, 2, 32, 32)

torchrl/objectives/a2c.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,7 @@ class A2CLoss(LossModule):
130130
the expected keyword arguments are:
131131
``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and critic.
132132
The return value is a tuple of tensors in the following order:
133-
``["loss_objective"]``
134-
+ ``["loss_critic"]`` if critic_coef is not None
135-
+ ``["entropy", "loss_entropy"]`` if entropy_bonus is True and critic_coef is not None
133+
``["loss_objective"]`` + ``["loss_critic"]`` if critic_coef is not None + ``["entropy", "loss_entropy"]`` if entropy_bonus is True and critic_coef is not None
136134
137135
Examples:
138136
>>> import torch

torchrl/objectives/ppo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ class PPOLoss(LossModule):
178178
the expected keyword arguments are:
179179
``["action", "sample_log_prob", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and value network.
180180
The return value is a tuple of tensors in the following order:
181-
``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set
182-
+ ``"loss_critic"`` if critic_coef is not None.
181+
``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set + ``"loss_critic"`` if critic_coef is not ``None``.
183182
The output keys can also be filtered using :meth:`PPOLoss.select_out_keys` method.
184183
185184
Examples:

torchrl/objectives/redq.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,7 @@ class REDQLoss(LossModule):
138138
the expected keyword arguments are:
139139
``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network
140140
The return value is a tuple of tensors in the following order:
141-
``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy",
142-
"state_action_value_actor", "action_log_prob_actor", "next.state_value", "target_value",]``.
141+
``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy", "state_action_value_actor", "action_log_prob_actor", "next.state_value", "target_value",]``.
143142
144143
Examples:
145144
>>> import torch

torchrl/record/loggers/csv.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import Dict, Optional, Sequence, Union
99

1010
import torch
11+
12+
from tensordict import MemoryMappedTensor
1113
from torch import Tensor
1214

1315
from .common import Logger
@@ -16,11 +18,13 @@
1618
class CSVExperiment:
1719
"""A CSV logger experiment class."""
1820

19-
def __init__(self, log_dir: str):
21+
def __init__(self, log_dir: str, *, video_format="pt", video_fps=30):
2022
self.scalars = defaultdict(lambda: [])
2123
self.videos_counter = defaultdict(lambda: 0)
2224
self.text_counter = defaultdict(lambda: 0)
2325
self.log_dir = log_dir
26+
self.video_format = video_format
27+
self.video_fps = video_fps
2428
os.makedirs(self.log_dir, exist_ok=True)
2529
os.makedirs(os.path.join(self.log_dir, "scalars"), exist_ok=True)
2630
os.makedirs(os.path.join(self.log_dir, "videos"), exist_ok=True)
@@ -44,12 +48,43 @@ def add_video(self, tag, vid_tensor, global_step: Optional[int] = None, **kwargs
4448
if global_step is None:
4549
global_step = self.videos_counter[tag]
4650
self.videos_counter[tag] += 1
51+
if self.video_format == "pt":
52+
extension = ".pt"
53+
elif self.video_format == "memmap":
54+
extension = ".memmap"
55+
elif self.video_format == "mp4":
56+
extension = ".mp4"
57+
else:
58+
raise ValueError(
59+
f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'."
60+
)
61+
4762
filepath = os.path.join(
48-
self.log_dir, "videos", "_".join([tag, str(global_step)]) + ".pt"
63+
self.log_dir, "videos", "_".join([tag, str(global_step)]) + extension
4964
)
5065
path_to_create = Path(str(filepath)).parent
5166
os.makedirs(path_to_create, exist_ok=True)
52-
torch.save(vid_tensor, filepath)
67+
if self.video_format == "pt":
68+
torch.save(vid_tensor, filepath)
69+
elif self.video_format == "memmap":
70+
MemoryMappedTensor.from_tensor(vid_tensor, filename=filepath)
71+
elif self.video_format == "mp4":
72+
import torchvision
73+
74+
if vid_tensor.shape[-3] not in (3, 1):
75+
raise RuntimeError(
76+
"expected the video tensor to be of format [T, C, H, W] but the third channel "
77+
f"starting from the end isn't in (1, 3) but is {vid_tensor.shape[-3]}."
78+
)
79+
if vid_tensor.ndim > 4:
80+
vid_tensor = vid_tensor.flatten(0, vid_tensor.ndim - 4)
81+
vid_tensor = vid_tensor.permute((0, 2, 3, 1))
82+
vid_tensor = vid_tensor.expand(*vid_tensor.shape[:-1], 3)
83+
torchvision.io.write_video(filepath, vid_tensor, fps=self.video_fps)
84+
else:
85+
raise ValueError(
86+
f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'."
87+
)
5388

5489
def add_text(self, tag, text, global_step: Optional[int] = None):
5590
if global_step is None:
@@ -77,20 +112,37 @@ class CSVLogger(Logger):
77112
78113
Args:
79114
exp_name (str): The name of the experiment.
115+
log_dir (str or Path, optional): where the experiment should be saved.
116+
Defaults to ``<cur_dir>/csv_logs``.
117+
video_format (str, optional): how videos should be saved. Must be one of
118+
``"pt"`` (video saved as a `video_<tag>_<step>.pt` file with torch.save),
119+
``"memmap"`` (video saved as a `video_<tag>_<step>.memmap` file with :class:`~tensordict.MemoryMappedTensor`),
120+
``"mp4"`` (video saved as a `video_<tag>_<step>.mp4` file, requires torchvision to be installed).
121+
Defaults to ``"pt"``.
122+
video_fps (int, optional): the video frames-per-seconds if `video_format="mp4"`. Defaults to 30.
80123
81124
"""
82125

83-
def __init__(self, exp_name: str, log_dir: Optional[str] = None) -> None:
126+
def __init__(
127+
self,
128+
exp_name: str,
129+
log_dir: Optional[str] = None,
130+
video_format: str = "pt",
131+
video_fps: int = 30,
132+
) -> None:
84133
if log_dir is None:
85134
log_dir = "csv_logs"
135+
self.video_format = video_format
136+
self.video_fps = video_fps
86137
super().__init__(exp_name=exp_name, log_dir=log_dir)
87-
88138
self._has_imported_moviepy = False
89139

90140
def _create_experiment(self) -> "CSVExperiment":
91141
"""Creates a CSV experiment."""
92142
log_dir = str(os.path.join(self.log_dir, self.exp_name))
93-
return CSVExperiment(log_dir)
143+
return CSVExperiment(
144+
log_dir, video_format=self.video_format, video_fps=self.video_fps
145+
)
94146

95147
def log_scalar(self, name: str, value: float, step: int = None) -> None:
96148
"""Logs a scalar value to the tensorboard.

torchrl/record/loggers/tensorboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class TensorboardLogger(Logger):
2020
2121
Args:
2222
exp_name (str): The name of the experiment.
23-
log_dir (str): the tensorboard log_dir.
23+
log_dir (str): the tensorboard log_dir. Defaults to ``td_logs``.
2424
2525
"""
2626

torchrl/record/loggers/wandb.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,24 @@
1919
class WandbLogger(Logger):
2020
"""Wrapper for the wandb logger.
2121
22+
The keyword arguments are mainly based on the :func:`wandb.init` kwargs.
23+
See the doc `here <https://docs.wandb.ai/ref/python/init>`__.
24+
2225
Args:
2326
exp_name (str): The name of the experiment.
27+
offline (bool, optional): if ``True``, the logs will be stored locally
28+
only. Defaults to ``False``.
29+
save_dir (path, optional): the directory where to save data. Exclusive with
30+
``log_dir``.
31+
log_dir (path, optional): the directory where to save data. Exclusive with
32+
``save_dir``.
33+
id (str, optional): A unique ID for this run, used for resuming.
34+
It must be unique in the project, and if you delete a run you can't reuse the ID.
35+
project (str, optional): The name of the project where you're sending
36+
the new run. If the project is not specified, the run is put in
37+
an ``"Uncategorized"`` project.
38+
**kwargs: Extra keyword arguments for ``wandb.init``. See relevant page for
39+
more info.
2440
2541
"""
2642

torchrl/record/recorder.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class VideoRecorder(ObservationTransform):
3030
3131
Args:
3232
logger (Logger): a Logger instance where the video
33-
should be written.
33+
should be written. To save the video under a memmap tensor or an mp4 file, use
34+
the :class:`~torchrl.record.loggers.CSVLogger` class.
3435
tag (str): the video tag in the logger.
3536
in_keys (Sequence of NestedKey, optional): keys to be read to produce the video.
3637
Default is :obj:`"pixels"`.
@@ -43,6 +44,29 @@ class VideoRecorder(ObservationTransform):
4344
out_keys (sequence of NestedKey, optional): destination keys. Defaults
4445
to ``in_keys`` if not provided.
4546
47+
Examples:
48+
The following example shows how to save a rollout under a video. First a few imports:
49+
>>> from torchrl.record import VideoRecorder
50+
>>> from torchrl.record.loggers.csv import CSVLogger
51+
>>> from torchrl.envs import TransformedEnv, DMControlEnv
52+
53+
The video format is chosen in the logger. Wandb and tensorboard will take care of that
54+
on their own, CSV accepts various video formats.
55+
>>> logger = CSVLogger(exp_name="cheetah", log_dir="cheetah_videos", video_format="mp4")
56+
57+
Some envs (eg, Atari games) natively return images, some require the user to ask for them.
58+
Check :class:`~torchrl.env.GymEnv` or :class:`~torchrl.envs.DMControlEnv` to see how to render images
59+
in these contexts.
60+
>>> base_env = DMControlEnv("cheetah", "run", from_pixels=True)
61+
>>> env = TransformedEnv(base_env, VideoRecorder(logger=logger, tag="run_video"))
62+
>>> env.rollout(100)
63+
64+
All transforms have a dump function, mostly a no-op except for ``VideoRecorder``, and :class:`~torchrl.envs.transforms.Composite`
65+
which will dispatch the `dumps` to all its members.
66+
>>> env.transform.dump()
67+
68+
Our video is available under ``./cheetah_videos/cheetah/videos/run_video_0.mp4``!
69+
4670
"""
4771

4872
def __init__(

0 commit comments

Comments
 (0)