Skip to content

Commit b7d148b

Browse files
author
Vincent Moens
authored
[Refactor] Put all buffers on CPU in examples (#1645)
1 parent a67b9fb commit b7d148b

File tree

10 files changed

+45
-15
lines changed

10 files changed

+45
-15
lines changed

examples/cql/cql_online.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def main(cfg: "DictConfig"): # noqa: F821
5151
batch_size=cfg.optim.batch_size,
5252
prb=cfg.replay_buffer.prb,
5353
buffer_size=cfg.replay_buffer.size,
54-
device=device,
54+
device="cpu",
5555
)
5656

5757
# Make Model
@@ -104,7 +104,13 @@ def main(cfg: "DictConfig"): # noqa: F821
104104
(actor_losses, q_losses, alpha_losses, alpha_primes) = ([], [], [], [])
105105
for _ in range(num_updates):
106106
# sample from replay buffer
107-
sampled_tensordict = replay_buffer.sample().clone()
107+
sampled_tensordict = replay_buffer.sample()
108+
if sampled_tensordict.device != device:
109+
sampled_tensordict = sampled_tensordict.to(
110+
device, non_blocking=True
111+
)
112+
else:
113+
sampled_tensordict = sampled_tensordict.clone()
108114

109115
loss_td = loss_module(sampled_tensordict)
110116

examples/ddpg/ddpg.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def main(cfg: "DictConfig"): # noqa: F821
7070
prb=cfg.replay_buffer.prb,
7171
buffer_size=cfg.replay_buffer.size,
7272
buffer_scratch_dir=cfg.replay_buffer.scratch_dir,
73-
device=device,
73+
device="cpu",
7474
)
7575

7676
# Create optimizers
@@ -118,7 +118,13 @@ def main(cfg: "DictConfig"): # noqa: F821
118118
) = ([], [])
119119
for _ in range(num_updates):
120120
# Sample from replay buffer
121-
sampled_tensordict = replay_buffer.sample().clone()
121+
sampled_tensordict = replay_buffer.sample()
122+
if sampled_tensordict.device != device:
123+
sampled_tensordict = sampled_tensordict.to(
124+
device, non_blocking=True
125+
)
126+
else:
127+
sampled_tensordict = sampled_tensordict.clone()
122128

123129
# Update critic
124130
q_loss, *_ = loss_module.loss_value(sampled_tensordict)

examples/discrete_sac/discrete_sac.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def env_factory(num_workers):
201201
prb=cfg.prb,
202202
buffer_size=cfg.buffer_size,
203203
batch_size=cfg.batch_size,
204-
device=device,
204+
device="cpu",
205205
)
206206

207207
# Optimizers
@@ -255,7 +255,13 @@ def env_factory(num_workers):
255255
) = ([], [], [], [], [], [])
256256
for _ in range(cfg.frames_per_batch * int(cfg.utd_ratio)):
257257
# sample from replay buffer
258-
sampled_tensordict = replay_buffer.sample().clone()
258+
sampled_tensordict = replay_buffer.sample()
259+
if sampled_tensordict.device != device:
260+
sampled_tensordict = sampled_tensordict.to(
261+
device, non_blocking=True
262+
)
263+
else:
264+
sampled_tensordict = sampled_tensordict.clone()
259265

260266
loss_td = loss_module(sampled_tensordict)
261267

examples/dqn/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def main(cfg: "DictConfig"): # noqa: F821
115115
cfg=cfg,
116116
)
117117

118-
replay_buffer = make_replay_buffer(device, cfg)
118+
replay_buffer = make_replay_buffer("cpu", cfg)
119119

120120
recorder = transformed_env_constructor(
121121
cfg,

examples/dreamer/dreamer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821
186186
)
187187
print("collector:", collector)
188188

189-
replay_buffer = make_replay_buffer(device, cfg)
189+
replay_buffer = make_replay_buffer("cpu", cfg)
190190

191191
record = Recorder(
192192
record_frames=cfg.record_frames,

examples/iql/iql_online.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def env_factory(num_workers):
218218

219219
# Make Replay Buffer
220220
replay_buffer = make_replay_buffer(
221-
buffer_size=cfg.buffer_size, device=device, batch_size=cfg.batch_size
221+
buffer_size=cfg.buffer_size, device="cpu", batch_size=cfg.batch_size
222222
)
223223

224224
# Optimizers

examples/redq/redq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def main(cfg: "DictConfig"): # noqa: F821
161161
# ],
162162
)
163163

164-
replay_buffer = make_replay_buffer(device, cfg)
164+
replay_buffer = make_replay_buffer("cpu", cfg)
165165

166166
recorder = transformed_env_constructor(
167167
cfg,

examples/sac/sac.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def main(cfg: "DictConfig"): # noqa: F821
7070
prb=cfg.replay_buffer.prb,
7171
buffer_size=cfg.replay_buffer.size,
7272
buffer_scratch_dir=cfg.replay_buffer.scratch_dir,
73-
device=device,
73+
device="cpu",
7474
)
7575

7676
# Create optimizers
@@ -122,7 +122,13 @@ def main(cfg: "DictConfig"): # noqa: F821
122122
)
123123
for i in range(num_updates):
124124
# Sample from replay buffer
125-
sampled_tensordict = replay_buffer.sample().clone()
125+
sampled_tensordict = replay_buffer.sample()
126+
if sampled_tensordict.device != device:
127+
sampled_tensordict = sampled_tensordict.to(
128+
device, non_blocking=True
129+
)
130+
else:
131+
sampled_tensordict = sampled_tensordict.clone()
126132

127133
# Compute loss
128134
loss_td = loss_module(sampled_tensordict)

examples/td3/td3.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def main(cfg: "DictConfig"): # noqa: F821
7070
prb=cfg.replay_buffer.prb,
7171
buffer_size=cfg.replay_buffer.size,
7272
buffer_scratch_dir=cfg.replay_buffer.scratch_dir,
73-
device=device,
73+
device="cpu",
7474
)
7575

7676
# Create optimizers
@@ -124,7 +124,13 @@ def main(cfg: "DictConfig"): # noqa: F821
124124
update_actor = update_counter % delayed_updates == 0
125125

126126
# Sample from replay buffer
127-
sampled_tensordict = replay_buffer.sample().clone()
127+
sampled_tensordict = replay_buffer.sample()
128+
if sampled_tensordict.device != device:
129+
sampled_tensordict = sampled_tensordict.to(
130+
device, non_blocking=True
131+
)
132+
else:
133+
sampled_tensordict = sampled_tensordict.clone()
128134

129135
# Compute loss
130136
q_loss, *_ = loss_module.value_loss(sampled_tensordict)

torchrl/data/replay_buffers/storages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ def _collate_contiguous(x):
740740

741741

742742
def _collate_as_tensor(x):
743-
return x.contiguous()
743+
return x.as_tensor()
744744

745745

746746
def _get_default_collate(storage, _is_tensordict=False):

0 commit comments

Comments
 (0)