diff --git a/sota-check/submitit-release-check.sh b/sota-check/submitit-release-check.sh index 515ac06a50b..a93d7a02846 100755 --- a/sota-check/submitit-release-check.sh +++ b/sota-check/submitit-release-check.sh @@ -13,7 +13,7 @@ EXAMPLES: ./submitit-release-check.sh --partition --n_runs 5 EOF - return 1 + exit 1 } # Check if the script is called with --help or without any arguments diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index fc388399878..9acd00b1627 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -190,6 +190,7 @@ def update(data, policy_eval_start, iteration): with timeit("log"): metrics_to_log.update(timeit.todict(prefix="time")) metrics_to_log["time/speed"] = pbar.format_dict["rate"] + log_metrics(logger, metrics_to_log, i) pbar.close() diff --git a/sota-implementations/cql/discrete_cql_config.yaml b/sota-implementations/cql/discrete_cql_config.yaml index a9fb9bfed0c..cb622e1e729 100644 --- a/sota-implementations/cql/discrete_cql_config.yaml +++ b/sota-implementations/cql/discrete_cql_config.yaml @@ -22,7 +22,7 @@ collector: # Logger logger: - backend: wandb + backend: csv project_name: torchrl_example_cql group_name: null exp_name: cql_cartpole_gym diff --git a/sota-implementations/cql/offline_config.yaml b/sota-implementations/cql/offline_config.yaml index a14604251c0..321c8ba0b1b 100644 --- a/sota-implementations/cql/offline_config.yaml +++ b/sota-implementations/cql/offline_config.yaml @@ -1,6 +1,6 @@ # env and task env: - name: Hopper-v4 + name: Hopper-v5 task: "" library: gym n_samples_stats: 1000 @@ -18,11 +18,11 @@ logger: eval_steps: 1000 mode: online eval_envs: 5 - video: False + video: True # replay buffer replay_buffer: - dataset: hopper-medium-v2 + dataset: mujoco/hopper/expert-v0 batch_size: 256 # optimization diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index 8bbc70a32c3..a4f412cb1cc 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -18,7 +18,7 @@ TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, ) -from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.datasets.minari_data import MinariExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import ( CatTensors, @@ -181,13 +181,13 @@ def make_replay_buffer( def make_offline_replay_buffer(rb_cfg): - data = D4RLExperienceReplay( + data = MinariExperienceReplay( dataset_id=rb_cfg.dataset, split_trajs=False, batch_size=rb_cfg.batch_size, sampler=SamplerWithoutReplacement(drop_last=True), prefetch=4, - direct_download=True, + download=True, ) data.append_transform(DoubleToFloat()) diff --git a/torchrl/record/loggers/wandb.py b/torchrl/record/loggers/wandb.py index 5c4b9d3ffc8..41382a01198 100644 --- a/torchrl/record/loggers/wandb.py +++ b/torchrl/record/loggers/wandb.py @@ -234,6 +234,6 @@ def log_str(self, name: str, value: str, step: int | None = None) -> None: table = wandb.Table(columns=["text"], data=[[value]]) if step is not None: - self.experiment.log({name: table, "trainer/step": step}) + self.experiment.log({name: value}, step=step) else: self.experiment.log({name: table})