@@ -673,14 +673,12 @@ def status(self) -> chess.Status:
673
673
674
674
ThreeCheckBoardT = TypeVar ("ThreeCheckBoardT" , bound = "ThreeCheckBoard" )
675
675
676
- class _ThreeCheckBoardState (Generic [ThreeCheckBoardT ], chess ._BoardState [ThreeCheckBoardT ]):
677
- def __init__ (self , board : ThreeCheckBoardT ) -> None :
678
- super ().__init__ (board )
676
+ class _ThreeCheckBoardState :
677
+ def __init__ (self , board : ThreeCheckBoard ) -> None :
679
678
self .remaining_checks_w = board .remaining_checks [chess .WHITE ]
680
679
self .remaining_checks_b = board .remaining_checks [chess .BLACK ]
681
680
682
- def restore (self , board : ThreeCheckBoardT ) -> None :
683
- super ().restore (board )
681
+ def restore (self , board : ThreeCheckBoard ) -> None :
684
682
board .remaining_checks [chess .WHITE ] = self .remaining_checks_w
685
683
board .remaining_checks [chess .BLACK ] = self .remaining_checks_b
686
684
@@ -698,8 +696,13 @@ class ThreeCheckBoard(chess.Board):
698
696
699
697
def __init__ (self , fen : Optional [str ] = starting_fen , chess960 : bool = False ) -> None :
700
698
self .remaining_checks = [3 , 3 ]
699
+ self ._three_check_stack : List [_ThreeCheckBoardState ] = []
701
700
super ().__init__ (fen , chess960 = chess960 )
702
701
702
+ def clear_stack (self ) -> None :
703
+ super ().clear_stack ()
704
+ self ._three_check_stack .clear ()
705
+
703
706
def reset_board (self ) -> None :
704
707
super ().reset_board ()
705
708
self .remaining_checks [chess .WHITE ] = 3
@@ -710,14 +713,17 @@ def clear_board(self) -> None:
710
713
self .remaining_checks [chess .WHITE ] = 3
711
714
self .remaining_checks [chess .BLACK ] = 3
712
715
713
- def _board_state (self : ThreeCheckBoardT ) -> _ThreeCheckBoardState [ThreeCheckBoardT ]:
714
- return _ThreeCheckBoardState (self )
715
-
716
716
def push (self , move : chess .Move ) -> None :
717
+ self ._three_check_stack .append (_ThreeCheckBoardState (self ))
717
718
super ().push (move )
718
719
if self .is_check ():
719
720
self .remaining_checks [not self .turn ] -= 1
720
721
722
+ def pop (self ) -> chess .Move :
723
+ move = super ().pop ()
724
+ self ._three_check_stack .pop ().restore (self )
725
+ return move
726
+
721
727
def has_insufficient_material (self , color : chess .Color ) -> bool :
722
728
# Any remaining piece can give check.
723
729
return not (self .occupied_co [color ] & ~ self .kings )
@@ -792,8 +798,19 @@ def _transposition_key(self) -> Hashable:
792
798
def copy (self : ThreeCheckBoardT , stack : Union [bool , int ] = True ) -> ThreeCheckBoardT :
793
799
board = super ().copy (stack = stack )
794
800
board .remaining_checks = self .remaining_checks .copy ()
801
+ if stack :
802
+ stack = len (self .move_stack ) if stack is True else stack
803
+ board ._three_check_stack = self ._three_check_stack [- stack :]
795
804
return board
796
805
806
+ def root (self : ThreeCheckBoardT ) -> ThreeCheckBoardT :
807
+ if self ._three_check_stack :
808
+ board = super ().root ()
809
+ self ._three_check_stack [0 ].restore (board )
810
+ return board
811
+ else :
812
+ return self .copy (stack = False )
813
+
797
814
def mirror (self : ThreeCheckBoardT ) -> ThreeCheckBoardT :
798
815
board = super ().mirror ()
799
816
board .remaining_checks [chess .WHITE ] = self .remaining_checks [chess .BLACK ]
@@ -803,14 +820,12 @@ def mirror(self: ThreeCheckBoardT) -> ThreeCheckBoardT:
803
820
804
821
CrazyhouseBoardT = TypeVar ("CrazyhouseBoardT" , bound = "CrazyhouseBoard" )
805
822
806
- class _CrazyhouseBoardState (Generic [CrazyhouseBoardT ], chess ._BoardState [CrazyhouseBoardT ]):
807
- def __init__ (self , board : CrazyhouseBoardT ) -> None :
808
- super ().__init__ (board )
823
+ class _CrazyhouseBoardState :
824
+ def __init__ (self , board : CrazyhouseBoard ) -> None :
809
825
self .pockets_w = board .pockets [chess .WHITE ].copy ()
810
826
self .pockets_b = board .pockets [chess .BLACK ].copy ()
811
827
812
- def restore (self , board : CrazyhouseBoardT ) -> None :
813
- super ().restore (board )
828
+ def restore (self , board : CrazyhouseBoard ) -> None :
814
829
board .pockets [chess .WHITE ] = self .pockets_w
815
830
board .pockets [chess .BLACK ] = self .pockets_b
816
831
@@ -870,8 +885,13 @@ class CrazyhouseBoard(chess.Board):
870
885
871
886
def __init__ (self , fen : Optional [str ] = starting_fen , chess960 : bool = False ) -> None :
872
887
self .pockets = [CrazyhousePocket (), CrazyhousePocket ()]
888
+ self ._crazyhouse_stack : List [_CrazyhouseBoardState ] = []
873
889
super ().__init__ (fen , chess960 = chess960 )
874
890
891
+ def clear_stack (self ) -> None :
892
+ super ().clear_stack ()
893
+ self ._crazyhouse_stack .clear ()
894
+
875
895
def reset_board (self ) -> None :
876
896
super ().reset_board ()
877
897
self .pockets [chess .WHITE ].reset ()
@@ -882,10 +902,8 @@ def clear_board(self) -> None:
882
902
self .pockets [chess .WHITE ].reset ()
883
903
self .pockets [chess .BLACK ].reset ()
884
904
885
- def _board_state (self : CrazyhouseBoardT ) -> _CrazyhouseBoardState [CrazyhouseBoardT ]:
886
- return _CrazyhouseBoardState (self )
887
-
888
905
def push (self , move : chess .Move ) -> None :
906
+ self ._crazyhouse_stack .append (_CrazyhouseBoardState (self ))
889
907
super ().push (move )
890
908
if move .drop :
891
909
self .pockets [not self .turn ].remove (move .drop )
@@ -896,6 +914,11 @@ def _push_capture(self, move: chess.Move, capture_square: chess.Square, piece_ty
896
914
else :
897
915
self .pockets [self .turn ].add (piece_type )
898
916
917
+ def pop (self ) -> chess .Move :
918
+ move = super ().pop ()
919
+ self ._crazyhouse_stack .pop ().restore (self )
920
+ return move
921
+
899
922
def _is_halfmoves (self , n : int ) -> bool :
900
923
# No draw by 50-move rule or 75-move rule.
901
924
return False
@@ -1028,8 +1051,19 @@ def copy(self: CrazyhouseBoardT, stack: Union[bool, int] = True) -> CrazyhouseBo
1028
1051
board = super ().copy (stack = stack )
1029
1052
board .pockets [chess .WHITE ] = self .pockets [chess .WHITE ].copy ()
1030
1053
board .pockets [chess .BLACK ] = self .pockets [chess .BLACK ].copy ()
1054
+ if stack :
1055
+ stack = len (self .move_stack ) if stack is True else stack
1056
+ board ._crazyhouse_stack = self ._crazyhouse_stack [- stack :]
1031
1057
return board
1032
1058
1059
+ def root (self : CrazyhouseBoardT ) -> CrazyhouseBoardT :
1060
+ if self ._crazyhouse_stack :
1061
+ board = super ().root ()
1062
+ self ._crazyhouse_stack [0 ].restore (board )
1063
+ return board
1064
+ else :
1065
+ return self .copy (stack = False )
1066
+
1033
1067
def mirror (self : CrazyhouseBoardT ) -> CrazyhouseBoardT :
1034
1068
board = super ().mirror ()
1035
1069
board .pockets [chess .WHITE ] = self .pockets [chess .BLACK ].copy ()
0 commit comments