@@ -915,6 +915,7 @@ def test_td3_batcher(
915
915
assert all ((p1 != p2 ).all () for p1 , p2 in zip (parameters , actor .parameters ()))
916
916
917
917
918
+ @pytest .mark .parametrize ("version" , [1 , 2 ])
918
919
class TestSAC :
919
920
seed = 0
920
921
@@ -1027,7 +1028,9 @@ def _create_seq_mock_data_sac(
1027
1028
@pytest .mark .parametrize ("delay_qvalue" , (True , False ))
1028
1029
@pytest .mark .parametrize ("num_qvalue" , [1 , 2 , 4 , 8 ])
1029
1030
@pytest .mark .parametrize ("device" , get_available_devices ())
1030
- def test_sac (self , delay_value , delay_actor , delay_qvalue , num_qvalue , device ):
1031
+ def test_sac (
1032
+ self , delay_value , delay_actor , delay_qvalue , num_qvalue , device , version
1033
+ ):
1031
1034
if (delay_actor or delay_qvalue ) and not delay_value :
1032
1035
pytest .skip ("incompatible config" )
1033
1036
@@ -1036,7 +1039,10 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
1036
1039
1037
1040
actor = self ._create_mock_actor (device = device )
1038
1041
qvalue = self ._create_mock_qvalue (device = device )
1039
- value = self ._create_mock_value (device = device )
1042
+ if version == 1 :
1043
+ value = self ._create_mock_value (device = device )
1044
+ else :
1045
+ value = None
1040
1046
1041
1047
kwargs = {}
1042
1048
if delay_actor :
@@ -1066,12 +1072,13 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
1066
1072
continue
1067
1073
loss [k ].sum ().backward (retain_graph = True )
1068
1074
if k == "loss_actor" :
1069
- assert all (
1070
- (p .grad is None ) or (p .grad == 0 ).all ()
1071
- for p in loss_fn .value_network_params .values (
1072
- include_nested = True , leaves_only = True
1075
+ if version == 1 :
1076
+ assert all (
1077
+ (p .grad is None ) or (p .grad == 0 ).all ()
1078
+ for p in loss_fn .value_network_params .values (
1079
+ include_nested = True , leaves_only = True
1080
+ )
1073
1081
)
1074
- )
1075
1082
assert all (
1076
1083
(p .grad is None ) or (p .grad == 0 ).all ()
1077
1084
for p in loss_fn .qvalue_network_params .values (
@@ -1084,7 +1091,7 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
1084
1091
include_nested = True , leaves_only = True
1085
1092
)
1086
1093
)
1087
- elif k == "loss_value" :
1094
+ elif k == "loss_value" and version == 1 :
1088
1095
assert all (
1089
1096
(p .grad is None ) or (p .grad == 0 ).all ()
1090
1097
for p in loss_fn .actor_network_params .values (
@@ -1110,12 +1117,13 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
1110
1117
include_nested = True , leaves_only = True
1111
1118
)
1112
1119
)
1113
- assert all (
1114
- (p .grad is None ) or (p .grad == 0 ).all ()
1115
- for p in loss_fn .value_network_params .values (
1116
- include_nested = True , leaves_only = True
1120
+ if version == 1 :
1121
+ assert all (
1122
+ (p .grad is None ) or (p .grad == 0 ).all ()
1123
+ for p in loss_fn .value_network_params .values (
1124
+ include_nested = True , leaves_only = True
1125
+ )
1117
1126
)
1118
- )
1119
1127
assert not any (
1120
1128
(p .grad is None ) or (p .grad == 0 ).all ()
1121
1129
for p in loss_fn .qvalue_network_params .values (
@@ -1129,12 +1137,13 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
1129
1137
include_nested = True , leaves_only = True
1130
1138
)
1131
1139
)
1132
- assert all (
1133
- (p .grad is None ) or (p .grad == 0 ).all ()
1134
- for p in loss_fn .value_network_params .values (
1135
- include_nested = True , leaves_only = True
1140
+ if version == 1 :
1141
+ assert all (
1142
+ (p .grad is None ) or (p .grad == 0 ).all ()
1143
+ for p in loss_fn .value_network_params .values (
1144
+ include_nested = True , leaves_only = True
1145
+ )
1136
1146
)
1137
- )
1138
1147
assert all (
1139
1148
(p .grad is None ) or (p .grad == 0 ).all ()
1140
1149
for p in loss_fn .qvalue_network_params .values (
@@ -1165,7 +1174,15 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
1165
1174
@pytest .mark .parametrize ("num_qvalue" , [1 , 2 , 4 , 8 ])
1166
1175
@pytest .mark .parametrize ("device" , get_available_devices ())
1167
1176
def test_sac_batcher (
1168
- self , n , delay_value , delay_actor , delay_qvalue , num_qvalue , device , gamma = 0.9
1177
+ self ,
1178
+ n ,
1179
+ delay_value ,
1180
+ delay_actor ,
1181
+ delay_qvalue ,
1182
+ num_qvalue ,
1183
+ device ,
1184
+ version ,
1185
+ gamma = 0.9 ,
1169
1186
):
1170
1187
if (delay_actor or delay_qvalue ) and not delay_value :
1171
1188
pytest .skip ("incompatible config" )
@@ -1174,7 +1191,10 @@ def test_sac_batcher(
1174
1191
1175
1192
actor = self ._create_mock_actor (device = device )
1176
1193
qvalue = self ._create_mock_qvalue (device = device )
1177
- value = self ._create_mock_value (device = device )
1194
+ if version == 1 :
1195
+ value = self ._create_mock_value (device = device )
1196
+ else :
1197
+ value = None
1178
1198
1179
1199
kwargs = {}
1180
1200
if delay_actor :
@@ -1237,12 +1257,13 @@ def test_sac_batcher(
1237
1257
include_nested = True , leaves_only = True
1238
1258
)
1239
1259
]
1240
- target_value = [
1241
- p .clone ()
1242
- for p in loss_fn .target_value_network_params .values (
1243
- include_nested = True , leaves_only = True
1244
- )
1245
- ]
1260
+ if version == 1 :
1261
+ target_value = [
1262
+ p .clone ()
1263
+ for p in loss_fn .target_value_network_params .values (
1264
+ include_nested = True , leaves_only = True
1265
+ )
1266
+ ]
1246
1267
for p in loss_fn .parameters ():
1247
1268
p .data += torch .randn_like (p )
1248
1269
target_actor2 = [
@@ -1257,12 +1278,13 @@ def test_sac_batcher(
1257
1278
include_nested = True , leaves_only = True
1258
1279
)
1259
1280
]
1260
- target_value2 = [
1261
- p .clone ()
1262
- for p in loss_fn .target_value_network_params .values (
1263
- include_nested = True , leaves_only = True
1264
- )
1265
- ]
1281
+ if version == 1 :
1282
+ target_value2 = [
1283
+ p .clone ()
1284
+ for p in loss_fn .target_value_network_params .values (
1285
+ include_nested = True , leaves_only = True
1286
+ )
1287
+ ]
1266
1288
if loss_fn .delay_actor :
1267
1289
assert all ((p1 == p2 ).all () for p1 , p2 in zip (target_actor , target_actor2 ))
1268
1290
else :
@@ -1277,12 +1299,15 @@ def test_sac_batcher(
1277
1299
assert not any (
1278
1300
(p1 == p2 ).any () for p1 , p2 in zip (target_qvalue , target_qvalue2 )
1279
1301
)
1280
- if loss_fn .delay_value :
1281
- assert all ((p1 == p2 ).all () for p1 , p2 in zip (target_value , target_value2 ))
1282
- else :
1283
- assert not any (
1284
- (p1 == p2 ).any () for p1 , p2 in zip (target_value , target_value2 )
1285
- )
1302
+ if version == 1 :
1303
+ if loss_fn .delay_value :
1304
+ assert all (
1305
+ (p1 == p2 ).all () for p1 , p2 in zip (target_value , target_value2 )
1306
+ )
1307
+ else :
1308
+ assert not any (
1309
+ (p1 == p2 ).any () for p1 , p2 in zip (target_value , target_value2 )
1310
+ )
1286
1311
1287
1312
# check that policy is updated after parameter update
1288
1313
parameters = [p .clone () for p in actor .parameters ()]
0 commit comments