Skip to content

Commit 61e05b3

Browse files
author
Vincent Moens
committed
[BugFix, BE] Document and fix fps passing in recorder and loggers
ghstack-source-id: b3996a9 Pull Request resolved: #2694
1 parent ff1ff7e commit 61e05b3

File tree

4 files changed

+32
-9
lines changed

4 files changed

+32
-9
lines changed

torchrl/record/loggers/csv.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
class CSVExperiment:
2222
"""A CSV logger experiment class."""
2323

24-
def __init__(self, log_dir: str, *, video_format="pt", video_fps=30):
24+
def __init__(self, log_dir: str, *, video_format="pt", video_fps: int = 30):
2525
self.scalars = defaultdict(lambda: [])
2626
self.videos_counter = defaultdict(lambda: 0)
2727
self.text_counter = defaultdict(lambda: 0)
@@ -144,6 +144,8 @@ class CSVLogger(Logger):
144144
145145
"""
146146

147+
experiment: CSVExperiment
148+
147149
def __init__(
148150
self,
149151
exp_name: str,

torchrl/record/loggers/mlflow.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,19 @@ class MLFlowLogger(Logger):
2424
Args:
2525
exp_name (str): The name of the experiment.
2626
tracking_uri (str): A tracking URI to a datastore that supports MLFlow or a local directory.
27+
28+
Keyword Args:
29+
fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``.
30+
2731
"""
2832

2933
def __init__(
3034
self,
3135
exp_name: str,
3236
tracking_uri: str,
3337
tags: Optional[Dict[str, Any]] = None,
38+
*,
39+
video_fps: int = 30,
3440
**kwargs,
3541
) -> None:
3642
import mlflow
@@ -43,6 +49,7 @@ def __init__(
4349
mlflow.set_tracking_uri(tracking_uri)
4450
super().__init__(exp_name=exp_name, log_dir=tracking_uri)
4551
self.video_log_counter = 0
52+
self.video_fps = video_fps
4653

4754
def _create_experiment(self) -> "mlflow.ActiveRun": # noqa
4855
import mlflow
@@ -85,7 +92,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
8592
video (Tensor): The video to be logged, expected to be in (T, C, H, W) format
8693
for consistency with other loggers.
8794
**kwargs: Other keyword arguments. By construction, log_video
88-
supports 'step' (integer indicating the step index) and 'fps' (default: 6).
95+
supports 'step' (integer indicating the step index) and 'fps' (defaults to ``self.video_fps``).
8996
"""
9097
import mlflow
9198
import torchvision
@@ -103,7 +110,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
103110
"The MLFlow logger only supports videos with 3 color channels."
104111
)
105112
self.video_log_counter += 1
106-
fps = kwargs.pop("fps", 6)
113+
fps = kwargs.pop("fps", self.video_fps)
107114
step = kwargs.pop("step", None)
108115
with TemporaryDirectory() as temp_dir:
109116
video_name = f"{name}_step_{step:04}.mp4" if step else f"{name}.mp4"

torchrl/record/loggers/wandb.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ class WandbLogger(Logger):
3535
project (str, optional): The name of the project where you're sending
3636
the new run. If the project is not specified, the run is put in
3737
an ``"Uncategorized"`` project.
38+
39+
Keyword Args:
40+
fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``.
3841
**kwargs: Extra keyword arguments for ``wandb.init``. See relevant page for
3942
more info.
4043
@@ -52,6 +55,8 @@ def __init__(
5255
save_dir: str = None,
5356
id: str = None,
5457
project: str = None,
58+
*,
59+
video_fps: int = 32,
5560
**kwargs,
5661
) -> None:
5762
if not _has_wandb:
@@ -68,6 +73,7 @@ def __init__(
6873
self.save_dir = save_dir
6974
self.id = id
7075
self.project = project
76+
self.video_fps = video_fps
7177
self._wandb_kwargs = {
7278
"name": exp_name,
7379
"dir": save_dir,
@@ -127,7 +133,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
127133
video (Tensor): The video to be logged.
128134
**kwargs: Other keyword arguments. By construction, log_video
129135
supports 'step' (integer indicating the step index), 'format'
130-
(default is 'mp4') and 'fps' (default: 6). Other kwargs are
136+
(default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are
131137
passed as-is to the :obj:`experiment.log` method.
132138
"""
133139
import wandb
@@ -148,7 +154,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
148154
"moviepy not found, videos cannot be logged with TensorboardLogger"
149155
)
150156
self.video_log_counter += 1
151-
fps = kwargs.pop("fps", 6)
157+
fps = kwargs.pop("fps", self.video_fps)
152158
step = kwargs.pop("step", None)
153159
format = kwargs.pop("format", "mp4")
154160
if step not in (None, self._prev_video_step, self._prev_video_step + 1):

torchrl/record/recorder.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ class VideoRecorder(ObservationTransform):
4949
if not.
5050
out_keys (sequence of NestedKey, optional): destination keys. Defaults
5151
to ``in_keys`` if not provided.
52+
fps (int, optional): Frames per second of the output video. Defaults to the logger predefined ``fps``,
53+
and overrides it if provided.
54+
**kwargs (Dict[str, Any], optional): additional keyword arguments for
55+
:meth:`~torchrl.record.loggers.Logger.log_video`.
5256
5357
Examples:
5458
The following example shows how to save a rollout under a video. First a few imports:
@@ -81,10 +85,11 @@ class VideoRecorder(ObservationTransform):
8185
>>> from torchrl.data.datasets import OpenXExperienceReplay
8286
>>> from torchrl.envs import Compose
8387
>>> from torchrl.record import VideoRecorder, CSVLogger
84-
>>> # Create a logger that saves videos as mp4
85-
>>> logger = CSVLogger("./dump", video_format="mp4")
88+
>>> # Create a logger that saves videos as mp4 using 24 frames per sec
89+
>>> logger = CSVLogger("./dump", video_format="mp4", video_fps=24)
8690
>>> # We use the VideoRecorder transform to save register the images coming from the batch.
87-
>>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")])
91+
>>> # Setting the fps to 12 overrides the one set in the logger, not doing so keeps it unchanged.
92+
>>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")], fps=12)
8893
>>> # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False)
8994
>>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=2000, slice_len=200,
9095
... download=True, strict_length=False,
@@ -108,15 +113,18 @@ def __init__(
108113
center_crop: Optional[int] = None,
109114
make_grid: bool | None = None,
110115
out_keys: Optional[Sequence[NestedKey]] = None,
116+
fps: int | None = None,
111117
**kwargs,
112118
) -> None:
113119
if in_keys is None:
114120
in_keys = ["pixels"]
115121
if out_keys is None:
116122
out_keys = copy(in_keys)
117123
super().__init__(in_keys=in_keys, out_keys=out_keys)
118-
video_kwargs = {"fps": 6}
124+
video_kwargs = {}
119125
video_kwargs.update(kwargs)
126+
if fps is not None:
127+
self.video_kwargs["fps"] = fps
120128
self.video_kwargs = video_kwargs
121129
self.iter = 0
122130
self.skip = skip

0 commit comments

Comments
 (0)