Skip to content

Commit 1ed5d29

Browse files
committed
[BugFix] Apply inverse transform to input of TransformedEnv._reset
ghstack-source-id: 5f7c1fb Pull Request resolved: #2787
1 parent ab76027 commit 1ed5d29

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

test/test_env.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4058,6 +4058,35 @@ def test_chess_tokenized(self):
40584058
assert "fen" in ftd["next"]
40594059
env.check_env_specs()
40604060

4061+
@pytest.mark.parametrize("stateful", [False, True])
4062+
@pytest.mark.parametrize("include_san", [False, True])
4063+
def test_env_reset_with_hash(self, stateful, include_san):
4064+
env = ChessEnv(
4065+
include_fen=True,
4066+
include_hash=True,
4067+
include_hash_inv=True,
4068+
stateful=stateful,
4069+
include_san=include_san,
4070+
)
4071+
cases = [
4072+
# (fen, num_legal_moves)
4073+
("5R1k/8/8/8/6R1/8/8/5K2 b - - 0 1", 1),
4074+
("8/8/2kq4/4K3/1R3Q2/8/8/8 w - - 0 1", 2),
4075+
("6R1/8/8/4rq2/3pPk2/5n2/8/2B1R2K b - e3 0 1", 2),
4076+
]
4077+
for fen, num_legal_moves in cases:
4078+
# Load the state by fen.
4079+
td = env.reset(TensorDict({"fen": fen}))
4080+
assert td["fen"] == fen
4081+
assert td["action_mask"].sum() == num_legal_moves
4082+
# Reset to initial state just to make sure that the next reset
4083+
# actually changes the state.
4084+
assert env.reset()["action_mask"].sum() == 20
4085+
# Load the state by fen hash and make sure it gives the same output
4086+
# as before.
4087+
td_check = env.reset(td.select("fen_hash"))
4088+
assert (td_check == td).all()
4089+
40614090

40624091
class TestCustomEnvs:
40634092
def test_tictactoe_env(self):

torchrl/envs/transforms/transforms.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,6 +1018,10 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
10181018
tensordict = tensordict.select(
10191019
*self.reset_keys, *self.state_spec.keys(True, True), strict=False
10201020
)
1021+
# Inputs might be transformed, so need to apply inverse transform
1022+
# before passing to the env reset function.
1023+
with _set_missing_tolerance(self.transform, True):
1024+
tensordict = self.transform.inv(tensordict)
10211025
tensordict_reset = self.base_env._reset(tensordict, **kwargs)
10221026
if tensordict is None:
10231027
# make sure all transforms see a source tensordict

0 commit comments

Comments
 (0)