diff --git a/torchrl/record/loggers/wandb.py b/torchrl/record/loggers/wandb.py index 3d23a485458..bb7d4dadf6e 100644 --- a/torchrl/record/loggers/wandb.py +++ b/torchrl/record/loggers/wandb.py @@ -47,7 +47,6 @@ class WandbLogger(Logger): @classmethod def __new__(cls, *args, **kwargs): - cls._prev_video_step = -1 return super().__new__(cls) def __init__( @@ -95,7 +94,7 @@ def __init__( self.video_log_counter = 0 - def _create_experiment(self) -> WandbLogger: + def _create_experiment(self) -> "wandb.Experiment": """Creates a wandb experiment. Args: @@ -122,10 +121,7 @@ def log_scalar(self, name: str, value: float, step: int | None = None) -> None: step (int, optional): The step at which the scalar is logged. Defaults to None. """ - if step is not None: - self.experiment.log({name: value, "trainer/step": step}) - else: - self.experiment.log({name: value}) + self.experiment.log({name: value}, step=step) def log_video(self, name: str, video: Tensor, **kwargs) -> None: """Log videos inputs to wandb. @@ -139,39 +135,11 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: passed as-is to the :obj:`experiment.log` method. """ import wandb - - # check for correct format of the video tensor ((N), T, C, H, W) - # check that the color channel (C) is either 1 or 3 - if video.dim() != 5 or video.size(dim=2) not in {1, 3}: - raise Exception( - "Wrong format of the video tensor. Should be ((N), T, C, H, W)" - ) - if not self._has_imported_moviepy: - try: - import moviepy # noqa - - self._has_imported_moviepy = True - except ImportError: - raise Exception( - "moviepy not found, videos cannot be logged with TensorboardLogger" - ) self.video_log_counter += 1 fps = kwargs.pop("fps", self.video_fps) - step = kwargs.pop("step", None) format = kwargs.pop("format", "mp4") - if step not in (None, self._prev_video_step, self._prev_video_step + 1): - warnings.warn( - "when using step with wandb_logger.log_video, it is expected " - "that the step is equal to the previous step or that value incremented " - f"by one. Got step={step} but previous value was {self._prev_video_step}. " - f"The step value will be set to {self._prev_video_step+1}. This warning will " - f"be silenced from now on but the values will keep being incremented." - ) - step = self._prev_video_step + 1 - self._prev_video_step = step if step is not None else self._prev_video_step + 1 self.experiment.log( {name: wandb.Video(video, fps=fps, format=format)}, - # step=step, **kwargs, )