Skip to content

Commit 159b126

Browse files
committed
Bug fix
1 parent 1249653 commit 159b126

File tree

2 files changed

+39
-13
lines changed

2 files changed

+39
-13
lines changed

bayesml/metatree/_metatree_x_discrete.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
411411
node.h_g = self.h_g
412412
# node.sub_model.set_h_params(**self.sub_h_params)
413413
node.sub_model = self.SubModel.GenModel(**self.sub_h_params)
414-
for i in range(self.c_k):
414+
for i in range(self.c_num_children):
415415
if node.children[i] is not None:
416416
self._set_h_params_recursion(node.children[i],None)
417417
else:
@@ -428,7 +428,7 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
428428
node.leaf = True
429429
else:
430430
node.leaf = False
431-
for i in range(self.c_k):
431+
for i in range(self.c_num_children):
432432
if node.children[i] is None:
433433
node.children[i] = _Node(
434434
node.depth+1,
@@ -611,7 +611,7 @@ def get_params(self):
611611
"""
612612
return {"root":self.root}
613613

614-
def gen_sample(self,sample_size,x=None):
614+
def gen_sample(self,sample_size=None,x=None):
615615
"""Generate a sample from the stochastic data generative model.
616616
617617
Parameters
@@ -630,10 +630,20 @@ def gen_sample(self,sample_size,x=None):
630630
y : numpy ndarray
631631
1 dimensional array whose size is ``sammple_size``.
632632
"""
633-
_check.pos_int(sample_size,'sample_size',DataFormatError)
634-
635-
if x is None:
633+
if x is not None:
634+
_check.int_vecs(x,'x',DataFormatError)
635+
_check.shape_consistency(
636+
x.shape[-1],'x.shape[-1]',
637+
self.c_k,'self.c_k',
638+
ParameterFormatError
639+
)
640+
x = x.reshape(-1,self.c_k)
641+
sample_size = x.shape[0]
642+
elif sample_size is not None:
643+
_check.pos_int(sample_size,'sample_size',DataFormatError)
636644
x = self.rng.choice(self.c_num_children,(sample_size,self.c_k))
645+
else:
646+
raise(DataFormatError("Either of the sample_size and the x must be given as a input."))
637647

638648
if self.SubModel in DISCRETE_MODELS:
639649
y = np.empty(sample_size,dtype=int)

bayesml/metatree/metatree_test.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,32 @@
33
import numpy as np
44
import copy
55

6-
gen_model = metatree.GenModel(3,3,h_g=0.7)
6+
gen_model = metatree.GenModel(5,3,h_g=0.75,SubModel=normal)
77
gen_model.gen_params()
88
gen_model.visualize_model(filename='tree.pdf')
9-
params1 = copy.deepcopy(gen_model.get_params())
10-
gen_model.gen_params(feature_fix=True)
11-
gen_model.visualize_model(filename='tree2.pdf')
12-
params2 = copy.deepcopy(gen_model.get_params())
9+
x,y = gen_model.gen_sample(sample_size=100)
10+
print(x)
11+
print(y)
12+
# gen_model.gen_params(feature_fix=True,tree_fix=True)
13+
# gen_model.visualize_model(filename='tree2.pdf')
14+
# gen_model.gen_params(feature_fix=True,tree_fix=True)
15+
# gen_model.visualize_model(filename='tree3.pdf')
1316

14-
# gen_model2 = metatree.GenModel(3,3)
17+
# gen_model2 = metatree.GenModel(3,3,h_g=0.01)
18+
# gen_model2.gen_params()
1519
# gen_model2.visualize_model(filename='tree3.pdf')
16-
# gen_model2.set_h_params(h_metatree_list=[params1['root'],params2['root']])
20+
# gen_model2.set_params(params1['root'])
21+
# gen_model2.visualize_model(filename='tree4.pdf')
22+
# gen_model2.gen_params(feature_fix=True)
23+
# print(gen_model2.get_h_params())
24+
# gen_model2.set_h_params(sub_h_params={'h_beta':100.0},h_g=0.99)
25+
# print(gen_model2.get_h_params())
26+
# gen_model2.gen_params()
27+
# gen_model2.visualize_model(filename='tree3.pdf')
28+
29+
# gen_model2.set_h_params(h_g=0.99)
30+
# gen_model2.gen_params()
31+
# gen_model2.visualize_model(filename='tree4.pdf')
32+
1733
# gen_model2.gen_params(feature_fix=True)
1834
# gen_model2.visualize_model(filename='tree4.pdf')

0 commit comments

Comments
 (0)