Skip to content

Commit ffa99b2

Browse files
author
Vincent Moens
committed
[BugFix] Fix compile weakrefs errors
ghstack-source-id: 3cb4c62 Pull Request resolved: #2742
1 parent bb9440b commit ffa99b2

20 files changed

+390
-56
lines changed

.github/workflows/benchmarks.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ jobs:
5050
5151
cd benchmarks/
5252
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
53+
export COMPOSITE_LP_AGGREGATE=0
5354
export TD_GET_DEFAULTS_TO_NONE=1
5455
python3 -m pytest -vvv --rank 0 --benchmark-json output.json --ignore test_collectors_benchmark.py
5556
- name: Store benchmark results
@@ -131,6 +132,7 @@ jobs:
131132
132133
cd benchmarks/
133134
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
135+
export COMPOSITE_LP_AGGREGATE=0
134136
export TD_GET_DEFAULTS_TO_NONE=1
135137
python3 -m pytest -vvv --rank 0 --benchmark-json output.json --ignore test_collectors_benchmark.py
136138
- name: Store benchmark results

.github/workflows/benchmarks_pr.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ jobs:
4848
4949
cd benchmarks/
5050
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
51+
export COMPOSITE_LP_AGGREGATE=0
5152
export TD_GET_DEFAULTS_TO_NONE=1
5253
RUN_BENCHMARK="python3 -m pytest -vvv --rank 0 --ignore test_collectors_benchmark.py --benchmark-json "
5354
git checkout ${{ github.event.pull_request.base.sha }}
@@ -141,6 +142,7 @@ jobs:
141142
142143
cd benchmarks/
143144
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
145+
export COMPOSITE_LP_AGGREGATE=0
144146
export TD_GET_DEFAULTS_TO_NONE=1
145147
RUN_BENCHMARK="python3 -m pytest -vvv --rank 0 --ignore test_collectors_benchmark.py --benchmark-json "
146148
git checkout ${{ github.event.pull_request.base.sha }}

