Skip to content

Commit de99d95

Browse files
committed
Revise set_h_params
1 parent a9942ec commit de99d95

File tree

2 files changed

+69
-7
lines changed

2 files changed

+69
-7
lines changed

bayesml/metatree/_metatree_x_discrete.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,45 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
317317

318318
return node_id
319319

320+
def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
321+
""" copy parameters from a fixed tree
322+
323+
Parameters
324+
----------
325+
node : object
326+
a object from _Node class
327+
original_tree_node : object
328+
a object from _Node class
329+
"""
330+
if original_tree_node is None:
331+
if node.depth == self.c_d_max:
332+
node.h_g = 0
333+
else:
334+
node.h_g = self.h_g
335+
node.sub_model.set_h_params(**self.sub_h_params)
336+
for i in range(self.c_k):
337+
if node.children[i] is not None:
338+
self._set_h_params_recursion(node.children[i],None)
339+
else:
340+
node.h_g = original_tree_node.h_g
341+
try:
342+
sub_h_params = node.sub_model.get_h_params()
343+
except:
344+
sub_h_params = node.sub_model.get_hn_params()
345+
node.sub_model.set_h_params(
346+
*sub_h_params.values()
347+
)
348+
if original_tree_node.leaf or node.depth == self.c_d_max: # leaf node
349+
node.leaf = True
350+
if node.depth == self.c_d_max:
351+
node.h_g = 0
352+
else:
353+
node.leaf = False
354+
for i in range(self.c_k):
355+
if node.children[i] is None:
356+
node.children[i] = _Node(node.depth+1,self.c_k)
357+
self._set_h_params_recursion(node.children[i],original_tree_node.children[i])
358+
320359
def set_h_params(self,
321360
h_k_prob_vec = None,
322361
h_g=None,
@@ -355,12 +394,29 @@ def set_h_params(self,
355394

356395
if h_g is not None:
357396
self.h_g = _check.float_in_closed01(h_g,'h_g',ParameterFormatError)
397+
if self.h_metatree_list:
398+
for h_root in self.h_metatree_list:
399+
self._set_h_params_recursion(h_root,None)
400+
358401

359402
if sub_h_params is not None:
403+
self.SubModel.GenModel(**sub_h_params)
360404
self.sub_h_params = copy.deepcopy(sub_h_params)
361-
self.SubModel.GenModel(**self.sub_h_params)
405+
if self.h_metatree_list:
406+
for h_root in self.h_metatree_list:
407+
self._set_h_params_recursion(h_root,None)
362408

363409
if h_metatree_list is not None:
410+
if not isinstance(h_metatree_list,list):
411+
raise(ParameterFormatError(
412+
"h_metatree_list must be a list"
413+
))
414+
if h_metatree_list:
415+
for h_root in h_metatree_list:
416+
if type(h_root) is not _Node:
417+
raise(ParameterFormatError(
418+
"all elements of h_metatree_list must be instances of metatree._Node or empty"
419+
))
364420
self.h_metatree_list = copy.deepcopy(h_metatree_list)
365421
if h_metatree_prob_vec is not None:
366422
self.h_metatree_prob_vec = np.copy(
@@ -370,9 +426,12 @@ def set_h_params(self,
370426
ParameterFormatError
371427
)
372428
)
373-
elif len(self.h_metatree_list) > 0:
374-
metatree_num = len(self.h_metatree_list)
375-
self.h_metatree_prob_vec = np.ones(metatree_num) / metatree_num
429+
else:
430+
if h_metatree_list:
431+
metatree_num = len(self.h_metatree_list)
432+
self.h_metatree_prob_vec = np.ones(metatree_num) / metatree_num
433+
else:
434+
self.h_metatree_prob_vec = None
376435
elif h_metatree_prob_vec is not None:
377436
self.h_metatree_prob_vec = np.copy(
378437
_check.float_vec_sum_1(
@@ -387,11 +446,15 @@ def set_h_params(self,
387446
raise(ParameterFormatError(
388447
"Length of h_metatree_list and dimension of h_metatree_prob_vec must be the same."
389448
))
390-
else:
449+
elif self.h_metatree_prob_vec is None:
391450
if len(self.h_metatree_list) > 0:
392451
raise(ParameterFormatError(
393452
"Length of h_metatree_list must be zero when self.h_metatree_prob_vec is None."
394453
))
454+
else:
455+
raise(ParameterFormatError(
456+
"self.h_metatree_prob_vec must be None or a numpy.ndarray."
457+
))
395458

396459
def get_h_params(self):
397460
"""Get the hyperparameters of the prior distribution.

bayesml/metatree/metatree_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@
33
import numpy as np
44

55
gen_model = metatree.GenModel(3,3,h_g=0.7)
6-
gen_model.gen_params()
7-
gen_model.visualize_model()
6+
print(gen_model.get_h_params())

0 commit comments

Comments
 (0)