@@ -423,7 +423,7 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
423
423
if node .children [i ] is None :
424
424
node .children [i ] = _Node (
425
425
node .depth + 1 ,
426
- self .c_k ,
426
+ self .c_num_children ,
427
427
sub_model = self .SubModel .GenModel (** self .sub_h_params ),
428
428
)
429
429
self ._set_h_params_recursion (node .children [i ],original_tree_node .children [i ])
@@ -818,6 +818,22 @@ def __init__(
818
818
h0_metatree_prob_vec ,
819
819
)
820
820
821
+ def _set_h0_g_recursion (self ,node :_Node ):
822
+ if node .depth == self .c_d_max :
823
+ node .h_g = 0
824
+ else :
825
+ node .h_g = self .h0_g
826
+ for i in range (self .c_num_children ):
827
+ if node .children [i ] is not None :
828
+ self ._set_h0_g_recursion (node .children [i ])
829
+
830
+ def _set_sub_h0_params_recursion (self ,node :_Node ):
831
+ # node.sub_model.set_h0_params(**self.sub_h0_params)
832
+ node .sub_model = self .SubModel .LearnModel (** self .sub_h0_params )
833
+ for i in range (self .c_num_children ):
834
+ if node .children [i ] is not None :
835
+ self ._set_sub_h0_params_recursion (node .children [i ])
836
+
821
837
def _set_h0_params_recursion (self ,node :_Node ,original_tree_node :_Node ):
822
838
""" copy parameters from a fixed tree
823
839
@@ -851,16 +867,36 @@ def _set_h0_params_recursion(self,node:_Node,original_tree_node:_Node):
851
867
if original_tree_node .leaf or node .depth == self .c_d_max : # leaf node
852
868
node .leaf = True
853
869
else :
870
+ node .k = original_tree_node .k
871
+ child_k_candidates = copy .copy (node .k_candidates )
872
+ child_k_candidates .remove (node .k )
854
873
node .leaf = False
855
874
for i in range (self .c_num_children ):
856
875
if node .children [i ] is None :
857
876
node .children [i ] = _Node (
858
877
node .depth + 1 ,
859
- self .c_k ,
878
+ self .c_num_children ,
879
+ child_k_candidates ,
860
880
sub_model = self .SubModel .LearnModel (** self .sub_h0_params ),
861
881
)
862
882
self ._set_h0_params_recursion (node .children [i ],original_tree_node .children [i ])
863
883
884
+ def _set_hn_g_recursion (self ,node :_Node ):
885
+ if node .depth == self .c_d_max :
886
+ node .h_g = 0
887
+ else :
888
+ node .h_g = self .hn_g
889
+ for i in range (self .c_num_children ):
890
+ if node .children [i ] is not None :
891
+ self ._set_hn_g_recursion (node .children [i ])
892
+
893
+ def _set_sub_hn_params_recursion (self ,node :_Node ):
894
+ # node.sub_model.set_hn_params(**self.sub_hn_params)
895
+ node .sub_model = self .SubModel .LearnModel (** self .sub_hn_params )
896
+ for i in range (self .c_num_children ):
897
+ if node .children [i ] is not None :
898
+ self ._set_sub_hn_params_recursion (node .children [i ])
899
+
864
900
def _set_hn_params_recursion (self ,node :_Node ,original_tree_node :_Node ):
865
901
""" copy parameters from a fixed tree
866
902
@@ -894,12 +930,16 @@ def _set_hn_params_recursion(self,node:_Node,original_tree_node:_Node):
894
930
if original_tree_node .leaf or node .depth == self .c_d_max : # leaf node
895
931
node .leaf = True
896
932
else :
933
+ node .k = original_tree_node .k
934
+ child_k_candidates = copy .copy (node .k_candidates )
935
+ child_k_candidates .remove (node .k )
897
936
node .leaf = False
898
937
for i in range (self .c_num_children ):
899
938
if node .children [i ] is None :
900
939
node .children [i ] = _Node (
901
940
node .depth + 1 ,
902
- self .c_k ,
941
+ self .c_num_children ,
942
+ child_k_candidates ,
903
943
sub_model = self .SubModel .LearnModel (** self .sub_hn_params ),
904
944
)
905
945
self ._set_hn_params_recursion (node .children [i ],original_tree_node .children [i ])
@@ -944,14 +984,14 @@ def set_h0_params(self,
944
984
self .h0_g = _check .float_in_closed01 (h0_g ,'h0_g' ,ParameterFormatError )
945
985
if self .h0_metatree_list :
946
986
for h0_root in self .h0_metatree_list :
947
- self ._set_h0_params_recursion (h0_root , None )
987
+ self ._set_h0_g_recursion (h0_root )
948
988
949
989
if sub_h0_params is not None :
950
990
self .SubModel .LearnModel (** sub_h0_params )
951
991
self .sub_h0_params = copy .deepcopy (sub_h0_params )
952
992
if self .h0_metatree_list :
953
993
for h0_root in self .h0_metatree_list :
954
- self ._set_h0_params_recursion (h0_root , None )
994
+ self ._set_sub_h0_params_recursion (h0_root )
955
995
956
996
if h0_metatree_list is not None :
957
997
if not isinstance (h0_metatree_list ,list ):
@@ -964,7 +1004,22 @@ def set_h0_params(self,
964
1004
raise (ParameterFormatError (
965
1005
"all elements of h0_metatree_list must be instances of metatree._Node or empty"
966
1006
))
967
- self .h0_metatree_list = copy .deepcopy (h0_metatree_list )
1007
+ diff = len (h0_metatree_list ) - len (self .h0_metatree_list )
1008
+ if diff < 0 :
1009
+ del self .h0_metatree_list [diff :]
1010
+ elif diff > 0 :
1011
+ for i in range (diff ):
1012
+ self .h0_metatree_list .append (
1013
+ _Node (
1014
+ 0 ,
1015
+ self .c_num_children ,
1016
+ list (range (self .c_k )),
1017
+ self .h0_g ,
1018
+ sub_model = self .SubModel .LearnModel (** self .sub_h0_params ),
1019
+ )
1020
+ )
1021
+ for i in range (len (self .h0_metatree_list )):
1022
+ self ._set_h0_params_recursion (self .h0_metatree_list [i ],h0_metatree_list [i ])
968
1023
if h0_metatree_prob_vec is not None :
969
1024
self .h0_metatree_prob_vec = np .copy (
970
1025
_check .float_vec_sum_1 (
@@ -1063,14 +1118,14 @@ def set_hn_params(self,
1063
1118
self .hn_g = _check .float_in_closed01 (hn_g ,'hn_g' ,ParameterFormatError )
1064
1119
if self .hn_metatree_list :
1065
1120
for hn_root in self .hn_metatree_list :
1066
- self ._set_hn_params_recursion (hn_root , None )
1121
+ self ._set_hn_g_recursion (hn_root )
1067
1122
1068
1123
if sub_hn_params is not None :
1069
1124
self .SubModel .LearnModel (** sub_hn_params )
1070
1125
self .sub_hn_params = copy .deepcopy (sub_hn_params )
1071
1126
if self .hn_metatree_list :
1072
1127
for hn_root in self .hn_metatree_list :
1073
- self ._set_hn_params_recursion (hn_root , None )
1128
+ self ._set_sub_hn_params_recursion (hn_root )
1074
1129
1075
1130
if hn_metatree_list is not None :
1076
1131
if not isinstance (hn_metatree_list ,list ):
@@ -1083,7 +1138,22 @@ def set_hn_params(self,
1083
1138
raise (ParameterFormatError (
1084
1139
"all elements of hn_metatree_list must be instances of metatree._Node or empty"
1085
1140
))
1086
- self .hn_metatree_list = copy .deepcopy (hn_metatree_list )
1141
+ diff = len (hn_metatree_list ) - len (self .hn_metatree_list )
1142
+ if diff < 0 :
1143
+ del self .hn_metatree_list [diff :]
1144
+ elif diff > 0 :
1145
+ for i in range (diff ):
1146
+ self .hn_metatree_list .append (
1147
+ _Node (
1148
+ 0 ,
1149
+ self .c_num_children ,
1150
+ list (range (self .c_k )),
1151
+ self .hn_g ,
1152
+ sub_model = self .SubModel .LearnModel (** self .sub_hn_params ),
1153
+ )
1154
+ )
1155
+ for i in range (len (self .hn_metatree_list )):
1156
+ self ._set_hn_params_recursion (self .hn_metatree_list [i ],hn_metatree_list [i ])
1087
1157
if hn_metatree_prob_vec is not None :
1088
1158
self .hn_metatree_prob_vec = np .copy (
1089
1159
_check .float_vec_sum_1 (
@@ -1610,13 +1680,14 @@ def _make_prediction_leaf_01(self,node:_Node):
1610
1680
pred_dist = node .sub_model .make_prediction (loss = 'KL' )
1611
1681
if type (pred_dist ) is np .ndarray :
1612
1682
mode_prob = pred_dist [mode ]
1613
- try :
1614
- mode_prob = pred_dist .pdf (mode )
1615
- except :
1683
+ else :
1616
1684
try :
1617
- mode_prob = pred_dist .pmf (mode )
1685
+ mode_prob = pred_dist .pdf (mode )
1618
1686
except :
1619
- mode_prob = None
1687
+ try :
1688
+ mode_prob = pred_dist .pmf (mode )
1689
+ except :
1690
+ mode_prob = None
1620
1691
# elif hasattr(pred_dist,'pdf'):
1621
1692
# mode_prob = pred_dist.pdf(mode)
1622
1693
# elif hasattr(pred_dist,'pmf'):
0 commit comments