@@ -4058,6 +4058,35 @@ def test_chess_tokenized(self):
4058
4058
assert "fen" in ftd ["next" ]
4059
4059
env .check_env_specs ()
4060
4060
4061
+ @pytest .mark .parametrize ("stateful" , [False , True ])
4062
+ @pytest .mark .parametrize ("include_san" , [False , True ])
4063
+ def test_env_reset_with_hash (self , stateful , include_san ):
4064
+ env = ChessEnv (
4065
+ include_fen = True ,
4066
+ include_hash = True ,
4067
+ include_hash_inv = True ,
4068
+ stateful = stateful ,
4069
+ include_san = include_san ,
4070
+ )
4071
+ cases = [
4072
+ # (fen, num_legal_moves)
4073
+ ("5R1k/8/8/8/6R1/8/8/5K2 b - - 0 1" , 1 ),
4074
+ ("8/8/2kq4/4K3/1R3Q2/8/8/8 w - - 0 1" , 2 ),
4075
+ ("6R1/8/8/4rq2/3pPk2/5n2/8/2B1R2K b - e3 0 1" , 2 ),
4076
+ ]
4077
+ for fen , num_legal_moves in cases :
4078
+ # Load the state by fen.
4079
+ td = env .reset (TensorDict ({"fen" : fen }))
4080
+ assert td ["fen" ] == fen
4081
+ assert td ["action_mask" ].sum () == num_legal_moves
4082
+ # Reset to initial state just to make sure that the next reset
4083
+ # actually changes the state.
4084
+ assert env .reset ()["action_mask" ].sum () == 20
4085
+ # Load the state by fen hash and make sure it gives the same output
4086
+ # as before.
4087
+ td_check = env .reset (td .select ("fen_hash" ))
4088
+ assert (td_check == td ).all ()
4089
+
4061
4090
4062
4091
class TestCustomEnvs :
4063
4092
def test_tictactoe_env (self ):
0 commit comments