@@ -385,6 +385,22 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
385
385
386
386
return node_id
387
387
388
+ def _set_h_g_recursion (self ,node :_Node ):
389
+ if node .depth == self .c_d_max :
390
+ node .h_g = 0
391
+ else :
392
+ node .h_g = self .h_g
393
+ for i in range (self .c_num_children ):
394
+ if node .children [i ] is not None :
395
+ self ._set_h_g_recursion (node .children [i ])
396
+
397
+ def _set_sub_h_params_recursion (self ,node :_Node ):
398
+ # node.sub_model.set_h_params(**self.sub_h_params)
399
+ node .sub_model = self .SubModel .GenModel (seed = self .rng ,** self .sub_h_params )
400
+ for i in range (self .c_num_children ):
401
+ if node .children [i ] is not None :
402
+ self ._set_sub_h_params_recursion (node .children [i ])
403
+
388
404
def _set_h_params_recursion (self ,node :_Node ,original_tree_node :_Node ):
389
405
""" copy parameters from a fixed tree
390
406
@@ -418,12 +434,16 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
418
434
if original_tree_node .leaf or node .depth == self .c_d_max : # leaf node
419
435
node .leaf = True
420
436
else :
437
+ node .k = original_tree_node .k
438
+ child_k_candidates = copy .copy (node .k_candidates )
439
+ child_k_candidates .remove (node .k )
421
440
node .leaf = False
422
441
for i in range (self .c_num_children ):
423
442
if node .children [i ] is None :
424
443
node .children [i ] = _Node (
425
444
node .depth + 1 ,
426
445
self .c_num_children ,
446
+ child_k_candidates ,
427
447
sub_model = self .SubModel .GenModel (seed = self .rng ,** self .sub_h_params ),
428
448
)
429
449
self ._set_h_params_recursion (node .children [i ],original_tree_node .children [i ])
@@ -468,14 +488,14 @@ def set_h_params(self,
468
488
self .h_g = _check .float_in_closed01 (h_g ,'h_g' ,ParameterFormatError )
469
489
if self .h_metatree_list :
470
490
for h_root in self .h_metatree_list :
471
- self ._set_h_params_recursion (h_root , None )
491
+ self ._set_h_g_recursion (h_root )
472
492
473
493
if sub_h_params is not None :
474
494
self .SubModel .GenModel (seed = self .rng ,** sub_h_params )
475
495
self .sub_h_params = copy .deepcopy (sub_h_params )
476
496
if self .h_metatree_list :
477
497
for h_root in self .h_metatree_list :
478
- self ._set_h_params_recursion (h_root , None )
498
+ self ._set_sub_h_params_recursion (h_root )
479
499
480
500
if h_metatree_list is not None :
481
501
if not isinstance (h_metatree_list ,list ):
@@ -488,7 +508,22 @@ def set_h_params(self,
488
508
raise (ParameterFormatError (
489
509
"all elements of h_metatree_list must be instances of metatree._Node or empty"
490
510
))
491
- self .h_metatree_list = copy .deepcopy (h_metatree_list )
511
+ diff = len (h_metatree_list ) - len (self .h_metatree_list )
512
+ if diff < 0 :
513
+ del self .h_metatree_list [diff :]
514
+ elif diff > 0 :
515
+ for i in range (diff ):
516
+ self .h_metatree_list .append (
517
+ _Node (
518
+ 0 ,
519
+ self .c_num_children ,
520
+ list (range (self .c_k )),
521
+ self .h_g ,
522
+ sub_model = self .SubModel .GenModel (seed = self .rng ,** self .sub_h_params ),
523
+ )
524
+ )
525
+ for i in range (len (self .h_metatree_list )):
526
+ self ._set_h_params_recursion (self .h_metatree_list [i ],h_metatree_list [i ])
492
527
if h_metatree_prob_vec is not None :
493
528
self .h_metatree_prob_vec = np .copy (
494
529
_check .float_vec_sum_1 (
@@ -962,7 +997,7 @@ def set_h0_params(self,
962
997
h0_g : float, optional
963
998
A real number in :math:`[0, 1]`, by default None
964
999
sub_h0_params : dict, optional
965
- h0_params for self.SubModel, by default None
1000
+ h0_params for self.SubModel.LearnModel , by default None
966
1001
h0_metatree_list : list of metatree._Node, optional
967
1002
Root nodes of meta-trees, by default None
968
1003
h0_metatree_prob_vec : numpy.ndarray, optional
@@ -1096,7 +1131,7 @@ def set_hn_params(self,
1096
1131
hn_g : float, optional
1097
1132
A real number in :math:`[0, 1]`, by default None
1098
1133
sub_hn_params : dict, optional
1099
- hn_params for self.SubModel, by default None
1134
+ hn_params for self.SubModel.LearnModel , by default None
1100
1135
hn_metatree_list : list of metatree._Node, optional
1101
1136
Root nodes of meta-trees, by default None
1102
1137
hn_metatree_prob_vec : numpy.ndarray, optional
0 commit comments