Skip to content

Commit b07fd35

Browse files
authored
[Feature] SAC V2 (#864)
1 parent f741deb commit b07fd35

File tree

2 files changed

+162
-66
lines changed

2 files changed

+162
-66
lines changed

test/test_cost.py

Lines changed: 63 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,7 @@ def test_td3_batcher(
915915
assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()))
916916

917917

918+
@pytest.mark.parametrize("version", [1, 2])
918919
class TestSAC:
919920
seed = 0
920921

@@ -1027,7 +1028,9 @@ def _create_seq_mock_data_sac(
10271028
@pytest.mark.parametrize("delay_qvalue", (True, False))
10281029
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
10291030
@pytest.mark.parametrize("device", get_available_devices())
1030-
def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
1031+
def test_sac(
1032+
self, delay_value, delay_actor, delay_qvalue, num_qvalue, device, version
1033+
):
10311034
if (delay_actor or delay_qvalue) and not delay_value:
10321035
pytest.skip("incompatible config")
10331036

@@ -1036,7 +1039,10 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
10361039

10371040
actor = self._create_mock_actor(device=device)
10381041
qvalue = self._create_mock_qvalue(device=device)
1039-
value = self._create_mock_value(device=device)
1042+
if version == 1:
1043+
value = self._create_mock_value(device=device)
1044+
else:
1045+
value = None
10401046

10411047
kwargs = {}
10421048
if delay_actor:
@@ -1066,12 +1072,13 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
10661072
continue
10671073
loss[k].sum().backward(retain_graph=True)
10681074
if k == "loss_actor":
1069-
assert all(
1070-
(p.grad is None) or (p.grad == 0).all()
1071-
for p in loss_fn.value_network_params.values(
1072-
include_nested=True, leaves_only=True
1075+
if version == 1:
1076+
assert all(
1077+
(p.grad is None) or (p.grad == 0).all()
1078+
for p in loss_fn.value_network_params.values(
1079+
include_nested=True, leaves_only=True
1080+
)
10731081
)
1074-
)
10751082
assert all(
10761083
(p.grad is None) or (p.grad == 0).all()
10771084
for p in loss_fn.qvalue_network_params.values(
@@ -1084,7 +1091,7 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
10841091
include_nested=True, leaves_only=True
10851092
)
10861093
)
1087-
elif k == "loss_value":
1094+
elif k == "loss_value" and version == 1:
10881095
assert all(
10891096
(p.grad is None) or (p.grad == 0).all()
10901097
for p in loss_fn.actor_network_params.values(
@@ -1110,12 +1117,13 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
11101117
include_nested=True, leaves_only=True
11111118
)
11121119
)
1113-
assert all(
1114-
(p.grad is None) or (p.grad == 0).all()
1115-
for p in loss_fn.value_network_params.values(
1116-
include_nested=True, leaves_only=True
1120+
if version == 1:
1121+
assert all(
1122+
(p.grad is None) or (p.grad == 0).all()
1123+
for p in loss_fn.value_network_params.values(
1124+
include_nested=True, leaves_only=True
1125+
)
11171126
)
1118-
)
11191127
assert not any(
11201128
(p.grad is None) or (p.grad == 0).all()
11211129
for p in loss_fn.qvalue_network_params.values(
@@ -1129,12 +1137,13 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
11291137
include_nested=True, leaves_only=True
11301138
)
11311139
)
1132-
assert all(
1133-
(p.grad is None) or (p.grad == 0).all()
1134-
for p in loss_fn.value_network_params.values(
1135-
include_nested=True, leaves_only=True
1140+
if version == 1:
1141+
assert all(
1142+
(p.grad is None) or (p.grad == 0).all()
1143+
for p in loss_fn.value_network_params.values(
1144+
include_nested=True, leaves_only=True
1145+
)
11361146
)
1137-
)
11381147
assert all(
11391148
(p.grad is None) or (p.grad == 0).all()
11401149
for p in loss_fn.qvalue_network_params.values(
@@ -1165,7 +1174,15 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
11651174
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
11661175
@pytest.mark.parametrize("device", get_available_devices())
11671176
def test_sac_batcher(
1168-
self, n, delay_value, delay_actor, delay_qvalue, num_qvalue, device, gamma=0.9
1177+
self,
1178+
n,
1179+
delay_value,
1180+
delay_actor,
1181+
delay_qvalue,
1182+
num_qvalue,
1183+
device,
1184+
version,
1185+
gamma=0.9,
11691186
):
11701187
if (delay_actor or delay_qvalue) and not delay_value:
11711188
pytest.skip("incompatible config")
@@ -1174,7 +1191,10 @@ def test_sac_batcher(
11741191

11751192
actor = self._create_mock_actor(device=device)
11761193
qvalue = self._create_mock_qvalue(device=device)
1177-
value = self._create_mock_value(device=device)
1194+
if version == 1:
1195+
value = self._create_mock_value(device=device)
1196+
else:
1197+
value = None
11781198

11791199
kwargs = {}
11801200
if delay_actor:
@@ -1237,12 +1257,13 @@ def test_sac_batcher(
12371257
include_nested=True, leaves_only=True
12381258
)
12391259
]
1240-
target_value = [
1241-
p.clone()
1242-
for p in loss_fn.target_value_network_params.values(
1243-
include_nested=True, leaves_only=True
1244-
)
1245-
]
1260+
if version == 1:
1261+
target_value = [
1262+
p.clone()
1263+
for p in loss_fn.target_value_network_params.values(
1264+
include_nested=True, leaves_only=True
1265+
)
1266+
]
12461267
for p in loss_fn.parameters():
12471268
p.data += torch.randn_like(p)
12481269
target_actor2 = [
@@ -1257,12 +1278,13 @@ def test_sac_batcher(
12571278
include_nested=True, leaves_only=True
12581279
)
12591280
]
1260-
target_value2 = [
1261-
p.clone()
1262-
for p in loss_fn.target_value_network_params.values(
1263-
include_nested=True, leaves_only=True
1264-
)
1265-
]
1281+
if version == 1:
1282+
target_value2 = [
1283+
p.clone()
1284+
for p in loss_fn.target_value_network_params.values(
1285+
include_nested=True, leaves_only=True
1286+
)
1287+
]
12661288
if loss_fn.delay_actor:
12671289
assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2))
12681290
else:
@@ -1277,12 +1299,15 @@ def test_sac_batcher(
12771299
assert not any(
12781300
(p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2)
12791301
)
1280-
if loss_fn.delay_value:
1281-
assert all((p1 == p2).all() for p1, p2 in zip(target_value, target_value2))
1282-
else:
1283-
assert not any(
1284-
(p1 == p2).any() for p1, p2 in zip(target_value, target_value2)
1285-
)
1302+
if version == 1:
1303+
if loss_fn.delay_value:
1304+
assert all(
1305+
(p1 == p2).all() for p1, p2 in zip(target_value, target_value2)
1306+
)
1307+
else:
1308+
assert not any(
1309+
(p1 == p2).any() for p1, p2 in zip(target_value, target_value2)
1310+
)
12861311

12871312
# check that policy is updated after parameter update
12881313
parameters = [p.clone() for p in actor.parameters()]

0 commit comments

Comments
 (0)