Skip to content

Commit 3be85c6

Browse files
committed
[Feature] Add include_hash_inv arg to ChessEnv
ghstack-source-id: f6920d7 Pull Request resolved: #2766
1 parent 32c4623 commit 3be85c6

File tree

2 files changed

+97
-11
lines changed

2 files changed

+97
-11
lines changed

test/test_env.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3709,6 +3709,74 @@ def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san
37093709
if include_san:
37103710
assert "san_hash" in env.observation_spec.keys()
37113711

3712+
# Test that `include_hash_inv=True` allows us to specify the board state
3713+
# with just the "fen_hash" or "pgn_hash", not "fen" or "pgn", when taking a
3714+
# step in the env.
3715+
@pytest.mark.parametrize(
3716+
"include_fen,include_pgn",
3717+
[[True, False], [False, True]],
3718+
)
3719+
@pytest.mark.parametrize("stateful", [True, False])
3720+
def test_env_hash_inv(self, include_fen, include_pgn, stateful):
3721+
env = ChessEnv(
3722+
include_fen=include_fen,
3723+
include_pgn=include_pgn,
3724+
include_hash=True,
3725+
include_hash_inv=True,
3726+
stateful=stateful,
3727+
)
3728+
env.check_env_specs()
3729+
3730+
def exclude_fen_and_pgn(td):
3731+
td = td.exclude("fen")
3732+
td = td.exclude("pgn")
3733+
return td
3734+
3735+
td0 = env.reset()
3736+
3737+
if include_fen:
3738+
env_check_fen = ChessEnv(
3739+
include_fen=True,
3740+
stateful=stateful,
3741+
)
3742+
3743+
if include_pgn:
3744+
env_check_pgn = ChessEnv(
3745+
include_pgn=True,
3746+
stateful=stateful,
3747+
)
3748+
3749+
for _ in range(8):
3750+
td1 = env.rand_step(exclude_fen_and_pgn(td0.clone()))
3751+
3752+
# Confirm that fen/pgn was not used to determine the board state
3753+
assert "fen" not in td1.keys()
3754+
assert "pgn" not in td1.keys()
3755+
3756+
if include_fen:
3757+
assert (td1["fen_hash"] == td0["fen_hash"]).all()
3758+
assert "fen" in td1["next"]
3759+
3760+
# Check that if we start in the same board state and perform the
3761+
# same action in an env that does not use hashes, we obtain the
3762+
# same next board state. This confirms that we really can
3763+
# successfully specify the board state with a hash.
3764+
td0_check = td1.clone().exclude("next").update({"fen": td0["fen"]})
3765+
assert (
3766+
env_check_fen.step(td0_check)["next", "fen"] == td1["next", "fen"]
3767+
)
3768+
3769+
if include_pgn:
3770+
assert (td1["pgn_hash"] == td0["pgn_hash"]).all()
3771+
assert "pgn" in td1["next"]
3772+
3773+
td0_check = td1.clone().exclude("next").update({"pgn": td0["pgn"]})
3774+
assert (
3775+
env_check_pgn.step(td0_check)["next", "pgn"] == td1["next", "pgn"]
3776+
)
3777+
3778+
td0 = td1["next"]
3779+
37123780
@pytest.mark.skipif(not _has_tv, reason="torchvision not found.")
37133781
@pytest.mark.skipif(not _has_cairosvg, reason="cairosvg not found.")
37143782
@pytest.mark.parametrize("stateful", [False, True])

torchrl/envs/custom/chess.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,38 @@
2222
class _ChessMeta(_EnvPostInit):
2323
def __call__(cls, *args, **kwargs):
2424
instance = super().__call__(*args, **kwargs)
25-
if kwargs.get("include_hash"):
25+
include_hash = kwargs.get("include_hash")
26+
include_hash_inv = kwargs.get("include_hash_inv")
27+
if include_hash:
2628
from torchrl.envs import Hash
2729

2830
in_keys = []
2931
out_keys = []
30-
if instance.include_san:
31-
in_keys.append("san")
32-
out_keys.append("san_hash")
33-
if instance.include_fen:
34-
in_keys.append("fen")
35-
out_keys.append("fen_hash")
36-
if instance.include_pgn:
37-
in_keys.append("pgn")
38-
out_keys.append("pgn_hash")
39-
instance = instance.append_transform(Hash(in_keys, out_keys))
32+
in_keys_inv = [] if include_hash_inv else None
33+
out_keys_inv = [] if include_hash_inv else None
34+
35+
def maybe_add_keys(condition, in_key, out_key):
36+
if condition:
37+
in_keys.append(in_key)
38+
out_keys.append(out_key)
39+
if include_hash_inv:
40+
in_keys_inv.append(in_key)
41+
out_keys_inv.append(out_key)
42+
43+
maybe_add_keys(instance.include_san, "san", "san_hash")
44+
maybe_add_keys(instance.include_fen, "fen", "fen_hash")
45+
maybe_add_keys(instance.include_pgn, "pgn", "pgn_hash")
46+
47+
instance = instance.append_transform(
48+
Hash(in_keys, out_keys, in_keys_inv, out_keys_inv)
49+
)
50+
elif include_hash_inv:
51+
raise ValueError(
52+
(
53+
"'include_hash_inv=True' can only be set if"
54+
f"'include_hash=True', but got 'include_hash={include_hash}'."
55+
)
56+
)
4057
if kwargs.get("mask_actions", True):
4158
from torchrl.envs import ActionMask
4259

@@ -265,6 +282,7 @@ def __init__(
265282
include_pgn: bool = False,
266283
include_legal_moves: bool = False,
267284
include_hash: bool = False,
285+
include_hash_inv: bool = False,
268286
mask_actions: bool = True,
269287
pixels: bool = False,
270288
):

0 commit comments

Comments
 (0)