@@ -4157,43 +4157,62 @@ def test_env_reset_with_hash(self, stateful, include_san):
4157
4157
td_check = env .reset (td .select ("fen_hash" ))
4158
4158
assert (td_check == td ).all ()
4159
4159
4160
- @pytest .mark .parametrize ("include_fen" , [False , True ])
4161
- @pytest .mark .parametrize ("include_pgn" , [False , True ])
4160
+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[False , True ], [True , False ]])
4162
4161
@pytest .mark .parametrize ("stateful" , [False , True ])
4163
- @pytest .mark .parametrize ("mask_actions" , [False , True ])
4164
- def test_all_actions (self , include_fen , include_pgn , stateful , mask_actions ):
4165
- if not stateful and not include_fen and not include_pgn :
4166
- # pytest.skip("fen or pgn must be included if not stateful")
4167
- return
4168
-
4162
+ @pytest .mark .parametrize ("include_hash" , [False , True ])
4163
+ @pytest .mark .parametrize ("include_san" , [False , True ])
4164
+ @pytest .mark .parametrize ("append_transform" , [False , True ])
4165
+ @pytest .mark .parametrize ("mask_actions" , [True ])
4166
+ def test_all_actions (
4167
+ self ,
4168
+ include_fen ,
4169
+ include_pgn ,
4170
+ stateful ,
4171
+ include_hash ,
4172
+ include_san ,
4173
+ append_transform ,
4174
+ mask_actions ,
4175
+ ):
4169
4176
env = ChessEnv (
4170
4177
include_fen = include_fen ,
4171
4178
include_pgn = include_pgn ,
4179
+ include_san = include_san ,
4180
+ include_hash = include_hash ,
4181
+ include_hash_inv = include_hash ,
4172
4182
stateful = stateful ,
4173
4183
mask_actions = mask_actions ,
4174
4184
)
4175
- td = env .reset ()
4176
4185
4177
- if not mask_actions :
4178
- with pytest .raises (RuntimeError , match = "Cannot generate legal actions" ):
4179
- env .all_actions ()
4180
- return
4186
+ def transform_reward (td ):
4187
+ if "reward" not in td :
4188
+ return td
4189
+ reward = td ["reward" ]
4190
+ if reward == 0.5 :
4191
+ td ["reward" ] = 0
4192
+ elif reward == 1 and td ["turn" ]:
4193
+ td ["reward" ] = - td ["reward" ]
4194
+ return td
4195
+
4196
+ if append_transform :
4197
+ env = env .append_transform (transform_reward )
4198
+
4199
+ check_env_specs (env )
4200
+
4201
+ td = env .reset ()
4181
4202
4182
4203
# Choose random actions from the output of `all_actions`
4183
- for _ in range (100 ):
4184
- if stateful :
4185
- all_actions = env .all_actions ()
4186
- else :
4204
+ for step_idx in range (100 ):
4205
+ if step_idx % 5 == 0 :
4187
4206
# Reset the the initial state first, just to make sure
4188
4207
# `all_actions` knows how to get the board state from the input.
4189
4208
env .reset ()
4190
- all_actions = env .all_actions (td .clone ())
4209
+ all_actions = env .all_actions (td .clone ())
4191
4210
4192
4211
# Choose some random actions and make sure they match exactly one of
4193
4212
# the actions from `all_actions`. This part is not tested when
4194
4213
# `mask_actions == False`, because `rand_action` can pick illegal
4195
4214
# actions in that case.
4196
- if mask_actions :
4215
+ if mask_actions and step_idx % 4 == 0 :
4197
4216
# TODO: Something is wrong in `ChessEnv.rand_action` which makes
4198
4217
# it fail to work properly for stateless mode. It doesn't know
4199
4218
# how to correctly reset the board state to what is given in the
@@ -4210,7 +4229,9 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
4210
4229
4211
4230
action_idx = torch .randint (0 , all_actions .shape [0 ], ()).item ()
4212
4231
chosen_action = all_actions [action_idx ]
4213
- td = env .step (td .update (chosen_action ))["next" ]
4232
+ td_new = env .step (td .update (chosen_action ).clone ())
4233
+ assert (td == td_new .exclude ("next" )).all ()
4234
+ td = td_new ["next" ]
4214
4235
4215
4236
if td ["done" ]:
4216
4237
td = env .reset ()
0 commit comments