Skip to content

Commit 7f27f22

Browse files
committed
Refine sub_h0_params and sub_hn_params sharing
1 parent 42aa4d6 commit 7f27f22

File tree

2 files changed

+89
-44
lines changed

2 files changed

+89
-44
lines changed

bayesml/metatree/_metatree.py

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@
3939
# linearregression,
4040
exponential,
4141
}
42+
CLF_MODELS = {
43+
bernoulli,
44+
# categorical,
45+
}
46+
REG_MODELS = {
47+
normal,
48+
# multivariate_normal,
49+
# linearregression,
50+
exponential,
51+
poisson,
52+
}
4253

4354
class _Node:
4455
def __init__(self,
@@ -252,8 +263,7 @@ def _gen_params_recursion(self,node:_Node,h_node:_Node,feature_fix,threshold_fix
252263
node.h_g = 0
253264
else:
254265
node.h_g = self.h_g
255-
# node.sub_model.set_h_params(**self.sub_h_params)
256-
node.sub_model = self.SubModel.GenModel(seed=self.rng,**self.sub_h_params)
266+
node.sub_model.set_h_params(**self.sub_h_params)
257267
if node.depth == self.c_max_depth or not node.k_candidates or self.rng.random() > self.h_g: # leaf node
258268
node.sub_model.gen_params()
259269
node.leaf = True
@@ -332,8 +342,7 @@ def _gen_params_recursion_feature_and_tree_fix(self,node:_Node,threshold_fix,thr
332342
node.h_g = 0
333343
else:
334344
node.h_g = self.h_g
335-
# node.sub_model.set_h_params(**self.sub_h_params)
336-
node.sub_model = self.SubModel.GenModel(seed=self.rng,**self.sub_h_params)
345+
node.sub_model.set_h_params(**self.sub_h_params)
337346
if node.leaf: # leaf node
338347
node.sub_model.gen_params()
339348
node.leaf = True
@@ -481,8 +490,7 @@ def _set_h_g_recursion(self,node:_Node):
481490
self._set_h_g_recursion(node.children[i])
482491

483492
def _set_sub_h_params_recursion(self,node:_Node):
484-
# node.sub_model.set_h_params(**self.sub_h_params)
485-
node.sub_model = self.SubModel.GenModel(seed=self.rng,**self.sub_h_params)
493+
node.sub_model.set_h_params(**self.sub_h_params)
486494
if not node.leaf:
487495
for i in range(self.c_num_children_vec[node.k]):
488496
self._set_sub_h_params_recursion(node.children[i])
@@ -493,8 +501,7 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
493501
node.h_g = 0
494502
else:
495503
node.h_g = self.h_g
496-
# node.sub_model.set_h_params(**self.sub_h_params)
497-
node.sub_model = self.SubModel.GenModel(seed=self.rng,**self.sub_h_params)
504+
node.sub_model.set_h_params(**self.sub_h_params)
498505
if not node.leaf:
499506
for i in range(self.c_num_children_vec[node.k]):
500507
self._set_h_params_recursion(node.children[i],None)
@@ -576,8 +583,7 @@ def set_h_params(self,
576583
self._set_h_g_recursion(h_root)
577584

578585
if sub_h_params is not None:
579-
self.SubModel.GenModel(seed=self.rng,**sub_h_params)
580-
self.sub_h_params = copy.deepcopy(sub_h_params)
586+
self.sub_h_params = self.SubModel.GenModel(seed=self.rng,**sub_h_params).get_h_params()
581587
if self.h_metatree_list:
582588
for h_root in self.h_metatree_list:
583589
self._set_sub_h_params_recursion(h_root)
@@ -686,12 +692,12 @@ def gen_params(self,feature_fix=False,threshold_fix=False,tree_fix=False,thresho
686692
if ``'random'``, self.c_ranges will be recursively divided by at random intervals.
687693
"""
688694
if feature_fix:
689-
warnings.warn(
690-
"If feature_fix=True, tree will be generated according to "
691-
+"self.h_g not any element of self.h_metatree_list.",ResultWarning)
692695
if tree_fix:
693696
self._gen_params_recursion_feature_and_tree_fix(self.root,threshold_fix,threshold_type)
694697
else:
698+
warnings.warn(
699+
"If feature_fix=True, tree will be generated according to "
700+
+"self.h_g not any element of self.h_metatree_list.",ResultWarning)
695701
self._gen_params_recursion(self.root,None,True,threshold_fix,threshold_type)
696702
else:
697703
if threshold_fix or tree_fix:
@@ -1277,8 +1283,7 @@ def _set_h0_g_recursion(self,node:_Node):
12771283
self._set_h0_g_recursion(node.children[i])
12781284

12791285
def _set_sub_h0_params_recursion(self,node:_Node):
1280-
# node.sub_model.set_h0_params(**self.sub_h0_params)
1281-
node.sub_model = self.SubModel.LearnModel(**self.sub_h0_params)
1286+
node.sub_model.set_h0_params(**self.sub_h0_params)
12821287
if not node.leaf:
12831288
for i in range(self.c_num_children_vec[node.k]):
12841289
self._set_sub_h0_params_recursion(node.children[i])
@@ -1289,8 +1294,7 @@ def _set_h0_params_recursion(self,node:_Node,original_tree_node:_Node):
12891294
node.h_g = 0
12901295
else:
12911296
node.h_g = self.h0_g
1292-
# node.sub_model.set_h0_params(**self.sub_h0_params)
1293-
node.sub_model = self.SubModel.LearnModel(**self.sub_h0_params)
1297+
node.sub_model.set_h0_params(**self.sub_h0_params)
12941298
if not node.leaf:
12951299
for i in range(self.c_num_children_vec[node.k]):
12961300
self._set_h0_params_recursion(node.children[i],None)
@@ -1339,8 +1343,7 @@ def _set_hn_g_recursion(self,node:_Node):
13391343
self._set_hn_g_recursion(node.children[i])
13401344

13411345
def _set_sub_hn_params_recursion(self,node:_Node):
1342-
# node.sub_model.set_hn_params(**self.sub_hn_params)
1343-
node.sub_model = self.SubModel.LearnModel(**self.sub_hn_params)
1346+
node.sub_model.set_hn_params(**self.sub_hn_params)
13441347
if not node.leaf:
13451348
for i in range(self.c_num_children_vec[node.k]):
13461349
self._set_sub_hn_params_recursion(node.children[i])
@@ -1351,8 +1354,7 @@ def _set_hn_params_recursion(self,node:_Node,original_tree_node:_Node):
13511354
node.h_g = 0
13521355
else:
13531356
node.h_g = self.hn_g
1354-
# node.sub_model.set_hn_params(**self.sub_hn_params)
1355-
node.sub_model = self.SubModel.LearnModel(**self.sub_hn_params)
1357+
node.sub_model.set_hn_params(**self.sub_hn_params)
13561358
if not node.leaf:
13571359
for i in range(self.c_num_children_vec[node.k]):
13581360
self._set_hn_params_recursion(node.children[i],None)
@@ -1382,7 +1384,7 @@ def _set_hn_params_recursion(self,node:_Node,original_tree_node:_Node):
13821384
if node.children[i] is None:
13831385
node.children[i] = _Node(
13841386
node.depth+1,
1385-
sub_model=self.SubModel.LearnModel(**self.sub_hn_params),
1387+
sub_model=self.SubModel.LearnModel(**self.sub_h0_params).set_hn_params(**self.sub_hn_params),
13861388
)
13871389
node.children[i].k_candidates = child_k_candidates
13881390
node.children[i].ranges = np.array(node.ranges)
@@ -1434,8 +1436,7 @@ def set_h0_params(self,
14341436
self._set_h0_g_recursion(h0_root)
14351437

14361438
if sub_h0_params is not None:
1437-
self.SubModel.LearnModel(**sub_h0_params)
1438-
self.sub_h0_params = copy.deepcopy(sub_h0_params)
1439+
self.sub_h0_params = self.SubModel.LearnModel(**sub_h0_params).get_h0_params()
14391440
if self.h0_metatree_list:
14401441
for h0_root in self.h0_metatree_list:
14411442
self._set_sub_h0_params_recursion(h0_root)
@@ -1568,8 +1569,7 @@ def set_hn_params(self,
15681569
self._set_hn_g_recursion(hn_root)
15691570

15701571
if sub_hn_params is not None:
1571-
self.SubModel.LearnModel(**sub_hn_params)
1572-
self.sub_hn_params = copy.deepcopy(sub_hn_params)
1572+
self.sub_hn_params = self.SubModel.LearnModel(**self.sub_h0_params).set_hn_params(**sub_hn_params).get_hn_params()
15731573
if self.hn_metatree_list:
15741574
for hn_root in self.hn_metatree_list:
15751575
self._set_sub_hn_params_recursion(hn_root)
@@ -1595,7 +1595,7 @@ def set_hn_params(self,
15951595
0,
15961596
self._root_k_candidates,
15971597
self.hn_g,
1598-
sub_model=self.SubModel.LearnModel(**self.sub_hn_params),
1598+
sub_model=self.SubModel.LearnModel(**self.sub_h0_params).set_hn_params(**self.sub_hn_params),
15991599
ranges=self.c_ranges,
16001600
)
16011601
)
@@ -1678,7 +1678,7 @@ def _copy_tree_from_sklearn_tree(self,new_node:_Node, original_tree,node_id):
16781678
new_node.depth+1,
16791679
child_k_candidates,
16801680
h_g=self.h0_g,
1681-
sub_model=self.SubModel.LearnModel(**self.sub_h0_params),
1681+
sub_model=self.SubModel.LearnModel(**self.sub_h0_params).set_hn_params(**self.sub_hn_params),
16821682
ranges=np.array(new_node.ranges)
16831683
)
16841684
if new_node.thresholds is not None:
@@ -1688,7 +1688,7 @@ def _copy_tree_from_sklearn_tree(self,new_node:_Node, original_tree,node_id):
16881688
new_node.depth+1,
16891689
child_k_candidates,
16901690
h_g=self.h0_g,
1691-
sub_model=self.SubModel.LearnModel(**self.sub_h0_params),
1691+
sub_model=self.SubModel.LearnModel(**self.sub_h0_params).set_hn_params(**self.sub_hn_params),
16921692
ranges=np.array(new_node.ranges)
16931693
)
16941694
if new_node.thresholds is not None:
@@ -1784,9 +1784,9 @@ def _MTRF(self,x_continuous,x_categorical,y,n_estimators=100,**kwargs):
17841784
"""
17851785
if np.any(self.c_num_children_vec != 2):
17861786
raise(ParameterFormatError("MTRF is supported only when all the elements of c_num_children_vec is 2."))
1787-
if self.SubModel in DISCRETE_MODELS:
1787+
if self.SubModel in CLF_MODELS:
17881788
randomforest = RandomForestClassifier(n_estimators=n_estimators,max_depth=self.c_max_depth,**kwargs)
1789-
if self.SubModel in CONTINUOUS_MODELS:
1789+
if self.SubModel in REG_MODELS:
17901790
randomforest = RandomForestRegressor(n_estimators=n_estimators,max_depth=self.c_max_depth,**kwargs)
17911791

17921792
x = np.empty([y.shape[0],self.c_dim_features])
@@ -1800,7 +1800,7 @@ def _MTRF(self,x_continuous,x_categorical,y,n_estimators=100,**kwargs):
18001800
0,
18011801
self._root_k_candidates,
18021802
self.hn_g,
1803-
sub_model=self.SubModel.LearnModel(**self.sub_hn_params),
1803+
sub_model=self.SubModel.LearnModel(**self.sub_h0_params).set_hn_params(**self.sub_hn_params),
18041804
ranges=self.c_ranges,
18051805
)
18061806
for i in range(n_estimators)
@@ -1957,7 +1957,7 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
19571957
def _map_recursion_add_nodes(self,node:_Node):
19581958
if node.depth == self.c_max_depth or not node.k_candidates: # leaf node
19591959
node.h_g = 0.0
1960-
node.sub_model = self.SubModel.LearnModel(**self.sub_hn_params)
1960+
node.sub_model = self.SubModel.LearnModel(**self.sub_h0_params).set_hn_params(**self.sub_hn_params)
19611961
node.leaf = True
19621962
node.map_leaf = True
19631963
else: # inner node
@@ -2181,7 +2181,7 @@ def _visualize_model_recursion_none(self,tree_graph,depth,k_candidates,ranges,no
21812181
child_k_candidates.remove(k)
21822182
label_string += f'hn_g={self.hn_g:.2f}\\lp_v={tmp_p_v:.2f}\\lsub_params={{'
21832183

2184-
sub_model = self.SubModel.LearnModel(**self.sub_hn_params)
2184+
sub_model = self.SubModel.LearnModel(**self.sub_h0_params).set_hn_params(**self.sub_hn_params)
21852185
try:
21862186
sub_params = sub_model.estimate_params(loss='0-1',dict_out=True)
21872187
except:
@@ -2451,3 +2451,34 @@ def pred_and_update(self,x_continuous=None,x_categorical=None,y=None,loss="squar
24512451
prediction = self.make_prediction(loss=loss)
24522452
self.update_posterior(x_continuous,x_categorical,y,alg_type='given_MT')
24532453
return prediction
2454+
2455+
def reset_hn_params(self):
2456+
"""Reset the hyperparameters of the posterior distribution to their initial values.
2457+
2458+
They are reset to the output of `self.get_h0_params()`.
2459+
Note that the parameters of the predictive distribution are also calculated from them.
2460+
"""
2461+
self.set_hn_params(
2462+
hn_k_weight_vec=self.h0_k_weight_vec,
2463+
hn_g=self.h0_g,
2464+
sub_hn_params=self.SubModel.LearnModel(**self.sub_h0_params).get_hn_params(),
2465+
hn_metatree_list=self.h0_metatree_list,
2466+
hn_metatree_prob_vec=self.h0_metatree_prob_vec,
2467+
)
2468+
return self
2469+
2470+
def overwrite_h0_params(self):
2471+
"""Overwrite the initial values of the hyperparameters of the posterior distribution by the learned values.
2472+
2473+
They are overwitten by the output of `self.get_hn_params()`.
2474+
Note that the parameters of the predictive distribution are also calculated from them.
2475+
"""
2476+
tmp = self.SubModel.LearnModel(**self.sub_h0_params).set_hn_params(**self.sub_hn_params)
2477+
self.set_h0_params(
2478+
h0_k_weight_vec=self.hn_k_weight_vec,
2479+
h0_g=self.hn_g,
2480+
sub_h0_params=tmp.overwrite_h0_params().get_h0_params(),
2481+
h0_metatree_list=self.hn_metatree_list,
2482+
h0_metatree_prob_vec=self.hn_metatree_prob_vec,
2483+
)
2484+
return self

bayesml/metatree/metatree_test.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,44 @@
77
import time
88

99
dim_continuous = 2
10-
dim_categorical = 0
10+
dim_categorical = 2
1111

1212
gen_model = metatree.GenModel(
1313
c_dim_continuous=dim_continuous,
1414
c_dim_categorical=dim_categorical,
15-
h_g=1.0,
15+
c_max_depth=3,
16+
h_g=0.75,
1617
SubModel=normal,
1718
sub_h_params={'h_kappa':0.1})
18-
# sub_h_params={'h_alpha':0.1,'h_beta':0.1})
19-
gen_model.gen_params(threshold_type='even')
20-
# gen_model.visualize_model(filename='tree.pdf')
19+
# # sub_h_params={'h_alpha':0.3,'h_beta':0.3})
20+
gen_model.gen_params(threshold_type='random')
21+
gen_model.visualize_model(filename='tree.pdf')
2122

2223
x_continuous,x_categorical,y = gen_model.gen_sample(200)
2324

2425
learn_model = metatree.LearnModel(
2526
c_dim_continuous=dim_continuous,
2627
c_dim_categorical=dim_categorical,
2728
c_num_children_vec=2,
29+
c_max_depth=2,
30+
h0_g=0.75,
2831
SubModel=normal,
2932
sub_h0_params={'h0_kappa':0.1})
30-
# sub_h0_params={'h0_alpha':0.1,'h0_beta':0.1})
33+
# sub_h0_params={'h0_alpha':0.3,'h0_beta':0.3})
3134

32-
start = time.time()
33-
learn_model.update_posterior(x_continuous,x_categorical,y)
34-
end = time.time()
35+
hn_params = learn_model.get_hn_params()
36+
hn_params['sub_hn_params']['hn_kappa'] = 0.2
37+
print(learn_model.get_h0_params())
38+
learn_model.set_hn_params(sub_hn_params=hn_params['sub_hn_params'])
39+
print(learn_model.get_hn_params())
3540

36-
print(end-start)
41+
# # start = time.time()
42+
learn_model.update_posterior(x_continuous,x_categorical,y,n_estimators=5)
43+
# # end = time.time()
44+
45+
learn_model.overwrite_h0_params()
46+
print(learn_model.get_h0_params())
47+
print(learn_model.get_hn_params())
48+
49+
# learn_model.visualize_posterior()
50+
# # print(end-start)

0 commit comments

Comments
 (0)