Skip to content

Commit 41279d7

Browse files
committed
Pass self.rng to SubModel.GenModel
1 parent 99cdd2e commit 41279d7

File tree

2 files changed

+20
-22
lines changed

2 files changed

+20
-22
lines changed

bayesml/metatree/_metatree_x_discrete.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(
159159
self.c_num_children,
160160
list(range(self.c_k)),
161161
self.h_g,
162-
sub_model=self.SubModel.GenModel(**self.sub_h_params),
162+
sub_model=self.SubModel.GenModel(seed=self.rng,**self.sub_h_params),
163163
leaf=True
164164
)
165165

@@ -181,7 +181,7 @@ def _gen_params_recursion(self,node:_Node,h_node:_Node,feature_fix):
181181
else:
182182
node.h_g = self.h_g
183183
# node.sub_model.set_h_params(**self.sub_h_params)
184-
node.sub_model = self.SubModel.GenModel(**self.sub_h_params)
184+
node.sub_model = self.SubModel.GenModel(seed=self.rng,**self.sub_h_params)
185185
if node.depth == self.c_d_max or node.depth == self.c_k or self.rng.random() > self.h_g: # leaf node
186186
node.sub_model.gen_params()
187187
node.leaf = True
@@ -198,7 +198,7 @@ def _gen_params_recursion(self,node:_Node,h_node:_Node,feature_fix):
198198
node.depth+1,
199199
self.c_num_children,
200200
h_g=self.h_g,
201-
sub_model=self.SubModel.GenModel(**self.sub_h_params),
201+
sub_model=self.SubModel.GenModel(seed=self.rng,**self.sub_h_params),
202202
)
203203
node.children[i].k_candidates = child_k_candidates
204204
self._gen_params_recursion(node.children[i],None,feature_fix)
@@ -228,7 +228,7 @@ def _gen_params_recursion(self,node:_Node,h_node:_Node,feature_fix):
228228
node.depth+1,
229229
self.c_num_children,
230230
h_g=self.h_g,
231-
sub_model=self.SubModel.GenModel(**self.sub_h_params),
231+
sub_model=self.SubModel.GenModel(seed=self.rng,**self.sub_h_params),
232232
)
233233
node.children[i].k_candidates = child_k_candidates
234234
self._gen_params_recursion(node.children[i],h_node.children[i],feature_fix)
@@ -249,7 +249,7 @@ def _gen_params_recursion_tree_fix(self,node:_Node,h_node:_Node,feature_fix):
249249
else:
250250
node.h_g = self.h_g
251251
# node.sub_model.set_h_params(**self.sub_h_params)
252-
node.sub_model = self.SubModel.GenModel(**self.sub_h_params)
252+
node.sub_model = self.SubModel.GenModel(seed=self.rng,**self.sub_h_params)
253253
if node.leaf: # leaf node
254254
node.sub_model.gen_params()
255255
node.leaf = True
@@ -325,7 +325,7 @@ def _set_params_recursion(self,node:_Node,original_tree_node:_Node):
325325
self.c_num_children,
326326
child_k_candidates,
327327
self.h_g,
328-
sub_model=self.SubModel.GenModel(**self.sub_h_params)
328+
sub_model=self.SubModel.GenModel(seed=self.rng,**self.sub_h_params)
329329
)
330330
self._set_params_recursion(node.children[i],original_tree_node.children[i])
331331

@@ -401,7 +401,7 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
401401
else:
402402
node.h_g = self.h_g
403403
# node.sub_model.set_h_params(**self.sub_h_params)
404-
node.sub_model = self.SubModel.GenModel(**self.sub_h_params)
404+
node.sub_model = self.SubModel.GenModel(seed=self.rng,**self.sub_h_params)
405405
for i in range(self.c_num_children):
406406
if node.children[i] is not None:
407407
self._set_h_params_recursion(node.children[i],None)
@@ -424,7 +424,7 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
424424
node.children[i] = _Node(
425425
node.depth+1,
426426
self.c_num_children,
427-
sub_model=self.SubModel.GenModel(**self.sub_h_params),
427+
sub_model=self.SubModel.GenModel(seed=self.rng,**self.sub_h_params),
428428
)
429429
self._set_h_params_recursion(node.children[i],original_tree_node.children[i])
430430

@@ -471,7 +471,7 @@ def set_h_params(self,
471471
self._set_h_params_recursion(h_root,None)
472472

473473
if sub_h_params is not None:
474-
self.SubModel.GenModel(**sub_h_params)
474+
self.SubModel.GenModel(seed=self.rng,**sub_h_params)
475475
self.sub_h_params = copy.deepcopy(sub_h_params)
476476
if self.h_metatree_list:
477477
for h_root in self.h_metatree_list:
@@ -587,7 +587,7 @@ def set_params(self,root=None):
587587
self.c_num_children,
588588
list(range(self.c_k)),
589589
self.h_g,
590-
sub_model=self.SubModel.GenModel(**self.sub_h_params),
590+
sub_model=self.SubModel.GenModel(seed=self.rng,**self.sub_h_params),
591591
leaf=True
592592
)
593593
self._set_params_recursion(self.root,root)

bayesml/metatree/metatree_test.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,15 @@
55
import numpy as np
66
import copy
77

8-
gen_model = GenModel(4,3,2,h_g=0.75)
9-
gen_model.gen_params()
10-
gen_model.visualize_model('tree.pdf')
11-
x,y = gen_model.gen_sample(1000)
8+
gen_model1 = GenModel(4,3,2,h_g=0.75,seed=0)
9+
gen_model1.gen_params()
10+
gen_model1.visualize_model('tree.pdf')
11+
x1,y1 = gen_model1.gen_sample(10)
1212

13-
learn_model = LearnModel(4,3,2)
14-
learn_model.update_posterior(x,y,n_estimators=1)
15-
learn_model.visualize_posterior('tree2.pdf')
13+
gen_model2 = GenModel(4,3,2,h_g=0.75,seed=0)
14+
gen_model2.gen_params()
15+
gen_model2.visualize_model('tree2.pdf')
16+
x2,y2 = gen_model2.gen_sample(10)
1617

17-
learn_model.calc_pred_dist(np.zeros(4,dtype=int))
18-
print(learn_model.make_prediction(loss='squared'))
19-
20-
learn_model.calc_pred_dist(np.ones(4,dtype=int))
21-
print(learn_model.make_prediction(loss='squared'))
18+
print(x1-x2)
19+
print(y1-y2)

0 commit comments

Comments
 (0)