Skip to content

Commit 09ebf4f

Browse files
committed
Revise set_h_params in GenModel
1 parent 41279d7 commit 09ebf4f

File tree

1 file changed

+40
-5
lines changed

1 file changed

+40
-5
lines changed

bayesml/metatree/_metatree_x_discrete.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,22 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
385385

386386
return node_id
387387

388+
def _set_h_g_recursion(self,node:_Node):
389+
if node.depth == self.c_d_max:
390+
node.h_g = 0
391+
else:
392+
node.h_g = self.h_g
393+
for i in range(self.c_num_children):
394+
if node.children[i] is not None:
395+
self._set_h_g_recursion(node.children[i])
396+
397+
def _set_sub_h_params_recursion(self,node:_Node):
398+
# node.sub_model.set_h_params(**self.sub_h_params)
399+
node.sub_model = self.SubModel.GenModel(seed=self.rng,**self.sub_h_params)
400+
for i in range(self.c_num_children):
401+
if node.children[i] is not None:
402+
self._set_sub_h_params_recursion(node.children[i])
403+
388404
def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
389405
""" copy parameters from a fixed tree
390406
@@ -418,12 +434,16 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
418434
if original_tree_node.leaf or node.depth == self.c_d_max: # leaf node
419435
node.leaf = True
420436
else:
437+
node.k = original_tree_node.k
438+
child_k_candidates = copy.copy(node.k_candidates)
439+
child_k_candidates.remove(node.k)
421440
node.leaf = False
422441
for i in range(self.c_num_children):
423442
if node.children[i] is None:
424443
node.children[i] = _Node(
425444
node.depth+1,
426445
self.c_num_children,
446+
child_k_candidates,
427447
sub_model=self.SubModel.GenModel(seed=self.rng,**self.sub_h_params),
428448
)
429449
self._set_h_params_recursion(node.children[i],original_tree_node.children[i])
@@ -468,14 +488,14 @@ def set_h_params(self,
468488
self.h_g = _check.float_in_closed01(h_g,'h_g',ParameterFormatError)
469489
if self.h_metatree_list:
470490
for h_root in self.h_metatree_list:
471-
self._set_h_params_recursion(h_root,None)
491+
self._set_h_g_recursion(h_root)
472492

473493
if sub_h_params is not None:
474494
self.SubModel.GenModel(seed=self.rng,**sub_h_params)
475495
self.sub_h_params = copy.deepcopy(sub_h_params)
476496
if self.h_metatree_list:
477497
for h_root in self.h_metatree_list:
478-
self._set_h_params_recursion(h_root,None)
498+
self._set_sub_h_params_recursion(h_root)
479499

480500
if h_metatree_list is not None:
481501
if not isinstance(h_metatree_list,list):
@@ -488,7 +508,22 @@ def set_h_params(self,
488508
raise(ParameterFormatError(
489509
"all elements of h_metatree_list must be instances of metatree._Node or empty"
490510
))
491-
self.h_metatree_list = copy.deepcopy(h_metatree_list)
511+
diff = len(h_metatree_list) - len(self.h_metatree_list)
512+
if diff < 0:
513+
del self.h_metatree_list[diff:]
514+
elif diff > 0:
515+
for i in range(diff):
516+
self.h_metatree_list.append(
517+
_Node(
518+
0,
519+
self.c_num_children,
520+
list(range(self.c_k)),
521+
self.h_g,
522+
sub_model=self.SubModel.GenModel(seed=self.rng,**self.sub_h_params),
523+
)
524+
)
525+
for i in range(len(self.h_metatree_list)):
526+
self._set_h_params_recursion(self.h_metatree_list[i],h_metatree_list[i])
492527
if h_metatree_prob_vec is not None:
493528
self.h_metatree_prob_vec = np.copy(
494529
_check.float_vec_sum_1(
@@ -962,7 +997,7 @@ def set_h0_params(self,
962997
h0_g : float, optional
963998
A real number in :math:`[0, 1]`, by default None
964999
sub_h0_params : dict, optional
965-
h0_params for self.SubModel, by default None
1000+
h0_params for self.SubModel.LearnModel, by default None
9661001
h0_metatree_list : list of metatree._Node, optional
9671002
Root nodes of meta-trees, by default None
9681003
h0_metatree_prob_vec : numpy.ndarray, optional
@@ -1096,7 +1131,7 @@ def set_hn_params(self,
10961131
hn_g : float, optional
10971132
A real number in :math:`[0, 1]`, by default None
10981133
sub_hn_params : dict, optional
1099-
hn_params for self.SubModel, by default None
1134+
hn_params for self.SubModel.LearnModel, by default None
11001135
hn_metatree_list : list of metatree._Node, optional
11011136
Root nodes of meta-trees, by default None
11021137
hn_metatree_prob_vec : numpy.ndarray, optional

0 commit comments

Comments
 (0)