@@ -317,6 +317,45 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
317
317
318
318
return node_id
319
319
320
+ def _set_h_params_recursion (self ,node :_Node ,original_tree_node :_Node ):
321
+ """ copy parameters from a fixed tree
322
+
323
+ Parameters
324
+ ----------
325
+ node : object
326
+ a object from _Node class
327
+ original_tree_node : object
328
+ a object from _Node class
329
+ """
330
+ if original_tree_node is None :
331
+ if node .depth == self .c_d_max :
332
+ node .h_g = 0
333
+ else :
334
+ node .h_g = self .h_g
335
+ node .sub_model .set_h_params (** self .sub_h_params )
336
+ for i in range (self .c_k ):
337
+ if node .children [i ] is not None :
338
+ self ._set_h_params_recursion (node .children [i ],None )
339
+ else :
340
+ node .h_g = original_tree_node .h_g
341
+ try :
342
+ sub_h_params = node .sub_model .get_h_params ()
343
+ except :
344
+ sub_h_params = node .sub_model .get_hn_params ()
345
+ node .sub_model .set_h_params (
346
+ * sub_h_params .values ()
347
+ )
348
+ if original_tree_node .leaf or node .depth == self .c_d_max : # leaf node
349
+ node .leaf = True
350
+ if node .depth == self .c_d_max :
351
+ node .h_g = 0
352
+ else :
353
+ node .leaf = False
354
+ for i in range (self .c_k ):
355
+ if node .children [i ] is None :
356
+ node .children [i ] = _Node (node .depth + 1 ,self .c_k )
357
+ self ._set_h_params_recursion (node .children [i ],original_tree_node .children [i ])
358
+
320
359
def set_h_params (self ,
321
360
h_k_prob_vec = None ,
322
361
h_g = None ,
@@ -355,12 +394,29 @@ def set_h_params(self,
355
394
356
395
if h_g is not None :
357
396
self .h_g = _check .float_in_closed01 (h_g ,'h_g' ,ParameterFormatError )
397
+ if self .h_metatree_list :
398
+ for h_root in self .h_metatree_list :
399
+ self ._set_h_params_recursion (h_root ,None )
400
+
358
401
359
402
if sub_h_params is not None :
403
+ self .SubModel .GenModel (** sub_h_params )
360
404
self .sub_h_params = copy .deepcopy (sub_h_params )
361
- self .SubModel .GenModel (** self .sub_h_params )
405
+ if self .h_metatree_list :
406
+ for h_root in self .h_metatree_list :
407
+ self ._set_h_params_recursion (h_root ,None )
362
408
363
409
if h_metatree_list is not None :
410
+ if not isinstance (h_metatree_list ,list ):
411
+ raise (ParameterFormatError (
412
+ "h_metatree_list must be a list"
413
+ ))
414
+ if h_metatree_list :
415
+ for h_root in h_metatree_list :
416
+ if type (h_root ) is not _Node :
417
+ raise (ParameterFormatError (
418
+ "all elements of h_metatree_list must be instances of metatree._Node or empty"
419
+ ))
364
420
self .h_metatree_list = copy .deepcopy (h_metatree_list )
365
421
if h_metatree_prob_vec is not None :
366
422
self .h_metatree_prob_vec = np .copy (
@@ -370,9 +426,12 @@ def set_h_params(self,
370
426
ParameterFormatError
371
427
)
372
428
)
373
- elif len (self .h_metatree_list ) > 0 :
374
- metatree_num = len (self .h_metatree_list )
375
- self .h_metatree_prob_vec = np .ones (metatree_num ) / metatree_num
429
+ else :
430
+ if h_metatree_list :
431
+ metatree_num = len (self .h_metatree_list )
432
+ self .h_metatree_prob_vec = np .ones (metatree_num ) / metatree_num
433
+ else :
434
+ self .h_metatree_prob_vec = None
376
435
elif h_metatree_prob_vec is not None :
377
436
self .h_metatree_prob_vec = np .copy (
378
437
_check .float_vec_sum_1 (
@@ -387,11 +446,15 @@ def set_h_params(self,
387
446
raise (ParameterFormatError (
388
447
"Length of h_metatree_list and dimension of h_metatree_prob_vec must be the same."
389
448
))
390
- else :
449
+ elif self . h_metatree_prob_vec is None :
391
450
if len (self .h_metatree_list ) > 0 :
392
451
raise (ParameterFormatError (
393
452
"Length of h_metatree_list must be zero when self.h_metatree_prob_vec is None."
394
453
))
454
+ else :
455
+ raise (ParameterFormatError (
456
+ "self.h_metatree_prob_vec must be None or a numpy.ndarray."
457
+ ))
395
458
396
459
def get_h_params (self ):
397
460
"""Get the hyperparameters of the prior distribution.
0 commit comments