diff --git a/test/test_env.py b/test/test_env.py index 83734619ebf..5301e77a0b7 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4234,43 +4234,62 @@ def test_env_reset_with_hash(self, stateful, include_san): td_check = env.reset(td.select("fen_hash")) assert (td_check == td).all() - @pytest.mark.parametrize("include_fen", [False, True]) - @pytest.mark.parametrize("include_pgn", [False, True]) + @pytest.mark.parametrize("include_fen,include_pgn", [[False, True], [True, False]]) @pytest.mark.parametrize("stateful", [False, True]) - @pytest.mark.parametrize("mask_actions", [False, True]) - def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions): - if not stateful and not include_fen and not include_pgn: - # pytest.skip("fen or pgn must be included if not stateful") - return - + @pytest.mark.parametrize("include_hash", [False, True]) + @pytest.mark.parametrize("include_san", [False, True]) + @pytest.mark.parametrize("append_transform", [False, True]) + @pytest.mark.parametrize("mask_actions", [True]) + def test_all_actions( + self, + include_fen, + include_pgn, + stateful, + include_hash, + include_san, + append_transform, + mask_actions, + ): env = ChessEnv( include_fen=include_fen, include_pgn=include_pgn, + include_san=include_san, + include_hash=include_hash, + include_hash_inv=include_hash, stateful=stateful, mask_actions=mask_actions, ) - td = env.reset() - if not mask_actions: - with pytest.raises(RuntimeError, match="Cannot generate legal actions"): - env.all_actions() - return + def transform_reward(td): + if "reward" not in td: + return td + reward = td["reward"] + if reward == 0.5: + td["reward"] = 0 + elif reward == 1 and td["turn"]: + td["reward"] = -td["reward"] + return td + + if append_transform: + env = env.append_transform(transform_reward) + + check_env_specs(env) + + td = env.reset() # Choose random actions from the output of `all_actions` - for _ in range(100): - if stateful: - all_actions = env.all_actions() - else: + for step_idx in range(100): + if step_idx % 5 == 0: # Reset theinitial state first, just to make sure # `all_actions` knows how to get the board state from the input. env.reset() - all_actions = env.all_actions(td.clone()) + all_actions = env.all_actions(td.clone()) # Choose some random actions and make sure they match exactly one of # the actions from `all_actions`. This part is not tested when # `mask_actions == False`, because `rand_action` can pick illegal # actions in that case. - if mask_actions: + if mask_actions and step_idx % 4 == 0: # TODO: Something is wrong in `ChessEnv.rand_action` which makes # it fail to work properly for stateless mode. It doesn't know # how to correctly reset the board state to what is given in the @@ -4287,7 +4306,9 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions): action_idx = torch.randint(0, all_actions.shape[0], ()).item() chosen_action = all_actions[action_idx] - td = env.step(td.update(chosen_action))["next"] + td_new = env.step(td.update(chosen_action).clone()) + assert (td == td_new.exclude("next")).all() + td = td_new["next"] if td["done"]: td = env.reset()