@@ -1804,10 +1804,8 @@ def _MTRF(self,x_continuous,x_categorical,y,n_estimators=100,**kwargs):
1804
1804
randomforest = RandomForestRegressor (n_estimators = n_estimators ,max_depth = self .c_max_depth ,** kwargs )
1805
1805
1806
1806
x = np .empty ([y .shape [0 ],self .c_dim_features ])
1807
- if self .c_dim_continuous > 0 :
1808
- x [:,:self .c_dim_continuous ] = x_continuous
1809
- if self .c_dim_categorical > 0 :
1810
- x [:,- self .c_dim_categorical :] = x_categorical
1807
+ x [:,:self .c_dim_continuous ] = x_continuous
1808
+ x [:,self .c_dim_continuous :] = x_categorical
1811
1809
1812
1810
randomforest .fit (x ,y )
1813
1811
@@ -1828,18 +1826,9 @@ def _MTRF(self,x_continuous,x_categorical,y,n_estimators=100,**kwargs):
1828
1826
tmp_metatree_list ,tmp_metatree_prob_vec = self ._marge_metatrees (tmp_metatree_list ,tmp_metatree_prob_vec )
1829
1827
1830
1828
log_metatree_posteriors = np .log (tmp_metatree_prob_vec )
1831
- if self .c_dim_continuous > 0 and self .c_dim_categorical > 0 :
1832
- for i ,metatree in enumerate (tmp_metatree_list ):
1833
- for j in range (y .shape [0 ]):
1834
- log_metatree_posteriors [i ] += np .log (self ._update_posterior_recursion (metatree ,x_continuous [j ],x_categorical [j ],y [j ]))
1835
- elif self .c_dim_continuous > 0 :
1836
- for i ,metatree in enumerate (tmp_metatree_list ):
1837
- for j in range (y .shape [0 ]):
1838
- log_metatree_posteriors [i ] += np .log (self ._update_posterior_recursion (metatree ,x_continuous [j ],None ,y [j ]))
1839
- else :
1840
- for i ,metatree in enumerate (tmp_metatree_list ):
1841
- for j in range (y .shape [0 ]):
1842
- log_metatree_posteriors [i ] += np .log (self ._update_posterior_recursion (metatree ,None ,x_categorical [j ],y [j ]))
1829
+ for i ,metatree in enumerate (tmp_metatree_list ):
1830
+ for j in range (y .shape [0 ]):
1831
+ log_metatree_posteriors [i ] += np .log (self ._update_posterior_recursion (metatree ,x_continuous [j ],x_categorical [j ],y [j ]))
1843
1832
tmp_metatree_prob_vec [:] = np .exp (log_metatree_posteriors - log_metatree_posteriors .max ())
1844
1833
tmp_metatree_prob_vec [:] /= tmp_metatree_prob_vec .sum ()
1845
1834
return tmp_metatree_list ,tmp_metatree_prob_vec
@@ -1868,18 +1857,9 @@ def _given_MT(self,x_continuous,x_categorical,y):
1868
1857
if not self .hn_metatree_list :
1869
1858
raise (ParameterFormatError ("given_MT is supported only when len(self.hn_metatree_list) > 0." ))
1870
1859
log_metatree_posteriors = np .log (self .hn_metatree_prob_vec )
1871
- if self .c_dim_continuous > 0 and self .c_dim_categorical > 0 :
1872
- for i ,metatree in enumerate (self .hn_metatree_list ):
1873
- for j in range (y .shape [0 ]):
1874
- log_metatree_posteriors [i ] += np .log (self ._update_posterior_recursion (metatree ,x_continuous [j ],x_categorical [j ],y [j ]))
1875
- elif self .c_dim_continuous > 0 :
1876
- for i ,metatree in enumerate (self .hn_metatree_list ):
1877
- for j in range (y .shape [0 ]):
1878
- log_metatree_posteriors [i ] += np .log (self ._update_posterior_recursion (metatree ,x_continuous [j ],None ,y [j ]))
1879
- else :
1880
- for i ,metatree in enumerate (self .hn_metatree_list ):
1881
- for j in range (y .shape [0 ]):
1882
- log_metatree_posteriors [i ] += np .log (self ._update_posterior_recursion (metatree ,None ,x_categorical [j ],y [j ]))
1860
+ for i ,metatree in enumerate (self .hn_metatree_list ):
1861
+ for j in range (y .shape [0 ]):
1862
+ log_metatree_posteriors [i ] += np .log (self ._update_posterior_recursion (metatree ,x_continuous [j ],x_categorical [j ],y [j ]))
1883
1863
self .hn_metatree_prob_vec [:] = np .exp (log_metatree_posteriors - log_metatree_posteriors .max ())
1884
1864
self .hn_metatree_prob_vec [:] /= self .hn_metatree_prob_vec .sum ()
1885
1865
return self .hn_metatree_list ,self .hn_metatree_prob_vec
@@ -1958,7 +1938,7 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
1958
1938
ParameterFormatError
1959
1939
)
1960
1940
1961
- x_categorical = None
1941
+ x_categorical = np . empty ([ y . shape [ 0 ], 0 ]) # dummy
1962
1942
1963
1943
elif self .c_dim_categorical > 0 :
1964
1944
_check .nonneg_int_vecs (x_categorical ,'x_categorical' ,DataFormatError )
@@ -1981,7 +1961,7 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
1981
1961
ParameterFormatError
1982
1962
)
1983
1963
1984
- x_continuous = None
1964
+ x_continuous = np . empty ([ y . shape [ 0 ], 0 ]) # dummy
1985
1965
1986
1966
if alg_type == 'MTRF' :
1987
1967
self .hn_metatree_list , self .hn_metatree_prob_vec = self ._MTRF (x_continuous ,x_categorical ,y ,** kwargs )
@@ -1991,37 +1971,57 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
1991
1971
def _map_recursion_add_nodes (self ,node :_Node ):
1992
1972
if node .depth == self .c_max_depth or not node .k_candidates : # leaf node
1993
1973
node .h_g = 0.0
1974
+ node .sub_model = self .SubModel .LearnModel (** self .sub_hn_params )
1994
1975
node .leaf = True
1995
1976
node .map_leaf = True
1996
1977
else : # inner node
1997
- node .k = node .k_candidates [0 ]
1978
+ node .k = node .k_candidates [self .hn_k_weight_vec [node .k_candidates ].argmax ()]
1979
+ node .children = [None for i in range (self .c_num_children_vec [node .k ])]
1980
+ if node .k < self .c_dim_continuous :
1981
+ node .thresholds = np .linspace (
1982
+ node .ranges [node .k ,0 ],
1983
+ node .ranges [node .k ,1 ],
1984
+ self .c_num_children_vec [node .k ]+ 1
1985
+ )
1986
+ else :
1987
+ node .thresholds = None
1998
1988
child_k_candidates = node .k_candidates .copy ()
1999
1989
child_k_candidates .remove (node .k )
2000
- for i in range (self .c_num_children ):
1990
+ node .leaf = False
1991
+ for i in range (self .c_num_children_vec [node .k ]):
2001
1992
node .children [i ] = _Node (
2002
1993
node .depth + 1 ,
2003
- self .c_num_children ,
2004
1994
child_k_candidates ,
2005
1995
self .hn_g ,
2006
- sub_model = self . SubModel . LearnModel ( ** self . sub_h0_params )
1996
+ ranges = np . array ( node . ranges )
2007
1997
)
1998
+ if node .thresholds is not None :
1999
+ node .children [i ].ranges [node .k ,0 ] = node .thresholds [i ]
2000
+ node .children [i ].ranges [node .k ,1 ] = node .thresholds [i + 1 ]
2008
2001
self ._map_recursion_add_nodes (node .children [i ])
2009
2002
2010
2003
def _map_recursion (self ,node :_Node ):
2011
2004
if node .leaf :
2012
2005
if node .depth == self .c_max_depth or not node .k_candidates :
2013
2006
node .map_leaf = True
2014
2007
return 1.0
2015
- elif 1.0 - self .hn_g > self .hn_g ** ((self .c_num_children ** (self .c_max_depth - node .depth ) - 1 )/ (self .c_num_children - 1 )):
2016
- node .map_leaf = True
2017
- return 1.0 - self .hn_g
2018
2008
else :
2019
- self ._map_recursion_add_nodes (node )
2020
- return self .hn_g ** ((self .c_num_children ** (self .c_max_depth - node .depth ) - 1 )/ (self .c_num_children - 1 ))
2009
+ sum_nodes = 0
2010
+ num_nodes = 1
2011
+ rest_num_children_vec = np .sort (self .c_num_children_vec [node .k_candidates ])
2012
+ for i in range (min (self .c_max_depth - node .depth ,len (node .k_candidates ))):
2013
+ sum_nodes += num_nodes
2014
+ num_nodes *= rest_num_children_vec [i ]
2015
+ if 1.0 - node .h_g > node .h_g * self .hn_g ** (sum_nodes - 1 ):
2016
+ node .map_leaf = True
2017
+ return 1.0 - node .h_g
2018
+ else :
2019
+ self ._map_recursion_add_nodes (node )
2020
+ return node .h_g * self .hn_g ** (sum_nodes - 1 )
2021
2021
else :
2022
2022
tmp1 = 1.0 - node .h_g
2023
- tmp_vec = np .empty (self .c_num_children )
2024
- for i in range (self .c_num_children ):
2023
+ tmp_vec = np .empty (self .c_num_children_vec [ node . k ] )
2024
+ for i in range (self .c_num_children_vec [ node . k ] ):
2025
2025
tmp_vec [i ] = self ._map_recursion (node .children [i ])
2026
2026
if tmp1 > node .h_g * tmp_vec .prod ():
2027
2027
node .map_leaf = True
@@ -2030,22 +2030,31 @@ def _map_recursion(self,node:_Node):
2030
2030
node .map_leaf = False
2031
2031
return node .h_g * tmp_vec .prod ()
2032
2032
2033
- def _copy_map_tree_recursion (self ,copyed_node :_Node ,original_node :_Node ):
2034
- copyed_node .h_g = original_node .h_g
2035
- if original_node .map_leaf == False :
2036
- copyed_node .k = original_node .k
2037
- child_k_candidates = copyed_node .k_candidates .copy ()
2038
- child_k_candidates .remove (copyed_node .k )
2039
- for i in range (self .c_num_children ):
2040
- copyed_node .children [i ] = _Node (
2041
- copyed_node .depth + 1 ,
2042
- self .c_num_children ,
2033
+ def _copy_map_tree_recursion (self ,copied_node :_Node ,original_node :_Node ):
2034
+ copied_node .h_g = original_node .h_g
2035
+ if original_node .map_leaf :
2036
+ copied_node .sub_model = copy .deepcopy (original_node .sub_model )
2037
+ copied_node .leaf = True
2038
+ else :
2039
+ copied_node .k = original_node .k
2040
+ copied_node .children = [None for i in range (self .c_num_children_vec [copied_node .k ])]
2041
+ if copied_node .k < self .c_dim_continuous :
2042
+ copied_node .thresholds = np .array (original_node .thresholds )
2043
+ else :
2044
+ copied_node .thresholds = None
2045
+ child_k_candidates = copied_node .k_candidates .copy ()
2046
+ child_k_candidates .remove (copied_node .k )
2047
+ copied_node .leaf = False
2048
+ for i in range (self .c_num_children_vec [copied_node .k ]):
2049
+ copied_node .children [i ] = _Node (
2050
+ copied_node .depth + 1 ,
2043
2051
child_k_candidates ,
2052
+ ranges = np .array (copied_node .ranges ),
2044
2053
)
2045
- self . _copy_map_tree_recursion ( copyed_node . children [ i ], original_node . children [ i ])
2046
- else :
2047
- copyed_node . sub_model = copy . deepcopy ( original_node . sub_model )
2048
- copyed_node . leaf = True
2054
+ if copied_node . thresholds is not None :
2055
+ copied_node . children [ i ]. ranges [ copied_node . k , 0 ] = copied_node . thresholds [ i ]
2056
+ copied_node . children [ i ]. ranges [ copied_node . k , 1 ] = copied_node . thresholds [ i + 1 ]
2057
+ self . _copy_map_tree_recursion ( copied_node . children [ i ], original_node . children [ i ])
2049
2058
2050
2059
def estimate_params (self ,loss = "0-1" ,visualize = True ,filename = None ,format = None ):
2051
2060
"""Estimate the parameter under the given criterion.
@@ -2088,13 +2097,18 @@ def estimate_params(self,loss="0-1",visualize=True,filename=None,format=None):
2088
2097
if prob > map_prob :
2089
2098
map_index = i
2090
2099
map_prob = prob
2091
- map_root = _Node (0 ,self .c_num_children ,list (range (self .c_k )))
2100
+ map_root = _Node (
2101
+ 0 ,
2102
+ self ._root_k_candidates ,
2103
+ self .hn_g ,
2104
+ ranges = self .c_ranges ,
2105
+ )
2092
2106
self ._copy_map_tree_recursion (map_root ,self .hn_metatree_list [map_index ])
2093
2107
if visualize :
2094
2108
import graphviz
2095
2109
tree_graph = graphviz .Digraph (filename = filename ,format = format )
2096
2110
tree_graph .attr ("node" ,shape = "box" ,fontname = "helvetica" ,style = "rounded,filled" )
2097
- self ._visualize_model_recursion (tree_graph , map_root , 0 , None , None , 1.0 )
2111
+ self ._visualize_model_recursion (tree_graph , map_root , 0 , None , None , None , 1.0 )
2098
2112
tree_graph .view ()
2099
2113
return {'root' :map_root }
2100
2114
else :
@@ -2119,17 +2133,20 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,pare
2119
2133
label_string += f',{ node .thresholds [i + 1 ]:.2f} '
2120
2134
label_string += '}\\ l'
2121
2135
label_string += f'hn_g={ node .h_g :.2f} \\ lp_v={ tmp_p_v :.2f} \\ lsub_params={{'
2122
- try :
2123
- sub_params = node .sub_model .estimate_params (loss = '0-1' ,dict_out = True )
2124
- except :
2125
- sub_params = node .sub_model .estimate_params (dict_out = True )
2126
-
2127
- for key ,value in sub_params .items ():
2136
+ if node .sub_model is not None :
2128
2137
try :
2129
- label_string += f' \\ l { key } : { value :.2f } '
2138
+ sub_params = node . sub_model . estimate_params ( loss = '0-1' , dict_out = True )
2130
2139
except :
2131
- label_string += f'\\ l{ key } :{ value } '
2132
- label_string += '}'
2140
+ sub_params = node .sub_model .estimate_params (dict_out = True )
2141
+
2142
+ for key ,value in sub_params .items ():
2143
+ try :
2144
+ label_string += f'\\ l{ key } :{ value :.2f} '
2145
+ except :
2146
+ label_string += f'\\ l{ key } :{ value } '
2147
+ label_string += '}'
2148
+ else :
2149
+ label_string += '\\ lNone}'
2133
2150
2134
2151
tree_graph .node (name = f'{ tmp_id } ' ,label = label_string ,fillcolor = f'{ rgb2hex (_CMAP (tmp_p_v ))} ' )
2135
2152
if tmp_p_v > 0.65 :
0 commit comments