From 3792dc694c2345f931985bbd3a0617b1b50f0934 Mon Sep 17 00:00:00 2001 From: "jorge.ibinarriaga.robles.becas" Date: Tue, 1 Jul 2025 12:03:26 +0200 Subject: [PATCH 1/6] [Feature] Replace D4RLExperienceReplay with MinariExperienceReplay in CQL offline --- sota-implementations/cql/=0.9.0, | 0 sota-implementations/cql/cql_offline.py | 5 +++++ sota-implementations/cql/discrete_cql_config.yaml | 2 +- sota-implementations/cql/offline_config.yaml | 8 ++++---- sota-implementations/cql/utils.py | 5 +++-- 5 files changed, 13 insertions(+), 7 deletions(-) create mode 100644 sota-implementations/cql/=0.9.0, diff --git a/sota-implementations/cql/=0.9.0, b/sota-implementations/cql/=0.9.0, new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index fc388399878..f5734a8984e 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -190,6 +190,11 @@ def update(data, policy_eval_start, iteration): with timeit("log"): metrics_to_log.update(timeit.todict(prefix="time")) metrics_to_log["time/speed"] = pbar.format_dict["rate"] + + if i % evaluation_interval == 0: + print( + f"Step {i}: loss={loss.item():.4f}, eval_reward={eval_reward:.4f}" + ) log_metrics(logger, metrics_to_log, i) pbar.close() diff --git a/sota-implementations/cql/discrete_cql_config.yaml b/sota-implementations/cql/discrete_cql_config.yaml index a9fb9bfed0c..cb622e1e729 100644 --- a/sota-implementations/cql/discrete_cql_config.yaml +++ b/sota-implementations/cql/discrete_cql_config.yaml @@ -22,7 +22,7 @@ collector: # Logger logger: - backend: wandb + backend: csv project_name: torchrl_example_cql group_name: null exp_name: cql_cartpole_gym diff --git a/sota-implementations/cql/offline_config.yaml b/sota-implementations/cql/offline_config.yaml index a14604251c0..c12db9f51f9 100644 --- a/sota-implementations/cql/offline_config.yaml +++ b/sota-implementations/cql/offline_config.yaml @@ -1,6 +1,6 @@ # env and task env: - name: Hopper-v4 + name: Hopper-v5 task: "" library: gym n_samples_stats: 1000 @@ -9,7 +9,7 @@ env: # logger logger: - backend: wandb + backend: csv project_name: torchrl_example_cql group_name: null exp_name: cql_${replay_buffer.dataset} @@ -18,11 +18,11 @@ logger: eval_steps: 1000 mode: online eval_envs: 5 - video: False + video: True # replay buffer replay_buffer: - dataset: hopper-medium-v2 + dataset: mujoco/hopper/expert-v0 batch_size: 256 # optimization diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index 8bbc70a32c3..cf77a158c3e 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -19,6 +19,7 @@ TensorDictReplayBuffer, ) from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.datasets.minari_data import MinariExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import ( CatTensors, @@ -181,13 +182,13 @@ def make_replay_buffer( def make_offline_replay_buffer(rb_cfg): - data = D4RLExperienceReplay( + data = MinariExperienceReplay( dataset_id=rb_cfg.dataset, split_trajs=False, batch_size=rb_cfg.batch_size, sampler=SamplerWithoutReplacement(drop_last=True), prefetch=4, - direct_download=True, + download=True, ) data.append_transform(DoubleToFloat()) From 30ea320f627fabd9633d448884e099c82b6ff5a7 Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Fri, 4 Jul 2025 12:41:54 +0200 Subject: [PATCH 2/6] [Refactor] Update logging backend to use wandb --- sota-implementations/cql/cql_offline.py | 4 ---- sota-implementations/cql/offline_config.yaml | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index f5734a8984e..9acd00b1627 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -191,10 +191,6 @@ def update(data, policy_eval_start, iteration): metrics_to_log.update(timeit.todict(prefix="time")) metrics_to_log["time/speed"] = pbar.format_dict["rate"] - if i % evaluation_interval == 0: - print( - f"Step {i}: loss={loss.item():.4f}, eval_reward={eval_reward:.4f}" - ) log_metrics(logger, metrics_to_log, i) pbar.close() diff --git a/sota-implementations/cql/offline_config.yaml b/sota-implementations/cql/offline_config.yaml index c12db9f51f9..321c8ba0b1b 100644 --- a/sota-implementations/cql/offline_config.yaml +++ b/sota-implementations/cql/offline_config.yaml @@ -9,7 +9,7 @@ env: # logger logger: - backend: csv + backend: wandb project_name: torchrl_example_cql group_name: null exp_name: cql_${replay_buffer.dataset} From 0fc3a5fbdf331f0a2b4f96bd2a2ea12d82151c7c Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Fri, 4 Jul 2025 12:42:23 +0200 Subject: [PATCH 3/6] [Fix] log WandB scalars using correct step argument --- torchrl/record/loggers/wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/record/loggers/wandb.py b/torchrl/record/loggers/wandb.py index 5c4b9d3ffc8..41382a01198 100644 --- a/torchrl/record/loggers/wandb.py +++ b/torchrl/record/loggers/wandb.py @@ -234,6 +234,6 @@ def log_str(self, name: str, value: str, step: int | None = None) -> None: table = wandb.Table(columns=["text"], data=[[value]]) if step is not None: - self.experiment.log({name: table, "trainer/step": step}) + self.experiment.log({name: value}, step=step) else: self.experiment.log({name: table}) From b1238ebfb6e2da4856b98f2fad61d5718c653a43 Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Fri, 4 Jul 2025 13:22:25 +0200 Subject: [PATCH 4/6] [Fix] Correct exit behavior in usage display function --- sota-check/submitit-release-check.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-check/submitit-release-check.sh b/sota-check/submitit-release-check.sh index 515ac06a50b..a93d7a02846 100755 --- a/sota-check/submitit-release-check.sh +++ b/sota-check/submitit-release-check.sh @@ -13,7 +13,7 @@ EXAMPLES: ./submitit-release-check.sh --partition --n_runs 5 EOF - return 1 + exit 1 } # Check if the script is called with --help or without any arguments From 9eba3881371717589b33ca4e5d35c16f05d9e352 Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Fri, 4 Jul 2025 13:39:54 +0200 Subject: [PATCH 5/6] [Refactor]: fix linting errors to pass pre-commit checks --- sota-implementations/cql/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index cf77a158c3e..a4f412cb1cc 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -18,8 +18,7 @@ TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, ) -from torchrl.data.datasets.d4rl import D4RLExperienceReplay -from torchrl.data.datasets.minari_data import MinariExperienceReplay +from torchrl.data.datasets.minari_data import MinariExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import ( CatTensors, From 4e4ff716dee1309dc69c181b8f9651477d95bb27 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 8 Jul 2025 11:15:44 +0100 Subject: [PATCH 6/6] Delete sota-implementations/cql/=0.9.0, --- sota-implementations/cql/=0.9.0, | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 sota-implementations/cql/=0.9.0, diff --git a/sota-implementations/cql/=0.9.0, b/sota-implementations/cql/=0.9.0, deleted file mode 100644 index e69de29bb2d..00000000000