Skip to content

Commit 3cf1df0

Browse files
Ibinarriaga8jorge.ibinarriaga.robles.becasvmoens
authored
[Feature] Migrate CQL from D4RLExperienceReplay to MinariExperienceReplay + fix W&B logging and SLURM usage (#3035)
Co-authored-by: jorge.ibinarriaga.robles.becas <jorge.ibinarriaga.robles.becas@bbva.com> Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
1 parent 2c45bde commit 3cf1df0

File tree

6 files changed

+10
-9
lines changed

6 files changed

+10
-9
lines changed

sota-check/submitit-release-check.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ EXAMPLES:
1313
./submitit-release-check.sh --partition <PARTITION_NAME> --n_runs 5
1414
1515
EOF
16-
return 1
16+
exit 1
1717
}
1818

1919
# Check if the script is called with --help or without any arguments

sota-implementations/cql/cql_offline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def update(data, policy_eval_start, iteration):
190190
with timeit("log"):
191191
metrics_to_log.update(timeit.todict(prefix="time"))
192192
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
193+
193194
log_metrics(logger, metrics_to_log, i)
194195

195196
pbar.close()

sota-implementations/cql/discrete_cql_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ collector:
2222

2323
# Logger
2424
logger:
25-
backend: wandb
25+
backend: csv
2626
project_name: torchrl_example_cql
2727
group_name: null
2828
exp_name: cql_cartpole_gym

sota-implementations/cql/offline_config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# env and task
22
env:
3-
name: Hopper-v4
3+
name: Hopper-v5
44
task: ""
55
library: gym
66
n_samples_stats: 1000
@@ -18,11 +18,11 @@ logger:
1818
eval_steps: 1000
1919
mode: online
2020
eval_envs: 5
21-
video: False
21+
video: True
2222

2323
# replay buffer
2424
replay_buffer:
25-
dataset: hopper-medium-v2
25+
dataset: mujoco/hopper/expert-v0
2626
batch_size: 256
2727

2828
# optimization

sota-implementations/cql/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
TensorDictPrioritizedReplayBuffer,
1919
TensorDictReplayBuffer,
2020
)
21-
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
21+
from torchrl.data.datasets.minari_data import MinariExperienceReplay
2222
from torchrl.data.replay_buffers import SamplerWithoutReplacement
2323
from torchrl.envs import (
2424
CatTensors,
@@ -181,13 +181,13 @@ def make_replay_buffer(
181181

182182

183183
def make_offline_replay_buffer(rb_cfg):
184-
data = D4RLExperienceReplay(
184+
data = MinariExperienceReplay(
185185
dataset_id=rb_cfg.dataset,
186186
split_trajs=False,
187187
batch_size=rb_cfg.batch_size,
188188
sampler=SamplerWithoutReplacement(drop_last=True),
189189
prefetch=4,
190-
direct_download=True,
190+
download=True,
191191
)
192192

193193
data.append_transform(DoubleToFloat())

torchrl/record/loggers/wandb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,6 @@ def log_str(self, name: str, value: str, step: int | None = None) -> None:
234234
table = wandb.Table(columns=["text"], data=[[value]])
235235

236236
if step is not None:
237-
self.experiment.log({name: table, "trainer/step": step})
237+
self.experiment.log({name: value}, step=step)
238238
else:
239239
self.experiment.log({name: table})

0 commit comments

Comments
 (0)