Skip to content

[Feature] Migrate CQL from D4RLExperienceReplay to MinariExperienceReplay + fix W&B logging and SLURM usage #3035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sota-check/submitit-release-check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ EXAMPLES:
./submitit-release-check.sh --partition <PARTITION_NAME> --n_runs 5

EOF
return 1
exit 1
}

# Check if the script is called with --help or without any arguments
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def update(data, policy_eval_start, iteration):
with timeit("log"):
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]

log_metrics(logger, metrics_to_log, i)

pbar.close()
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/discrete_cql_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ collector:

# Logger
logger:
backend: wandb
backend: csv
project_name: torchrl_example_cql
group_name: null
exp_name: cql_cartpole_gym
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/cql/offline_config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# env and task
env:
name: Hopper-v4
name: Hopper-v5
task: ""
library: gym
n_samples_stats: 1000
Expand All @@ -18,11 +18,11 @@ logger:
eval_steps: 1000
mode: online
eval_envs: 5
video: False
video: True

# replay buffer
replay_buffer:
dataset: hopper-medium-v2
dataset: mujoco/hopper/expert-v0
batch_size: 256

# optimization
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
from torchrl.data.datasets.minari_data import MinariExperienceReplay
from torchrl.data.replay_buffers import SamplerWithoutReplacement
from torchrl.envs import (
CatTensors,
Expand Down Expand Up @@ -181,13 +181,13 @@ def make_replay_buffer(


def make_offline_replay_buffer(rb_cfg):
data = D4RLExperienceReplay(
data = MinariExperienceReplay(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced D4RLExperienceReplay with MinariExperienceReplay following official deprecation of D4RL datasets.
See: https://github.com/Farama-Foundation/d4rl

dataset_id=rb_cfg.dataset,
split_trajs=False,
batch_size=rb_cfg.batch_size,
sampler=SamplerWithoutReplacement(drop_last=True),
prefetch=4,
direct_download=True,
download=True,
)

data.append_transform(DoubleToFloat())
Expand Down
2 changes: 1 addition & 1 deletion torchrl/record/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,6 @@ def log_str(self, name: str, value: str, step: int | None = None) -> None:
table = wandb.Table(columns=["text"], data=[[value]])

if step is not None:
self.experiment.log({name: table, "trainer/step": step})
self.experiment.log({name: value}, step=step)
else:
self.experiment.log({name: table})