Skip to content

Commit a2879d0

Browse files
committed
[Test] Improve coverage of ChessEnv.all_actions
1 parent 8c9dc05 commit a2879d0

File tree

1 file changed

+41
-20
lines changed

1 file changed

+41
-20
lines changed

test/test_env.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4157,43 +4157,62 @@ def test_env_reset_with_hash(self, stateful, include_san):
41574157
td_check = env.reset(td.select("fen_hash"))
41584158
assert (td_check == td).all()
41594159

4160-
@pytest.mark.parametrize("include_fen", [False, True])
4161-
@pytest.mark.parametrize("include_pgn", [False, True])
4160+
@pytest.mark.parametrize("include_fen,include_pgn", [[False, True], [True, False]])
41624161
@pytest.mark.parametrize("stateful", [False, True])
4163-
@pytest.mark.parametrize("mask_actions", [False, True])
4164-
def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
4165-
if not stateful and not include_fen and not include_pgn:
4166-
# pytest.skip("fen or pgn must be included if not stateful")
4167-
return
4168-
4162+
@pytest.mark.parametrize("include_hash", [False, True])
4163+
@pytest.mark.parametrize("include_san", [False, True])
4164+
@pytest.mark.parametrize("append_transform", [False, True])
4165+
@pytest.mark.parametrize("mask_actions", [True])
4166+
def test_all_actions(
4167+
self,
4168+
include_fen,
4169+
include_pgn,
4170+
stateful,
4171+
include_hash,
4172+
include_san,
4173+
append_transform,
4174+
mask_actions,
4175+
):
41694176
env = ChessEnv(
41704177
include_fen=include_fen,
41714178
include_pgn=include_pgn,
4179+
include_san=include_san,
4180+
include_hash=include_hash,
4181+
include_hash_inv=include_hash,
41724182
stateful=stateful,
41734183
mask_actions=mask_actions,
41744184
)
4175-
td = env.reset()
41764185

4177-
if not mask_actions:
4178-
with pytest.raises(RuntimeError, match="Cannot generate legal actions"):
4179-
env.all_actions()
4180-
return
4186+
def transform_reward(td):
4187+
if "reward" not in td:
4188+
return td
4189+
reward = td["reward"]
4190+
if reward == 0.5:
4191+
td["reward"] = 0
4192+
elif reward == 1 and td["turn"]:
4193+
td["reward"] = -td["reward"]
4194+
return td
4195+
4196+
if append_transform:
4197+
env = env.append_transform(transform_reward)
4198+
4199+
check_env_specs(env)
4200+
4201+
td = env.reset()
41814202

41824203
# Choose random actions from the output of `all_actions`
4183-
for _ in range(100):
4184-
if stateful:
4185-
all_actions = env.all_actions()
4186-
else:
4204+
for step_idx in range(100):
4205+
if step_idx % 5 == 0:
41874206
# Reset the the initial state first, just to make sure
41884207
# `all_actions` knows how to get the board state from the input.
41894208
env.reset()
4190-
all_actions = env.all_actions(td.clone())
4209+
all_actions = env.all_actions(td.clone())
41914210

41924211
# Choose some random actions and make sure they match exactly one of
41934212
# the actions from `all_actions`. This part is not tested when
41944213
# `mask_actions == False`, because `rand_action` can pick illegal
41954214
# actions in that case.
4196-
if mask_actions:
4215+
if mask_actions and step_idx % 4 == 0:
41974216
# TODO: Something is wrong in `ChessEnv.rand_action` which makes
41984217
# it fail to work properly for stateless mode. It doesn't know
41994218
# how to correctly reset the board state to what is given in the
@@ -4210,7 +4229,9 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
42104229

42114230
action_idx = torch.randint(0, all_actions.shape[0], ()).item()
42124231
chosen_action = all_actions[action_idx]
4213-
td = env.step(td.update(chosen_action))["next"]
4232+
td_new = env.step(td.update(chosen_action).clone())
4233+
assert (td == td_new.exclude("next")).all()
4234+
td = td_new["next"]
42144235

42154236
if td["done"]:
42164237
td = env.reset()

0 commit comments

Comments
 (0)