39
39
# linearregression,
40
40
exponential ,
41
41
}
42
+ CLF_MODELS = {
43
+ bernoulli ,
44
+ # categorical,
45
+ }
46
+ REG_MODELS = {
47
+ normal ,
48
+ # multivariate_normal,
49
+ # linearregression,
50
+ exponential ,
51
+ poisson ,
52
+ }
42
53
43
54
class _Node :
44
55
def __init__ (self ,
@@ -252,8 +263,7 @@ def _gen_params_recursion(self,node:_Node,h_node:_Node,feature_fix,threshold_fix
252
263
node .h_g = 0
253
264
else :
254
265
node .h_g = self .h_g
255
- # node.sub_model.set_h_params(**self.sub_h_params)
256
- node .sub_model = self .SubModel .GenModel (seed = self .rng ,** self .sub_h_params )
266
+ node .sub_model .set_h_params (** self .sub_h_params )
257
267
if node .depth == self .c_max_depth or not node .k_candidates or self .rng .random () > self .h_g : # leaf node
258
268
node .sub_model .gen_params ()
259
269
node .leaf = True
@@ -332,8 +342,7 @@ def _gen_params_recursion_feature_and_tree_fix(self,node:_Node,threshold_fix,thr
332
342
node .h_g = 0
333
343
else :
334
344
node .h_g = self .h_g
335
- # node.sub_model.set_h_params(**self.sub_h_params)
336
- node .sub_model = self .SubModel .GenModel (seed = self .rng ,** self .sub_h_params )
345
+ node .sub_model .set_h_params (** self .sub_h_params )
337
346
if node .leaf : # leaf node
338
347
node .sub_model .gen_params ()
339
348
node .leaf = True
@@ -481,8 +490,7 @@ def _set_h_g_recursion(self,node:_Node):
481
490
self ._set_h_g_recursion (node .children [i ])
482
491
483
492
def _set_sub_h_params_recursion (self ,node :_Node ):
484
- # node.sub_model.set_h_params(**self.sub_h_params)
485
- node .sub_model = self .SubModel .GenModel (seed = self .rng ,** self .sub_h_params )
493
+ node .sub_model .set_h_params (** self .sub_h_params )
486
494
if not node .leaf :
487
495
for i in range (self .c_num_children_vec [node .k ]):
488
496
self ._set_sub_h_params_recursion (node .children [i ])
@@ -493,8 +501,7 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
493
501
node .h_g = 0
494
502
else :
495
503
node .h_g = self .h_g
496
- # node.sub_model.set_h_params(**self.sub_h_params)
497
- node .sub_model = self .SubModel .GenModel (seed = self .rng ,** self .sub_h_params )
504
+ node .sub_model .set_h_params (** self .sub_h_params )
498
505
if not node .leaf :
499
506
for i in range (self .c_num_children_vec [node .k ]):
500
507
self ._set_h_params_recursion (node .children [i ],None )
@@ -576,8 +583,7 @@ def set_h_params(self,
576
583
self ._set_h_g_recursion (h_root )
577
584
578
585
if sub_h_params is not None :
579
- self .SubModel .GenModel (seed = self .rng ,** sub_h_params )
580
- self .sub_h_params = copy .deepcopy (sub_h_params )
586
+ self .sub_h_params = self .SubModel .GenModel (seed = self .rng ,** sub_h_params ).get_h_params ()
581
587
if self .h_metatree_list :
582
588
for h_root in self .h_metatree_list :
583
589
self ._set_sub_h_params_recursion (h_root )
@@ -686,12 +692,12 @@ def gen_params(self,feature_fix=False,threshold_fix=False,tree_fix=False,thresho
686
692
if ``'random'``, self.c_ranges will be recursively divided by at random intervals.
687
693
"""
688
694
if feature_fix :
689
- warnings .warn (
690
- "If feature_fix=True, tree will be generated according to "
691
- + "self.h_g not any element of self.h_metatree_list." ,ResultWarning )
692
695
if tree_fix :
693
696
self ._gen_params_recursion_feature_and_tree_fix (self .root ,threshold_fix ,threshold_type )
694
697
else :
698
+ warnings .warn (
699
+ "If feature_fix=True, tree will be generated according to "
700
+ + "self.h_g not any element of self.h_metatree_list." ,ResultWarning )
695
701
self ._gen_params_recursion (self .root ,None ,True ,threshold_fix ,threshold_type )
696
702
else :
697
703
if threshold_fix or tree_fix :
@@ -1277,8 +1283,7 @@ def _set_h0_g_recursion(self,node:_Node):
1277
1283
self ._set_h0_g_recursion (node .children [i ])
1278
1284
1279
1285
def _set_sub_h0_params_recursion (self ,node :_Node ):
1280
- # node.sub_model.set_h0_params(**self.sub_h0_params)
1281
- node .sub_model = self .SubModel .LearnModel (** self .sub_h0_params )
1286
+ node .sub_model .set_h0_params (** self .sub_h0_params )
1282
1287
if not node .leaf :
1283
1288
for i in range (self .c_num_children_vec [node .k ]):
1284
1289
self ._set_sub_h0_params_recursion (node .children [i ])
@@ -1289,8 +1294,7 @@ def _set_h0_params_recursion(self,node:_Node,original_tree_node:_Node):
1289
1294
node .h_g = 0
1290
1295
else :
1291
1296
node .h_g = self .h0_g
1292
- # node.sub_model.set_h0_params(**self.sub_h0_params)
1293
- node .sub_model = self .SubModel .LearnModel (** self .sub_h0_params )
1297
+ node .sub_model .set_h0_params (** self .sub_h0_params )
1294
1298
if not node .leaf :
1295
1299
for i in range (self .c_num_children_vec [node .k ]):
1296
1300
self ._set_h0_params_recursion (node .children [i ],None )
@@ -1339,8 +1343,7 @@ def _set_hn_g_recursion(self,node:_Node):
1339
1343
self ._set_hn_g_recursion (node .children [i ])
1340
1344
1341
1345
def _set_sub_hn_params_recursion (self ,node :_Node ):
1342
- # node.sub_model.set_hn_params(**self.sub_hn_params)
1343
- node .sub_model = self .SubModel .LearnModel (** self .sub_hn_params )
1346
+ node .sub_model .set_hn_params (** self .sub_hn_params )
1344
1347
if not node .leaf :
1345
1348
for i in range (self .c_num_children_vec [node .k ]):
1346
1349
self ._set_sub_hn_params_recursion (node .children [i ])
@@ -1351,8 +1354,7 @@ def _set_hn_params_recursion(self,node:_Node,original_tree_node:_Node):
1351
1354
node .h_g = 0
1352
1355
else :
1353
1356
node .h_g = self .hn_g
1354
- # node.sub_model.set_hn_params(**self.sub_hn_params)
1355
- node .sub_model = self .SubModel .LearnModel (** self .sub_hn_params )
1357
+ node .sub_model .set_hn_params (** self .sub_hn_params )
1356
1358
if not node .leaf :
1357
1359
for i in range (self .c_num_children_vec [node .k ]):
1358
1360
self ._set_hn_params_recursion (node .children [i ],None )
@@ -1382,7 +1384,7 @@ def _set_hn_params_recursion(self,node:_Node,original_tree_node:_Node):
1382
1384
if node .children [i ] is None :
1383
1385
node .children [i ] = _Node (
1384
1386
node .depth + 1 ,
1385
- sub_model = self .SubModel .LearnModel (** self .sub_hn_params ),
1387
+ sub_model = self .SubModel .LearnModel (** self .sub_h0_params ). set_hn_params ( ** self . sub_hn_params ),
1386
1388
)
1387
1389
node .children [i ].k_candidates = child_k_candidates
1388
1390
node .children [i ].ranges = np .array (node .ranges )
@@ -1434,8 +1436,7 @@ def set_h0_params(self,
1434
1436
self ._set_h0_g_recursion (h0_root )
1435
1437
1436
1438
if sub_h0_params is not None :
1437
- self .SubModel .LearnModel (** sub_h0_params )
1438
- self .sub_h0_params = copy .deepcopy (sub_h0_params )
1439
+ self .sub_h0_params = self .SubModel .LearnModel (** sub_h0_params ).get_h0_params ()
1439
1440
if self .h0_metatree_list :
1440
1441
for h0_root in self .h0_metatree_list :
1441
1442
self ._set_sub_h0_params_recursion (h0_root )
@@ -1568,8 +1569,7 @@ def set_hn_params(self,
1568
1569
self ._set_hn_g_recursion (hn_root )
1569
1570
1570
1571
if sub_hn_params is not None :
1571
- self .SubModel .LearnModel (** sub_hn_params )
1572
- self .sub_hn_params = copy .deepcopy (sub_hn_params )
1572
+ self .sub_hn_params = self .SubModel .LearnModel (** self .sub_h0_params ).set_hn_params (** sub_hn_params ).get_hn_params ()
1573
1573
if self .hn_metatree_list :
1574
1574
for hn_root in self .hn_metatree_list :
1575
1575
self ._set_sub_hn_params_recursion (hn_root )
@@ -1595,7 +1595,7 @@ def set_hn_params(self,
1595
1595
0 ,
1596
1596
self ._root_k_candidates ,
1597
1597
self .hn_g ,
1598
- sub_model = self .SubModel .LearnModel (** self .sub_hn_params ),
1598
+ sub_model = self .SubModel .LearnModel (** self .sub_h0_params ). set_hn_params ( ** self . sub_hn_params ),
1599
1599
ranges = self .c_ranges ,
1600
1600
)
1601
1601
)
@@ -1678,7 +1678,7 @@ def _copy_tree_from_sklearn_tree(self,new_node:_Node, original_tree,node_id):
1678
1678
new_node .depth + 1 ,
1679
1679
child_k_candidates ,
1680
1680
h_g = self .h0_g ,
1681
- sub_model = self .SubModel .LearnModel (** self .sub_h0_params ),
1681
+ sub_model = self .SubModel .LearnModel (** self .sub_h0_params ). set_hn_params ( ** self . sub_hn_params ) ,
1682
1682
ranges = np .array (new_node .ranges )
1683
1683
)
1684
1684
if new_node .thresholds is not None :
@@ -1688,7 +1688,7 @@ def _copy_tree_from_sklearn_tree(self,new_node:_Node, original_tree,node_id):
1688
1688
new_node .depth + 1 ,
1689
1689
child_k_candidates ,
1690
1690
h_g = self .h0_g ,
1691
- sub_model = self .SubModel .LearnModel (** self .sub_h0_params ),
1691
+ sub_model = self .SubModel .LearnModel (** self .sub_h0_params ). set_hn_params ( ** self . sub_hn_params ) ,
1692
1692
ranges = np .array (new_node .ranges )
1693
1693
)
1694
1694
if new_node .thresholds is not None :
@@ -1784,9 +1784,9 @@ def _MTRF(self,x_continuous,x_categorical,y,n_estimators=100,**kwargs):
1784
1784
"""
1785
1785
if np .any (self .c_num_children_vec != 2 ):
1786
1786
raise (ParameterFormatError ("MTRF is supported only when all the elements of c_num_children_vec is 2." ))
1787
- if self .SubModel in DISCRETE_MODELS :
1787
+ if self .SubModel in CLF_MODELS :
1788
1788
randomforest = RandomForestClassifier (n_estimators = n_estimators ,max_depth = self .c_max_depth ,** kwargs )
1789
- if self .SubModel in CONTINUOUS_MODELS :
1789
+ if self .SubModel in REG_MODELS :
1790
1790
randomforest = RandomForestRegressor (n_estimators = n_estimators ,max_depth = self .c_max_depth ,** kwargs )
1791
1791
1792
1792
x = np .empty ([y .shape [0 ],self .c_dim_features ])
@@ -1800,7 +1800,7 @@ def _MTRF(self,x_continuous,x_categorical,y,n_estimators=100,**kwargs):
1800
1800
0 ,
1801
1801
self ._root_k_candidates ,
1802
1802
self .hn_g ,
1803
- sub_model = self .SubModel .LearnModel (** self .sub_hn_params ),
1803
+ sub_model = self .SubModel .LearnModel (** self .sub_h0_params ). set_hn_params ( ** self . sub_hn_params ),
1804
1804
ranges = self .c_ranges ,
1805
1805
)
1806
1806
for i in range (n_estimators )
@@ -1957,7 +1957,7 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
1957
1957
def _map_recursion_add_nodes (self ,node :_Node ):
1958
1958
if node .depth == self .c_max_depth or not node .k_candidates : # leaf node
1959
1959
node .h_g = 0.0
1960
- node .sub_model = self .SubModel .LearnModel (** self .sub_hn_params )
1960
+ node .sub_model = self .SubModel .LearnModel (** self .sub_h0_params ). set_hn_params ( ** self . sub_hn_params )
1961
1961
node .leaf = True
1962
1962
node .map_leaf = True
1963
1963
else : # inner node
@@ -2181,7 +2181,7 @@ def _visualize_model_recursion_none(self,tree_graph,depth,k_candidates,ranges,no
2181
2181
child_k_candidates .remove (k )
2182
2182
label_string += f'hn_g={ self .hn_g :.2f} \\ lp_v={ tmp_p_v :.2f} \\ lsub_params={{'
2183
2183
2184
- sub_model = self .SubModel .LearnModel (** self .sub_hn_params )
2184
+ sub_model = self .SubModel .LearnModel (** self .sub_h0_params ). set_hn_params ( ** self . sub_hn_params )
2185
2185
try :
2186
2186
sub_params = sub_model .estimate_params (loss = '0-1' ,dict_out = True )
2187
2187
except :
@@ -2451,3 +2451,34 @@ def pred_and_update(self,x_continuous=None,x_categorical=None,y=None,loss="squar
2451
2451
prediction = self .make_prediction (loss = loss )
2452
2452
self .update_posterior (x_continuous ,x_categorical ,y ,alg_type = 'given_MT' )
2453
2453
return prediction
2454
+
2455
+ def reset_hn_params (self ):
2456
+ """Reset the hyperparameters of the posterior distribution to their initial values.
2457
+
2458
+ They are reset to the output of `self.get_h0_params()`.
2459
+ Note that the parameters of the predictive distribution are also calculated from them.
2460
+ """
2461
+ self .set_hn_params (
2462
+ hn_k_weight_vec = self .h0_k_weight_vec ,
2463
+ hn_g = self .h0_g ,
2464
+ sub_hn_params = self .SubModel .LearnModel (** self .sub_h0_params ).get_hn_params (),
2465
+ hn_metatree_list = self .h0_metatree_list ,
2466
+ hn_metatree_prob_vec = self .h0_metatree_prob_vec ,
2467
+ )
2468
+ return self
2469
+
2470
+ def overwrite_h0_params (self ):
2471
+ """Overwrite the initial values of the hyperparameters of the posterior distribution by the learned values.
2472
+
2473
+ They are overwitten by the output of `self.get_hn_params()`.
2474
+ Note that the parameters of the predictive distribution are also calculated from them.
2475
+ """
2476
+ tmp = self .SubModel .LearnModel (** self .sub_h0_params ).set_hn_params (** self .sub_hn_params )
2477
+ self .set_h0_params (
2478
+ h0_k_weight_vec = self .hn_k_weight_vec ,
2479
+ h0_g = self .hn_g ,
2480
+ sub_h0_params = tmp .overwrite_h0_params ().get_h0_params (),
2481
+ h0_metatree_list = self .hn_metatree_list ,
2482
+ h0_metatree_prob_vec = self .hn_metatree_prob_vec ,
2483
+ )
2484
+ return self
0 commit comments