benchmarks/test_objectives_benchmarks.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from tensordict import TensorDict
1212
from tensordict.nn import (
13+
composite_lp_aggregate,
1314
InteractionType,
1415
NormalParamExtractor,
1516
ProbabilisticTensorDictModule as ProbMod,
@@ -785,11 +786,15 @@ def test_a2c_speed(
785786
device=device,
786787
)
787788
batch = [batch, T]
789+
if composite_lp_aggregate():
790+
raise RuntimeError(
791+
"Expected composite_lp_aggregate() to return False. Use set_composite_lp_aggregate or COMPOSITE_LP_AGGREGATE env variable."
792+
)
788793
td = TensorDict(
789794
{
790795
"obs": torch.randn(*batch, n_obs),
791796
"action": torch.randn(*batch, n_act),
792-
"sample_log_prob": torch.randn(*batch),
797+
"action_log_prob": torch.randn(*batch),
793798
"done": torch.zeros(*batch, 1, dtype=torch.bool),
794799
"next": {
795800
"obs": torch.randn(*batch, n_obs),
@@ -884,11 +889,15 @@ def test_ppo_speed(
884889
device=device,
885890
)
886891
batch = [batch, T]
892+
if composite_lp_aggregate():
893+
raise RuntimeError(
894+
"Expected composite_lp_aggregate() to return False. Use set_composite_lp_aggregate or COMPOSITE_LP_AGGREGATE env variable."
895+
)
887896
td = TensorDict(
888897
{
889898
"obs": torch.randn(*batch, n_obs),
890899
"action": torch.randn(*batch, n_act),
891-
"sample_log_prob": torch.randn(*batch),
900+
"action_log_prob": torch.randn(*batch),
892901
"done": torch.zeros(*batch, 1, dtype=torch.bool),
893902
"next": {
894903
"obs": torch.randn(*batch, n_obs),
@@ -983,11 +992,15 @@ def test_reinforce_speed(
983992
device=device,
984993
)
985994
batch = [batch, T]
995+
if composite_lp_aggregate():
996+
raise RuntimeError(
997+
"Expected composite_lp_aggregate() to return False. Use set_composite_lp_aggregate or COMPOSITE_LP_AGGREGATE env variable."
998+
)
986999
td = TensorDict(
9871000
{
9881001
"obs": torch.randn(*batch, n_obs),
9891002
"action": torch.randn(*batch, n_act),
990-
"sample_log_prob": torch.randn(*batch),
1003+
"action_log_prob": torch.randn(*batch),
9911004
"done": torch.zeros(*batch, 1, dtype=torch.bool),
9921005
"next": {
9931006
"obs": torch.randn(*batch, n_obs),
@@ -1089,11 +1102,15 @@ def test_iql_speed(
10891102
device=device,
10901103
)
10911104
batch = [batch, T]
1105+
if composite_lp_aggregate():
1106+
raise RuntimeError(
1107+
"Expected composite_lp_aggregate() to return False. Use set_composite_lp_aggregate or COMPOSITE_LP_AGGREGATE env variable."
1108+
)
10921109
td = TensorDict(
10931110
{
10941111
"obs": torch.randn(*batch, n_obs),
10951112
"action": torch.randn(*batch, n_act),
1096-
"sample_log_prob": torch.randn(*batch),
1113+
"action_log_prob": torch.randn(*batch),
10971114
"done": torch.zeros(*batch, 1, dtype=torch.bool),
10981115
"next": {
10991116
"obs": torch.randn(*batch, n_obs),

torchrl/objectives/a2c.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,10 @@ def __post_init__(self):
262262

263263
actor_network: TensorDictModule
264264
critic_network: TensorDictModule
265-
actor_network_params: TensorDictParams
266-
critic_network_params: TensorDictParams
267-
target_actor_network_params: TensorDictParams
268-
target_critic_network_params: TensorDictParams
265+
actor_network_params: TensorDictParams | None
266+
critic_network_params: TensorDictParams | None
267+
target_actor_network_params: TensorDictParams | None
268+
target_critic_network_params: TensorDictParams | None
269269

270270
def __init__(
271271
self,
@@ -521,6 +521,13 @@ def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]:
521521
loss_value,
522522
self.loss_critic_type,
523523
)
524+
self._clear_weakrefs(
525+
tensordict,
526+
"actor_network_params",
527+
"critic_network_params",
528+
"target_actor_network_params",
529+
"target_critic_network_params",
530+
)
524531
if self.critic_coef is not None:
525532
return self.critic_coef * loss_value, clip_fraction
526533
return loss_value, clip_fraction
@@ -559,7 +566,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
559566
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
560567
if name.startswith("loss_")
561568
else value,
562-
batch_size=[],
569+
)
570+
self._clear_weakrefs(
571+
tensordict,
572+
td_out,
573+
"actor_network_params",
574+
"critic_network_params",
575+
"target_actor_network_params",
576+
"target_critic_network_params",
563577
)
564578
return td_out
565579

torchrl/objectives/common.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@
2727
from torchrl.objectives.value import ValueEstimatorBase
2828

2929
try:
30-
from torch.compiler import is_dynamo_compiling
30+
from torch.compiler import is_compiling
3131
except ImportError:
32-
from torch._dynamo import is_compiling as is_dynamo_compiling
32+
from torch._dynamo import is_compiling
3333

3434

3535
def _updater_check_forward_prehook(module, *args, **kwargs):
3636
if (
3737
not all(module._has_update_associated.values())
3838
and RL_WARNINGS
39-
and not is_dynamo_compiling()
39+
and not is_compiling()
4040
):
4141
warnings.warn(
4242
module.TARGET_NET_WARNING,
@@ -415,6 +415,7 @@ def _compare_and_expand(param):
415415
params.set(key, parameter.data)
416416

417417
setattr(self, param_name, params)
418+
assert getattr(self, param_name) is params, getattr(self, param_name)
418419

419420
# Set the module in the __dict__ directly to avoid listing its params
420421
# A deepcopy with meta device could be used but that assumes that the model is copyable!
@@ -433,6 +434,16 @@ def _compare_and_expand(param):
433434
setattr(self, name_params_target + "_params", target_params)
434435
self._has_update_associated[module_name] = not create_target_params
435436

437+
def _clear_weakrefs(self, *tds):
438+
if is_compiling():
439+
# Waiting for weakrefs reconstruct to be supported by compile
440+
for td in tds:
441+
if isinstance(td, str):
442+
td = getattr(self, td, None)
443+
if not is_tensor_collection(td):
444+
continue
445+
td.clear_refs_for_compile_()
446+
436447
def __getattr__(self, item):
437448
if item.startswith("target_") and item.endswith("_params"):
438449
params = self._modules.get(item, None)
@@ -443,7 +454,7 @@ def __getattr__(self, item):
443454
elif (
444455
not self._has_update_associated[item[7:-7]]
445456
and RL_WARNINGS
446-
and not is_dynamo_compiling()
457+
and not is_compiling()
447458
):
448459
# no updater associated
449460
warnings.warn(

torchrl/objectives/cql.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
542542
}
543543
if self.with_lagrange:
544544
out["loss_alpha_prime"] = alpha_prime_loss.mean()
545-
return TensorDict(out, [])
545+
td_loss = TensorDict(out)
546+
self._clear_weakrefs(
547+
tensordict,
548+
td_loss,
549+
"actor_network_params",
550+
"qvalue_network_params",
551+
"target_actor_network_params",
552+
"target_qvalue_network_params",
553+
)
554+
return td_loss
546555

547556
@property
548557
@_cache_values
@@ -563,6 +572,13 @@ def actor_bc_loss(self, tensordict: TensorDictBase) -> Tensor:
563572
bc_actor_loss = self._alpha * log_prob - bc_log_prob
564573
bc_actor_loss = _reduce(bc_actor_loss, reduction=self.reduction)
565574
metadata = {"bc_log_prob": bc_log_prob.mean().detach()}
575+
self._clear_weakrefs(
576+
tensordict,
577+
"actor_network_params",
578+
"qvalue_network_params",
579+
"target_actor_network_params",
580+
"target_qvalue_network_params",
581+
)
566582
return bc_actor_loss, metadata
567583

568584
def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
@@ -596,7 +612,13 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
596612
metadata[self.tensor_keys.log_prob] = log_prob.detach()
597613
actor_loss = self._alpha * log_prob - min_q_logprob
598614
actor_loss = _reduce(actor_loss, reduction=self.reduction)
599-
615+
self._clear_weakrefs(
616+
tensordict,
617+
"actor_network_params",
618+
"qvalue_network_params",
619+
"target_actor_network_params",
620+
"target_qvalue_network_params",
621+
)
600622
return actor_loss, metadata
601623

602624
def _get_policy_actions(self, data, actor_params, num_actions=10):
@@ -712,6 +734,13 @@ def q_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
712734
loss_qval = _reduce(loss_qval, reduction=self.reduction)
713735
td_error = (q_pred - target_value).pow(2)
714736
metadata = {"td_error": td_error.detach()}
737+
self._clear_weakrefs(
738+
tensordict,
739+
"actor_network_params",
740+
"qvalue_network_params",
741+
"target_actor_network_params",
742+
"target_qvalue_network_params",
743+
)
715744
return loss_qval, metadata
716745

717746
def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
@@ -855,6 +884,13 @@ def filter_and_repeat(name, x):
855884
cql_q_loss = (cql_q1_loss + cql_q2_loss).mean(-1)
856885
cql_q_loss = _reduce(cql_q_loss, reduction=self.reduction)
857886

887+
self._clear_weakrefs(
888+
tensordict,
889+
"actor_network_params",
890+
"qvalue_network_params",
891+
"target_actor_network_params",
892+
"target_qvalue_network_params",
893+
)
858894
return cql_q_loss, {}
859895

860896
def alpha_prime_loss(self, tensordict: TensorDictBase) -> Tensor:
@@ -878,6 +914,13 @@ def alpha_prime_loss(self, tensordict: TensorDictBase) -> Tensor:
878914

879915
alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5
880916
alpha_prime_loss = _reduce(alpha_prime_loss, reduction=self.reduction)
917+
self._clear_weakrefs(
918+
tensordict,
919+
"actor_network_params",
920+
"qvalue_network_params",
921+
"target_actor_network_params",
922+
"target_qvalue_network_params",
923+
)
881924
return alpha_prime_loss, {}
882925

883926
def alpha_loss(self, tensordict: TensorDictBase) -> Tensor:
@@ -889,6 +932,13 @@ def alpha_loss(self, tensordict: TensorDictBase) -> Tensor:
889932
# placeholder
890933
alpha_loss = torch.zeros_like(log_pi)
891934
alpha_loss = _reduce(alpha_loss, reduction=self.reduction)
935+
self._clear_weakrefs(
936+
tensordict,
937+
"actor_network_params",
938+
"qvalue_network_params",
939+
"target_actor_network_params",
940+
"target_qvalue_network_params",
941+
)
892942
return alpha_loss, {}
893943

894944
@property

torchrl/objectives/crossq.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
542542
**value_metadata,
543543
}
544544
td_out = TensorDict(out)
545+
self._clear_weakrefs(
546+
tensordict,
547+
td_out,
548+
"actor_network_params",
549+
"qvalue_network_params",
550+
"target_actor_network_params",
551+
"target_qvalue_network_params",
552+
)
545553
return td_out
546554

547555
@property

torchrl/objectives/ddpg.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
303303
source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata},
304304
batch_size=[],
305305
)
306+
self._clear_weakrefs(
307+
tensordict,
308+
td_out,
309+
"value_network_params",
310+
"target_value_network_params",
311+
"target_actor_network_params",
312+
"actor_network_params",
313+
)
306314
return td_out
307315

308316
def loss_actor(
@@ -319,6 +327,14 @@ def loss_actor(
319327
loss_actor = -td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)
320328
metadata = {}
321329
loss_actor = _reduce(loss_actor, self.reduction)
330+
self._clear_weakrefs(
331+
tensordict,
332+
loss_actor,
333+
"value_network_params",
334+
"target_value_network_params",
335+
"target_actor_network_params",
336+
"actor_network_params",
337+
)
322338
return loss_actor, metadata
323339

324340
def loss_value(
@@ -358,6 +374,13 @@ def loss_value(
358374
"pred_value_max": pred_val.max(),
359375
}
360376
loss_value = _reduce(loss_value, self.reduction)
377+
self._clear_weakrefs(
378+
tensordict,
379+
"value_network_params",
380+
"target_value_network_params",
381+
"target_actor_network_params",
382+
"actor_network_params",
383+
)
361384
return loss_value, metadata
362385

363386
def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):

torchrl/objectives/decision_transformer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
241241
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
242242
if name.startswith("loss_")
243243
else value,
244-
batch_size=[],
244+
)
245+
self._clear_weakrefs(
246+
tensordict,
247+
td_out,
248+
"actor_network_params",
249+
"target_actor_network_params",
245250
)
246251
return td_out
247252

@@ -360,4 +365,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
360365
)
361366
loss = _reduce(loss, reduction=self.reduction)
362367
td_out = TensorDict(loss=loss)
368+
self._clear_weakrefs(
369+
tensordict,
370+
td_out,
371+
"actor_network_params",
372+
"target_actor_network_params",
373+
)
363374
return td_out

0 commit comments

Comments
 (0)