@@ -3709,6 +3709,74 @@ def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san
3709
3709
if include_san :
3710
3710
assert "san_hash" in env .observation_spec .keys ()
3711
3711
3712
+ # Test that `include_hash_inv=True` allows us to specify the board state
3713
+ # with just the "fen_hash" or "pgn_hash", not "fen" or "pgn", when taking a
3714
+ # step in the env.
3715
+ @pytest .mark .parametrize (
3716
+ "include_fen,include_pgn" ,
3717
+ [[True , False ], [False , True ]],
3718
+ )
3719
+ @pytest .mark .parametrize ("stateful" , [True , False ])
3720
+ def test_env_hash_inv (self , include_fen , include_pgn , stateful ):
3721
+ env = ChessEnv (
3722
+ include_fen = include_fen ,
3723
+ include_pgn = include_pgn ,
3724
+ include_hash = True ,
3725
+ include_hash_inv = True ,
3726
+ stateful = stateful ,
3727
+ )
3728
+ env .check_env_specs ()
3729
+
3730
+ def exclude_fen_and_pgn (td ):
3731
+ td = td .exclude ("fen" )
3732
+ td = td .exclude ("pgn" )
3733
+ return td
3734
+
3735
+ td0 = env .reset ()
3736
+
3737
+ if include_fen :
3738
+ env_check_fen = ChessEnv (
3739
+ include_fen = True ,
3740
+ stateful = stateful ,
3741
+ )
3742
+
3743
+ if include_pgn :
3744
+ env_check_pgn = ChessEnv (
3745
+ include_pgn = True ,
3746
+ stateful = stateful ,
3747
+ )
3748
+
3749
+ for _ in range (8 ):
3750
+ td1 = env .rand_step (exclude_fen_and_pgn (td0 .clone ()))
3751
+
3752
+ # Confirm that fen/pgn was not used to determine the board state
3753
+ assert "fen" not in td1 .keys ()
3754
+ assert "pgn" not in td1 .keys ()
3755
+
3756
+ if include_fen :
3757
+ assert (td1 ["fen_hash" ] == td0 ["fen_hash" ]).all ()
3758
+ assert "fen" in td1 ["next" ]
3759
+
3760
+ # Check that if we start in the same board state and perform the
3761
+ # same action in an env that does not use hashes, we obtain the
3762
+ # same next board state. This confirms that we really can
3763
+ # successfully specify the board state with a hash.
3764
+ td0_check = td1 .clone ().exclude ("next" ).update ({"fen" : td0 ["fen" ]})
3765
+ assert (
3766
+ env_check_fen .step (td0_check )["next" , "fen" ] == td1 ["next" , "fen" ]
3767
+ )
3768
+
3769
+ if include_pgn :
3770
+ assert (td1 ["pgn_hash" ] == td0 ["pgn_hash" ]).all ()
3771
+ assert "pgn" in td1 ["next" ]
3772
+
3773
+ td0_check = td1 .clone ().exclude ("next" ).update ({"pgn" : td0 ["pgn" ]})
3774
+ assert (
3775
+ env_check_pgn .step (td0_check )["next" , "pgn" ] == td1 ["next" , "pgn" ]
3776
+ )
3777
+
3778
+ td0 = td1 ["next" ]
3779
+
3712
3780
@pytest .mark .skipif (not _has_tv , reason = "torchvision not found." )
3713
3781
@pytest .mark .skipif (not _has_cairosvg , reason = "cairosvg not found." )
3714
3782
@pytest .mark .parametrize ("stateful" , [False , True ])
0 commit comments