Skip to content

Commit 611f3ee

Browse files
author
jorge.ibinarriaga.robles.becas
committed
[Feature] Replace D4RLExperienceReplay with MinariExperienceReplay in CQL offline
1 parent 773c366 commit 611f3ee

File tree

5 files changed

+13
-7
lines changed

5 files changed

+13
-7
lines changed

sota-implementations/cql/=0.9.0,

Whitespace-only changes.

sota-implementations/cql/cql_offline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ def update(data, policy_eval_start, iteration):
190190
with timeit("log"):
191191
metrics_to_log.update(timeit.todict(prefix="time"))
192192
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
193+
194+
if i % evaluation_interval == 0:
195+
print(
196+
f"Step {i}: loss={loss.item():.4f}, eval_reward={eval_reward:.4f}"
197+
)
193198
log_metrics(logger, metrics_to_log, i)
194199

195200
pbar.close()

sota-implementations/cql/discrete_cql_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ collector:
2222

2323
# Logger
2424
logger:
25-
backend: wandb
25+
backend: csv
2626
project_name: torchrl_example_cql
2727
group_name: null
2828
exp_name: cql_cartpole_gym

sota-implementations/cql/offline_config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# env and task
22
env:
3-
name: Hopper-v4
3+
name: Hopper-v5
44
task: ""
55
library: gym
66
n_samples_stats: 1000
@@ -9,7 +9,7 @@ env:
99

1010
# logger
1111
logger:
12-
backend: wandb
12+
backend: csv
1313
project_name: torchrl_example_cql
1414
group_name: null
1515
exp_name: cql_${replay_buffer.dataset}
@@ -18,11 +18,11 @@ logger:
1818
eval_steps: 1000
1919
mode: online
2020
eval_envs: 5
21-
video: False
21+
video: True
2222

2323
# replay buffer
2424
replay_buffer:
25-
dataset: hopper-medium-v2
25+
dataset: mujoco/hopper/expert-v0
2626
batch_size: 256
2727

2828
# optimization

sota-implementations/cql/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
TensorDictReplayBuffer,
2020
)
2121
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
22+
from torchrl.data.datasets.minari_data import MinariExperienceReplay
2223
from torchrl.data.replay_buffers import SamplerWithoutReplacement
2324
from torchrl.envs import (
2425
CatTensors,
@@ -181,13 +182,13 @@ def make_replay_buffer(
181182

182183

183184
def make_offline_replay_buffer(rb_cfg):
184-
data = D4RLExperienceReplay(
185+
data = MinariExperienceReplay(
185186
dataset_id=rb_cfg.dataset,
186187
split_trajs=False,
187188
batch_size=rb_cfg.batch_size,
188189
sampler=SamplerWithoutReplacement(drop_last=True),
189190
prefetch=4,
190-
direct_download=True,
191+
download=True,
191192
)
192193

193194
data.append_transform(DoubleToFloat())

0 commit comments

Comments
 (0)