Skip to content

Commit 218d5bf

Browse files
authored
[Feature] Add LossModule.reset_parameters_recursive (#2546)
1 parent 35a7813 commit 218d5bf

File tree

2 files changed

+241
-0
lines changed

2 files changed

+241
-0
lines changed

test/test_cost.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,12 @@ def get_devices():
195195

196196

197197
class LossModuleTestBase:
198+
def __init_subclass__(cls, **kwargs):
199+
super().__init_subclass__(**kwargs)
200+
assert hasattr(
201+
cls, "test_reset_parameters_recursive"
202+
), "Please add a test_reset_parameters_recursive test for this class"
203+
198204
def _flatten_in_keys(self, in_keys):
199205
return [
200206
in_key if isinstance(in_key, str) else "_".join(list(unravel_keys(in_key)))
@@ -252,6 +258,42 @@ def set_advantage_keys_through_loss_test(
252258
getattr(test_fn.value_estimator.tensor_keys, advantage_key) == new_key
253259
)
254260

261+
@classmethod
262+
def reset_parameters_recursive_test(cls, loss_fn):
263+
def get_params(loss_fn):
264+
for key, item in loss_fn.__dict__.items():
265+
if isinstance(item, nn.Module):
266+
module_name = key
267+
params_name = f"{module_name}_params"
268+
target_name = f"target_{module_name}_params"
269+
params = loss_fn._modules.get(params_name, None)
270+
target = loss_fn._modules.get(target_name, None)
271+
272+
if params is not None:
273+
yield params_name, params._param_td
274+
275+
else:
276+
for subparam_name, subparam in loss_fn.named_parameters():
277+
if module_name in subparam_name:
278+
yield subparam_name, subparam
279+
280+
if target is not None:
281+
yield target_name, target
282+
283+
old_params = {}
284+
285+
for param_name, param in get_params(loss_fn):
286+
with torch.no_grad():
287+
# Change the parameter to ensure that reset will change it again
288+
param += 1000
289+
old_params[param_name] = param.clone()
290+
291+
loss_fn.reset_parameters_recursive()
292+
293+
for param_name, param in get_params(loss_fn):
294+
old_param = old_params[param_name]
295+
assert (param != old_param).any()
296+
255297

256298
@pytest.mark.parametrize("device", get_default_devices())
257299
@pytest.mark.parametrize("vmap_randomness", (None, "different", "same", "error"))
@@ -494,6 +536,11 @@ def _create_seq_mock_data_dqn(
494536
)
495537
return td
496538

539+
def test_reset_parameters_recursive(self):
540+
actor = self._create_mock_actor(action_spec_type="one_hot")
541+
loss_fn = DQNLoss(actor)
542+
self.reset_parameters_recursive_test(loss_fn)
543+
497544
@pytest.mark.parametrize(
498545
"delay_value,double_dqn", ([False, False], [True, False], [True, True])
499546
)
@@ -1066,6 +1113,12 @@ def _create_mock_data_dqn(
10661113
td.refine_names(None, "time")
10671114
return td
10681115

1116+
def test_reset_parameters_recursive(self):
1117+
actor = self._create_mock_actor(action_spec_type="one_hot")
1118+
mixer = self._create_mock_mixer()
1119+
loss_fn = QMixerLoss(actor, mixer)
1120+
self.reset_parameters_recursive_test(loss_fn)
1121+
10691122
@pytest.mark.parametrize("delay_value", (False, True))
10701123
@pytest.mark.parametrize("device", get_default_devices())
10711124
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
@@ -1570,6 +1623,12 @@ def _create_seq_mock_data_ddpg(
15701623
)
15711624
return td
15721625

1626+
def test_reset_parameters_recursive(self):
1627+
actor = self._create_mock_actor()
1628+
value = self._create_mock_value()
1629+
loss_fn = DDPGLoss(actor, value)
1630+
self.reset_parameters_recursive_test(loss_fn)
1631+
15731632
@pytest.mark.parametrize("device", get_default_devices())
15741633
@pytest.mark.parametrize("delay_actor,delay_value", [(False, False), (True, True)])
15751634
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@@ -2210,6 +2269,16 @@ def _create_seq_mock_data_td3(
22102269
)
22112270
return td
22122271

2272+
def test_reset_parameters_recursive(self):
2273+
actor = self._create_mock_actor()
2274+
value = self._create_mock_value()
2275+
loss_fn = TD3Loss(
2276+
actor,
2277+
value,
2278+
bounds=(-1, 1),
2279+
)
2280+
self.reset_parameters_recursive_test(loss_fn)
2281+
22132282
@pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
22142283
@pytest.mark.parametrize("device", get_default_devices())
22152284
@pytest.mark.parametrize(
@@ -2916,6 +2985,16 @@ def _create_seq_mock_data_td3bc(
29162985
)
29172986
return td
29182987

2988+
def test_reset_parameters_recursive(self):
2989+
actor = self._create_mock_actor()
2990+
value = self._create_mock_value()
2991+
loss_fn = TD3BCLoss(
2992+
actor,
2993+
value,
2994+
bounds=(-1, 1),
2995+
)
2996+
self.reset_parameters_recursive_test(loss_fn)
2997+
29192998
@pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
29202999
@pytest.mark.parametrize("device", get_default_devices())
29213000
@pytest.mark.parametrize(
@@ -3720,6 +3799,20 @@ def _create_seq_mock_data_sac(
37203799
)
37213800
return td
37223801

3802+
def test_reset_parameters_recursive(self, version):
3803+
actor = self._create_mock_actor()
3804+
qvalue = self._create_mock_qvalue()
3805+
if version == 1:
3806+
value = self._create_mock_value()
3807+
else:
3808+
value = None
3809+
loss_fn = SACLoss(
3810+
actor_network=actor,
3811+
qvalue_network=qvalue,
3812+
value_network=value,
3813+
)
3814+
self.reset_parameters_recursive_test(loss_fn)
3815+
37233816
@pytest.mark.parametrize("delay_value", (True, False))
37243817
@pytest.mark.parametrize("delay_actor", (True, False))
37253818
@pytest.mark.parametrize("delay_qvalue", (True, False))
@@ -4591,6 +4684,17 @@ def _create_seq_mock_data_sac(
45914684
)
45924685
return td
45934686

4687+
def test_reset_parameters_recursive(self):
4688+
actor = self._create_mock_actor()
4689+
qvalue = self._create_mock_qvalue()
4690+
loss_fn = DiscreteSACLoss(
4691+
actor_network=actor,
4692+
qvalue_network=qvalue,
4693+
num_actions=actor.spec["action"].space.n,
4694+
action_space="one-hot",
4695+
)
4696+
self.reset_parameters_recursive_test(loss_fn)
4697+
45944698
@pytest.mark.parametrize("delay_qvalue", (True, False))
45954699
@pytest.mark.parametrize("num_qvalue", [2])
45964700
@pytest.mark.parametrize("device", get_default_devices())
@@ -5227,6 +5331,15 @@ def _create_seq_mock_data_crossq(
52275331
)
52285332
return td
52295333

5334+
def test_reset_parameters_recursive(self):
5335+
actor = self._create_mock_actor()
5336+
qvalue = self._create_mock_qvalue()
5337+
loss_fn = CrossQLoss(
5338+
actor_network=actor,
5339+
qvalue_network=qvalue,
5340+
)
5341+
self.reset_parameters_recursive_test(loss_fn)
5342+
52305343
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
52315344
@pytest.mark.parametrize("device", get_default_devices())
52325345
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@@ -5962,6 +6075,15 @@ def _create_seq_mock_data_redq(
59626075
)
59636076
return td
59646077

6078+
def test_reset_parameters_recursive(self):
6079+
actor = self._create_mock_actor()
6080+
qvalue = self._create_mock_qvalue()
6081+
loss_fn = REDQLoss(
6082+
actor_network=actor,
6083+
qvalue_network=qvalue,
6084+
)
6085+
self.reset_parameters_recursive_test(loss_fn)
6086+
59656087
@pytest.mark.parametrize("delay_qvalue", (True, False))
59666088
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
59676089
@pytest.mark.parametrize("device", get_default_devices())
@@ -6792,6 +6914,15 @@ def _create_seq_mock_data_cql(
67926914
)
67936915
return td
67946916

6917+
def test_reset_parameters_recursive(self):
6918+
actor = self._create_mock_actor()
6919+
qvalue = self._create_mock_qvalue()
6920+
loss_fn = CQLLoss(
6921+
actor_network=actor,
6922+
qvalue_network=qvalue,
6923+
)
6924+
self.reset_parameters_recursive_test(loss_fn)
6925+
67956926
@pytest.mark.parametrize("delay_actor", (True, False))
67966927
@pytest.mark.parametrize("delay_qvalue", (True, True))
67976928
@pytest.mark.parametrize("max_q_backup", [True, False])
@@ -7367,6 +7498,13 @@ def _create_seq_mock_data_dcql(
73677498
)
73687499
return td
73697500

7501+
def test_reset_parameters_recursive(self):
7502+
actor = self._create_mock_actor(
7503+
action_spec_type="one_hot",
7504+
)
7505+
loss_fn = DiscreteCQLLoss(actor)
7506+
self.reset_parameters_recursive_test(loss_fn)
7507+
73707508
@pytest.mark.parametrize("delay_value", (False, True))
73717509
@pytest.mark.parametrize("device", get_default_devices())
73727510
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
@@ -7938,6 +8076,13 @@ def _create_seq_mock_data_ppo(
79388076

79398077
return td
79408078

8079+
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
8080+
def test_reset_parameters_recursive(self, loss_class):
8081+
actor = self._create_mock_actor()
8082+
value = self._create_mock_value()
8083+
loss_fn = loss_class(actor, value)
8084+
self.reset_parameters_recursive_test(loss_fn)
8085+
79418086
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
79428087
@pytest.mark.parametrize("gradient_mode", (True, False))
79438088
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@@ -9016,6 +9161,12 @@ def _create_seq_mock_data_a2c(
90169161
td["scale"] = scale
90179162
return td
90189163

9164+
def test_reset_parameters_recursive(self):
9165+
actor = self._create_mock_actor()
9166+
value = self._create_mock_value()
9167+
loss_fn = A2CLoss(actor, value)
9168+
self.reset_parameters_recursive_test(loss_fn)
9169+
90199170
@pytest.mark.parametrize("gradient_mode", (True, False))
90209171
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
90219172
@pytest.mark.parametrize("device", get_default_devices())
@@ -9624,6 +9775,27 @@ def test_a2c_value_clipping(self, clip_value, device, composite_action_dist):
96249775
class TestReinforce(LossModuleTestBase):
96259776
seed = 0
96269777

9778+
def test_reset_parameters_recursive(self):
9779+
n_obs = 3
9780+
n_act = 5
9781+
value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
9782+
net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
9783+
module = TensorDictModule(
9784+
net, in_keys=["observation"], out_keys=["loc", "scale"]
9785+
)
9786+
actor_net = ProbabilisticActor(
9787+
module,
9788+
distribution_class=TanhNormal,
9789+
return_log_prob=True,
9790+
in_keys=["loc", "scale"],
9791+
spec=Unbounded(n_act),
9792+
)
9793+
loss_fn = ReinforceLoss(
9794+
actor_net,
9795+
critic_network=value_net,
9796+
)
9797+
self.reset_parameters_recursive_test(loss_fn)
9798+
96279799
@pytest.mark.parametrize("gradient_mode", [True, False])
96289800
@pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None])
96299801
@pytest.mark.parametrize(
@@ -10323,6 +10495,11 @@ def _create_value_model(self, rssm_hidden_dim, state_dim, mlp_num_units=13):
1032310495
value_model(td)
1032410496
return value_model
1032510497

10498+
def test_reset_parameters_recursive(self, device):
10499+
world_model = self._create_world_model_model(10, 5).to(device)
10500+
loss_fn = DreamerModelLoss(world_model)
10501+
self.reset_parameters_recursive_test(loss_fn)
10502+
1032610503
@pytest.mark.parametrize("lambda_kl", [0, 1.0])
1032710504
@pytest.mark.parametrize("lambda_reco", [0, 1.0])
1032810505
@pytest.mark.parametrize("lambda_reward", [0, 1.0])
@@ -10604,6 +10781,11 @@ def _create_seq_mock_data_odt(
1060410781
)
1060510782
return td
1060610783

10784+
def test_reset_parameters_recursive(self):
10785+
actor = self._create_mock_actor()
10786+
loss_fn = OnlineDTLoss(actor)
10787+
self.reset_parameters_recursive_test(loss_fn)
10788+
1060710789
@pytest.mark.parametrize("device", get_available_devices())
1060810790
def test_odt(self, device):
1060910791
torch.manual_seed(self.seed)
@@ -10831,6 +11013,11 @@ def _create_seq_mock_data_dt(
1083111013
)
1083211014
return td
1083311015

11016+
def test_reset_parameters_recursive(self):
11017+
actor = self._create_mock_actor()
11018+
loss_fn = DTLoss(actor)
11019+
self.reset_parameters_recursive_test(loss_fn)
11020+
1083411021
def test_dt_tensordict_keys(self):
1083511022
actor = self._create_mock_actor()
1083611023
loss_fn = DTLoss(actor)
@@ -11034,6 +11221,11 @@ def _create_seq_mock_data_gail(
1103411221
)
1103511222
return td
1103611223

11224+
def test_reset_parameters_recursive(self):
11225+
discriminator = self._create_mock_discriminator()
11226+
loss_fn = GAILLoss(discriminator)
11227+
self.reset_parameters_recursive_test(loss_fn)
11228+
1103711229
def test_gail_tensordict_keys(self):
1103811230
discriminator = self._create_mock_discriminator()
1103911231
loss_fn = GAILLoss(discriminator)
@@ -11406,6 +11598,17 @@ def _create_seq_mock_data_iql(
1140611598
)
1140711599
return td
1140811600

11601+
def test_reset_parameters_recursive(self):
11602+
actor = self._create_mock_actor()
11603+
qvalue = self._create_mock_qvalue()
11604+
value = self._create_mock_value()
11605+
loss_fn = IQLLoss(
11606+
actor_network=actor,
11607+
qvalue_network=qvalue,
11608+
value_network=value,
11609+
)
11610+
self.reset_parameters_recursive_test(loss_fn)
11611+
1140911612
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
1141011613
@pytest.mark.parametrize("device", get_default_devices())
1141111614
@pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0])
@@ -12214,6 +12417,18 @@ def _create_seq_mock_data_discrete_iql(
1221412417
)
1221512418
return td
1221612419

12420+
def test_reset_parameters_recursive(self):
12421+
actor = self._create_mock_actor()
12422+
qvalue = self._create_mock_qvalue()
12423+
value = self._create_mock_value()
12424+
loss_fn = DiscreteIQLLoss(
12425+
actor_network=actor,
12426+
qvalue_network=qvalue,
12427+
value_network=value,
12428+
action_space="one-hot",
12429+
)
12430+
self.reset_parameters_recursive_test(loss_fn)
12431+
1221712432
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
1221812433
@pytest.mark.parametrize("device", get_default_devices())
1221912434
@pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0])
@@ -12842,6 +13057,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1284213057
)
1284313058
loss = MyLoss(actor_module)
1284413059

13060+
LossModuleTestBase.reset_parameters_recursive_test(loss)
13061+
1284513062
if create_target_params:
1284613063
SoftUpdate(loss, eps=0.5)
1284713064

0 commit comments

Comments
 (0)