@@ -195,6 +195,12 @@ def get_devices():
195
195
196
196
197
197
class LossModuleTestBase:
198
+ def __init_subclass__(cls, **kwargs):
199
+ super().__init_subclass__(**kwargs)
200
+ assert hasattr(
201
+ cls, "test_reset_parameters_recursive"
202
+ ), "Please add a test_reset_parameters_recursive test for this class"
203
+
198
204
def _flatten_in_keys(self, in_keys):
199
205
return [
200
206
in_key if isinstance(in_key, str) else "_".join(list(unravel_keys(in_key)))
@@ -252,6 +258,42 @@ def set_advantage_keys_through_loss_test(
252
258
getattr(test_fn.value_estimator.tensor_keys, advantage_key) == new_key
253
259
)
254
260
261
+ @classmethod
262
+ def reset_parameters_recursive_test(cls, loss_fn):
263
+ def get_params(loss_fn):
264
+ for key, item in loss_fn.__dict__.items():
265
+ if isinstance(item, nn.Module):
266
+ module_name = key
267
+ params_name = f"{module_name}_params"
268
+ target_name = f"target_{module_name}_params"
269
+ params = loss_fn._modules.get(params_name, None)
270
+ target = loss_fn._modules.get(target_name, None)
271
+
272
+ if params is not None:
273
+ yield params_name, params._param_td
274
+
275
+ else:
276
+ for subparam_name, subparam in loss_fn.named_parameters():
277
+ if module_name in subparam_name:
278
+ yield subparam_name, subparam
279
+
280
+ if target is not None:
281
+ yield target_name, target
282
+
283
+ old_params = {}
284
+
285
+ for param_name, param in get_params(loss_fn):
286
+ with torch.no_grad():
287
+ # Change the parameter to ensure that reset will change it again
288
+ param += 1000
289
+ old_params[param_name] = param.clone()
290
+
291
+ loss_fn.reset_parameters_recursive()
292
+
293
+ for param_name, param in get_params(loss_fn):
294
+ old_param = old_params[param_name]
295
+ assert (param != old_param).any()
296
+
255
297
256
298
@pytest.mark.parametrize("device", get_default_devices())
257
299
@pytest.mark.parametrize("vmap_randomness", (None, "different", "same", "error"))
@@ -494,6 +536,11 @@ def _create_seq_mock_data_dqn(
494
536
)
495
537
return td
496
538
539
+ def test_reset_parameters_recursive(self):
540
+ actor = self._create_mock_actor(action_spec_type="one_hot")
541
+ loss_fn = DQNLoss(actor)
542
+ self.reset_parameters_recursive_test(loss_fn)
543
+
497
544
@pytest.mark.parametrize(
498
545
"delay_value,double_dqn", ([False, False], [True, False], [True, True])
499
546
)
@@ -1066,6 +1113,12 @@ def _create_mock_data_dqn(
1066
1113
td.refine_names(None, "time")
1067
1114
return td
1068
1115
1116
+ def test_reset_parameters_recursive(self):
1117
+ actor = self._create_mock_actor(action_spec_type="one_hot")
1118
+ mixer = self._create_mock_mixer()
1119
+ loss_fn = QMixerLoss(actor, mixer)
1120
+ self.reset_parameters_recursive_test(loss_fn)
1121
+
1069
1122
@pytest.mark.parametrize("delay_value", (False, True))
1070
1123
@pytest.mark.parametrize("device", get_default_devices())
1071
1124
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
@@ -1570,6 +1623,12 @@ def _create_seq_mock_data_ddpg(
1570
1623
)
1571
1624
return td
1572
1625
1626
+ def test_reset_parameters_recursive(self):
1627
+ actor = self._create_mock_actor()
1628
+ value = self._create_mock_value()
1629
+ loss_fn = DDPGLoss(actor, value)
1630
+ self.reset_parameters_recursive_test(loss_fn)
1631
+
1573
1632
@pytest.mark.parametrize("device", get_default_devices())
1574
1633
@pytest.mark.parametrize("delay_actor,delay_value", [(False, False), (True, True)])
1575
1634
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@@ -2210,6 +2269,16 @@ def _create_seq_mock_data_td3(
2210
2269
)
2211
2270
return td
2212
2271
2272
+ def test_reset_parameters_recursive(self):
2273
+ actor = self._create_mock_actor()
2274
+ value = self._create_mock_value()
2275
+ loss_fn = TD3Loss(
2276
+ actor,
2277
+ value,
2278
+ bounds=(-1, 1),
2279
+ )
2280
+ self.reset_parameters_recursive_test(loss_fn)
2281
+
2213
2282
@pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
2214
2283
@pytest.mark.parametrize("device", get_default_devices())
2215
2284
@pytest.mark.parametrize(
@@ -2916,6 +2985,16 @@ def _create_seq_mock_data_td3bc(
2916
2985
)
2917
2986
return td
2918
2987
2988
+ def test_reset_parameters_recursive(self):
2989
+ actor = self._create_mock_actor()
2990
+ value = self._create_mock_value()
2991
+ loss_fn = TD3BCLoss(
2992
+ actor,
2993
+ value,
2994
+ bounds=(-1, 1),
2995
+ )
2996
+ self.reset_parameters_recursive_test(loss_fn)
2997
+
2919
2998
@pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
2920
2999
@pytest.mark.parametrize("device", get_default_devices())
2921
3000
@pytest.mark.parametrize(
@@ -3720,6 +3799,20 @@ def _create_seq_mock_data_sac(
3720
3799
)
3721
3800
return td
3722
3801
3802
+ def test_reset_parameters_recursive(self, version):
3803
+ actor = self._create_mock_actor()
3804
+ qvalue = self._create_mock_qvalue()
3805
+ if version == 1:
3806
+ value = self._create_mock_value()
3807
+ else:
3808
+ value = None
3809
+ loss_fn = SACLoss(
3810
+ actor_network=actor,
3811
+ qvalue_network=qvalue,
3812
+ value_network=value,
3813
+ )
3814
+ self.reset_parameters_recursive_test(loss_fn)
3815
+
3723
3816
@pytest.mark.parametrize("delay_value", (True, False))
3724
3817
@pytest.mark.parametrize("delay_actor", (True, False))
3725
3818
@pytest.mark.parametrize("delay_qvalue", (True, False))
@@ -4591,6 +4684,17 @@ def _create_seq_mock_data_sac(
4591
4684
)
4592
4685
return td
4593
4686
4687
+ def test_reset_parameters_recursive(self):
4688
+ actor = self._create_mock_actor()
4689
+ qvalue = self._create_mock_qvalue()
4690
+ loss_fn = DiscreteSACLoss(
4691
+ actor_network=actor,
4692
+ qvalue_network=qvalue,
4693
+ num_actions=actor.spec["action"].space.n,
4694
+ action_space="one-hot",
4695
+ )
4696
+ self.reset_parameters_recursive_test(loss_fn)
4697
+
4594
4698
@pytest.mark.parametrize("delay_qvalue", (True, False))
4595
4699
@pytest.mark.parametrize("num_qvalue", [2])
4596
4700
@pytest.mark.parametrize("device", get_default_devices())
@@ -5227,6 +5331,15 @@ def _create_seq_mock_data_crossq(
5227
5331
)
5228
5332
return td
5229
5333
5334
+ def test_reset_parameters_recursive(self):
5335
+ actor = self._create_mock_actor()
5336
+ qvalue = self._create_mock_qvalue()
5337
+ loss_fn = CrossQLoss(
5338
+ actor_network=actor,
5339
+ qvalue_network=qvalue,
5340
+ )
5341
+ self.reset_parameters_recursive_test(loss_fn)
5342
+
5230
5343
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
5231
5344
@pytest.mark.parametrize("device", get_default_devices())
5232
5345
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@@ -5962,6 +6075,15 @@ def _create_seq_mock_data_redq(
5962
6075
)
5963
6076
return td
5964
6077
6078
+ def test_reset_parameters_recursive(self):
6079
+ actor = self._create_mock_actor()
6080
+ qvalue = self._create_mock_qvalue()
6081
+ loss_fn = REDQLoss(
6082
+ actor_network=actor,
6083
+ qvalue_network=qvalue,
6084
+ )
6085
+ self.reset_parameters_recursive_test(loss_fn)
6086
+
5965
6087
@pytest.mark.parametrize("delay_qvalue", (True, False))
5966
6088
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
5967
6089
@pytest.mark.parametrize("device", get_default_devices())
@@ -6792,6 +6914,15 @@ def _create_seq_mock_data_cql(
6792
6914
)
6793
6915
return td
6794
6916
6917
+ def test_reset_parameters_recursive(self):
6918
+ actor = self._create_mock_actor()
6919
+ qvalue = self._create_mock_qvalue()
6920
+ loss_fn = CQLLoss(
6921
+ actor_network=actor,
6922
+ qvalue_network=qvalue,
6923
+ )
6924
+ self.reset_parameters_recursive_test(loss_fn)
6925
+
6795
6926
@pytest.mark.parametrize("delay_actor", (True, False))
6796
6927
@pytest.mark.parametrize("delay_qvalue", (True, True))
6797
6928
@pytest.mark.parametrize("max_q_backup", [True, False])
@@ -7367,6 +7498,13 @@ def _create_seq_mock_data_dcql(
7367
7498
)
7368
7499
return td
7369
7500
7501
+ def test_reset_parameters_recursive(self):
7502
+ actor = self._create_mock_actor(
7503
+ action_spec_type="one_hot",
7504
+ )
7505
+ loss_fn = DiscreteCQLLoss(actor)
7506
+ self.reset_parameters_recursive_test(loss_fn)
7507
+
7370
7508
@pytest.mark.parametrize("delay_value", (False, True))
7371
7509
@pytest.mark.parametrize("device", get_default_devices())
7372
7510
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
@@ -7938,6 +8076,13 @@ def _create_seq_mock_data_ppo(
7938
8076
7939
8077
return td
7940
8078
8079
+ @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
8080
+ def test_reset_parameters_recursive(self, loss_class):
8081
+ actor = self._create_mock_actor()
8082
+ value = self._create_mock_value()
8083
+ loss_fn = loss_class(actor, value)
8084
+ self.reset_parameters_recursive_test(loss_fn)
8085
+
7941
8086
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
7942
8087
@pytest.mark.parametrize("gradient_mode", (True, False))
7943
8088
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@@ -9016,6 +9161,12 @@ def _create_seq_mock_data_a2c(
9016
9161
td["scale"] = scale
9017
9162
return td
9018
9163
9164
+ def test_reset_parameters_recursive(self):
9165
+ actor = self._create_mock_actor()
9166
+ value = self._create_mock_value()
9167
+ loss_fn = A2CLoss(actor, value)
9168
+ self.reset_parameters_recursive_test(loss_fn)
9169
+
9019
9170
@pytest.mark.parametrize("gradient_mode", (True, False))
9020
9171
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
9021
9172
@pytest.mark.parametrize("device", get_default_devices())
@@ -9624,6 +9775,27 @@ def test_a2c_value_clipping(self, clip_value, device, composite_action_dist):
9624
9775
class TestReinforce(LossModuleTestBase):
9625
9776
seed = 0
9626
9777
9778
+ def test_reset_parameters_recursive(self):
9779
+ n_obs = 3
9780
+ n_act = 5
9781
+ value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
9782
+ net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
9783
+ module = TensorDictModule(
9784
+ net, in_keys=["observation"], out_keys=["loc", "scale"]
9785
+ )
9786
+ actor_net = ProbabilisticActor(
9787
+ module,
9788
+ distribution_class=TanhNormal,
9789
+ return_log_prob=True,
9790
+ in_keys=["loc", "scale"],
9791
+ spec=Unbounded(n_act),
9792
+ )
9793
+ loss_fn = ReinforceLoss(
9794
+ actor_net,
9795
+ critic_network=value_net,
9796
+ )
9797
+ self.reset_parameters_recursive_test(loss_fn)
9798
+
9627
9799
@pytest.mark.parametrize("gradient_mode", [True, False])
9628
9800
@pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None])
9629
9801
@pytest.mark.parametrize(
@@ -10323,6 +10495,11 @@ def _create_value_model(self, rssm_hidden_dim, state_dim, mlp_num_units=13):
10323
10495
value_model(td)
10324
10496
return value_model
10325
10497
10498
+ def test_reset_parameters_recursive(self, device):
10499
+ world_model = self._create_world_model_model(10, 5).to(device)
10500
+ loss_fn = DreamerModelLoss(world_model)
10501
+ self.reset_parameters_recursive_test(loss_fn)
10502
+
10326
10503
@pytest.mark.parametrize("lambda_kl", [0, 1.0])
10327
10504
@pytest.mark.parametrize("lambda_reco", [0, 1.0])
10328
10505
@pytest.mark.parametrize("lambda_reward", [0, 1.0])
@@ -10604,6 +10781,11 @@ def _create_seq_mock_data_odt(
10604
10781
)
10605
10782
return td
10606
10783
10784
+ def test_reset_parameters_recursive(self):
10785
+ actor = self._create_mock_actor()
10786
+ loss_fn = OnlineDTLoss(actor)
10787
+ self.reset_parameters_recursive_test(loss_fn)
10788
+
10607
10789
@pytest.mark.parametrize("device", get_available_devices())
10608
10790
def test_odt(self, device):
10609
10791
torch.manual_seed(self.seed)
@@ -10831,6 +11013,11 @@ def _create_seq_mock_data_dt(
10831
11013
)
10832
11014
return td
10833
11015
11016
+ def test_reset_parameters_recursive(self):
11017
+ actor = self._create_mock_actor()
11018
+ loss_fn = DTLoss(actor)
11019
+ self.reset_parameters_recursive_test(loss_fn)
11020
+
10834
11021
def test_dt_tensordict_keys(self):
10835
11022
actor = self._create_mock_actor()
10836
11023
loss_fn = DTLoss(actor)
@@ -11034,6 +11221,11 @@ def _create_seq_mock_data_gail(
11034
11221
)
11035
11222
return td
11036
11223
11224
+ def test_reset_parameters_recursive(self):
11225
+ discriminator = self._create_mock_discriminator()
11226
+ loss_fn = GAILLoss(discriminator)
11227
+ self.reset_parameters_recursive_test(loss_fn)
11228
+
11037
11229
def test_gail_tensordict_keys(self):
11038
11230
discriminator = self._create_mock_discriminator()
11039
11231
loss_fn = GAILLoss(discriminator)
@@ -11406,6 +11598,17 @@ def _create_seq_mock_data_iql(
11406
11598
)
11407
11599
return td
11408
11600
11601
+ def test_reset_parameters_recursive(self):
11602
+ actor = self._create_mock_actor()
11603
+ qvalue = self._create_mock_qvalue()
11604
+ value = self._create_mock_value()
11605
+ loss_fn = IQLLoss(
11606
+ actor_network=actor,
11607
+ qvalue_network=qvalue,
11608
+ value_network=value,
11609
+ )
11610
+ self.reset_parameters_recursive_test(loss_fn)
11611
+
11409
11612
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
11410
11613
@pytest.mark.parametrize("device", get_default_devices())
11411
11614
@pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0])
@@ -12214,6 +12417,18 @@ def _create_seq_mock_data_discrete_iql(
12214
12417
)
12215
12418
return td
12216
12419
12420
+ def test_reset_parameters_recursive(self):
12421
+ actor = self._create_mock_actor()
12422
+ qvalue = self._create_mock_qvalue()
12423
+ value = self._create_mock_value()
12424
+ loss_fn = DiscreteIQLLoss(
12425
+ actor_network=actor,
12426
+ qvalue_network=qvalue,
12427
+ value_network=value,
12428
+ action_space="one-hot",
12429
+ )
12430
+ self.reset_parameters_recursive_test(loss_fn)
12431
+
12217
12432
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
12218
12433
@pytest.mark.parametrize("device", get_default_devices())
12219
12434
@pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0])
@@ -12842,6 +13057,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
12842
13057
)
12843
13058
loss = MyLoss(actor_module)
12844
13059
13060
+ LossModuleTestBase.reset_parameters_recursive_test(loss)
13061
+
12845
13062
if create_target_params:
12846
13063
SoftUpdate(loss, eps=0.5)
12847
13064
0 commit comments