Skip to content

Commit 001cf33

Browse files
author
Vincent Moens
authored
[Refactor] Refactor DDPG loss in standalone methods (#1603)
1 parent 5501d4a commit 001cf33

File tree

3 files changed

+32
-31
lines changed

3 files changed

+32
-31
lines changed

examples/ddpg/ddpg.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,18 +120,14 @@ def main(cfg: "DictConfig"): # noqa: F821
120120
# Sample from replay buffer
121121
sampled_tensordict = replay_buffer.sample().clone()
122122

123-
# Compute loss
124-
loss_td = loss_module(sampled_tensordict)
125-
126-
actor_loss = loss_td["loss_actor"]
127-
q_loss = loss_td["loss_value"]
128-
129123
# Update critic
124+
q_loss, *_ = loss_module.loss_value(sampled_tensordict)
130125
optimizer_critic.zero_grad()
131126
q_loss.backward()
132127
optimizer_critic.step()
133128

134129
# Update actor
130+
actor_loss, *_ = loss_module.loss_actor(sampled_tensordict)
135131
optimizer_actor.zero_grad()
136132
actor_loss.backward()
137133
optimizer_actor.step()

test/test_cost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1765,7 +1765,7 @@ def test_ddpg_notensordict(self):
17651765
with pytest.warns(UserWarning, match="No target network updater has been"):
17661766
loss_val_td = loss(td)
17671767
loss_val = loss(**kwargs)
1768-
for i, key in enumerate(loss_val_td.keys()):
1768+
for i, key in enumerate(loss.out_keys):
17691769
torch.testing.assert_close(loss_val_td.get(key), loss_val[i])
17701770
# test select
17711771
loss.select_out_keys("loss_actor", "target_value")

torchrl/objectives/ddpg.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -280,32 +280,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
280280
a tuple of 2 tensors containing the DDPG loss.
281281
282282
"""
283-
loss_value, td_error, pred_val, target_value = self._loss_value(tensordict)
284-
td_error = td_error.detach()
285-
if tensordict.device is not None:
286-
td_error = td_error.to(tensordict.device)
287-
tensordict.set(
288-
self.tensor_keys.priority,
289-
td_error,
290-
inplace=True,
291-
)
292-
loss_actor = self._loss_actor(tensordict)
283+
loss_value, metadata = self.loss_value(tensordict)
284+
loss_actor, metadata_actor = self.loss_actor(tensordict)
285+
metadata.update(metadata_actor)
293286
return TensorDict(
294-
source={
295-
"loss_actor": loss_actor.mean(),
296-
"loss_value": loss_value.mean(),
297-
"pred_value": pred_val.mean().detach(),
298-
"target_value": target_value.mean().detach(),
299-
"pred_value_max": pred_val.max().detach(),
300-
"target_value_max": target_value.max().detach(),
301-
},
287+
source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata},
302288
batch_size=[],
303289
)
304290

305-
def _loss_actor(
291+
def loss_actor(
306292
self,
307293
tensordict: TensorDictBase,
308-
) -> torch.Tensor:
294+
) -> [torch.Tensor, dict]:
309295
td_copy = tensordict.select(
310296
*self.actor_in_keys, *self.value_exclusive_keys
311297
).detach()
@@ -317,12 +303,14 @@ def _loss_actor(
317303
td_copy,
318304
params=self._cached_detached_value_params,
319305
)
320-
return -td_copy.get(self.tensor_keys.state_action_value)
306+
loss_actor = -td_copy.get(self.tensor_keys.state_action_value)
307+
metadata = {}
308+
return loss_actor.mean(), metadata
321309

322-
def _loss_value(
310+
def loss_value(
323311
self,
324312
tensordict: TensorDictBase,
325-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
313+
) -> Tuple[torch.Tensor, dict]:
326314
# value loss
327315
td_copy = tensordict.select(*self.value_network.in_keys).detach()
328316
self.value_network(
@@ -340,7 +328,24 @@ def _loss_value(
340328
pred_val, target_value, loss_function=self.loss_function
341329
)
342330

343-
return loss_value, (pred_val - target_value).pow(2), pred_val, target_value
331+
td_error = (pred_val - target_value).pow(2)
332+
td_error = td_error.detach()
333+
if tensordict.device is not None:
334+
td_error = td_error.to(tensordict.device)
335+
tensordict.set(
336+
self.tensor_keys.priority,
337+
td_error,
338+
inplace=True,
339+
)
340+
with torch.no_grad():
341+
metadata = {
342+
"td_error": td_error.mean(),
343+
"pred_value": pred_val.mean(),
344+
"target_value": target_value.mean(),
345+
"target_value_max": target_value.max(),
346+
"pred_value_max": pred_val.max(),
347+
}
348+
return loss_value.mean(), metadata
344349

345350
def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
346351
if value_type is None:

0 commit comments

Comments
 (0)