Skip to content

Commit b6f2101

Browse files
committed
Remove _GenNode and add _Node
1 parent de4d369 commit b6f2101

File tree

2 files changed

+112
-56
lines changed

2 files changed

+112
-56
lines changed

bayesml/contexttree/_contexttree.py

Lines changed: 94 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
_CMAP = plt.get_cmap("Blues")
1717

18-
class _GenNode:
18+
class _Node:
1919
""" The node class used by generative model and the prior distribution
2020
2121
Parameters
@@ -28,14 +28,14 @@ class _GenNode:
2828
def __init__(self,
2929
depth,
3030
c_k,
31-
h_g,
3231
):
3332
self.depth = depth
3433
self.children = [None for i in range(c_k)] # child nodes
35-
self.h_g = h_g
34+
self.h_g = 0.5
3635
self.h_beta_vec = np.ones(c_k) / 2
3736
self.theta_vec = np.ones(c_k) / c_k
3837
self.leaf = False
38+
self.map_leaf = False
3939

4040
class GenModel(base.Generative):
4141
""" The stochastice data generative model and the prior distribution
@@ -46,11 +46,7 @@ class GenModel(base.Generative):
4646
A positive integer
4747
c_d_max : int, optional
4848
A positive integer, by default 10
49-
theta_vec : numpy.ndarray, optional
50-
A vector of real numbers in :math:`[0, 1]`,
51-
by default [1/c_k, 1/c_k, ... , 1/c_k]
52-
Sum of its elements must be 1.0.
53-
root : contexttree._GenNode, optional
49+
root : contexttree._Node, optional
5450
A root node of a context tree,
5551
by default a tree consists of only one node.
5652
h_g : float, optional
@@ -59,7 +55,7 @@ class GenModel(base.Generative):
5955
A vector of positive real numbers,
6056
by default [1/2, 1/2, ... , 1/2].
6157
If a single real number is input, it will be broadcasted.
62-
h_root : contexttree._GenNode, optional
58+
h_root : contexttree._Node, optional
6359
A root node of a superposed tree for hyperparameters
6460
by default ``None``
6561
seed : {None, int}, optional
@@ -71,7 +67,6 @@ def __init__(
7167
c_k,
7268
c_d_max=10,
7369
*,
74-
theta_vec=None,
7570
root=None,
7671
h_g=0.5,
7772
h_beta_vec=None,
@@ -95,80 +90,132 @@ def __init__(
9590
)
9691

9792
# params
98-
self.theta_vec = np.ones(self.c_k) / self.c_k
99-
self.root = _GenNode(0,self.c_k,self.h_g)
93+
self.root = _Node(0,self.c_k)
94+
self.root.h_g = self.h_g
95+
self.root.h_beta_vec[:] = self.h_beta_vec
10096
self.root.leaf = True
10197

10298
self.set_params(
103-
theta_vec,
10499
root,
105100
)
106101

107-
def _gen_params_recursion(self,node,h_node):
102+
def _gen_params_recursion(self,node:_Node,h_node:_Node):
108103
""" generate parameters recursively"""
109-
if node.depth == self.c_d_max:
110-
node.h_g = 0
111104
if h_node is None:
105+
if node.depth == self.c_d_max:
106+
node.h_g = 0
107+
else:
108+
node.h_g = self.h_g
109+
node.h_beta_vec[:] = self.h_beta_vec
112110
if node.depth == self.c_d_max or self.rng.random() > self.h_g: # 葉ノード
113111
node.theta_vec[:] = self.rng.dirichlet(self.h_beta_vec)
114112
node.leaf = True
115113
else: # 内部ノード
116114
node.leaf = False
117115
for i in range(self.c_k):
118116
if node.children[i] is None:
119-
node.children[i] = _GenNode(node.depth+1,self.c_k,self.h_g)
117+
node.children[i] = _Node(node.depth+1,self.c_k)
120118
self._gen_params_recursion(node.children[i],None)
121119
else:
120+
if node.depth == self.c_d_max:
121+
node.h_g = 0
122+
else:
123+
node.h_g = h_node.h_g
124+
node.h_beta_vec[:] = h_node.h_beta_vec
122125
if node.depth == self.c_d_max or self.rng.random() > h_node.h_g: # 葉ノード
123126
node.theta_vec[:] = self.rng.dirichlet(h_node.h_beta_vec)
124127
node.leaf = True
125128
else: # 内部ノード
126129
node.leaf = False
127130
for i in range(self.c_k):
128131
if node.children[i] is None:
129-
node.children[i] = _GenNode(node.depth+1,self.c_k,self.h_g)
132+
node.children[i] = _Node(node.depth+1,self.c_k)
130133
self._gen_params_recursion(node.children[i],h_node.children[i])
131134

132-
def _gen_params_recursion_tree_fix(self,node,h_node):
135+
def _gen_params_recursion_tree_fix(self,node:_Node,h_node:_Node):
133136
""" generate parameters recursively for fixed tree"""
134137
if h_node is None:
138+
if node.depth == self.c_d_max:
139+
node.h_g = 0
140+
else:
141+
node.h_g = self.h_g
142+
node.h_beta_vec[:] = self.h_beta_vec
135143
if node.leaf: # 葉ノード
136144
node.theta_vec[:] = self.rng.dirichlet(self.h_beta_vec)
137145
else: # 内部ノード
138146
for i in range(self.c_k):
139147
if node.children[i] is not None:
140148
self._gen_params_recursion_tree_fix(node.children[i],None)
141149
else:
150+
if node.depth == self.c_d_max:
151+
node.h_g = 0
152+
else:
153+
node.h_g = h_node.h_g
154+
node.h_beta_vec[:] = h_node.h_beta_vec
142155
if node.leaf: # 葉ノード
143156
node.theta_vec[:] = self.rng.dirichlet(h_node.h_beta_vec)
144157
else: # 内部ノード
145158
for i in range(self.c_k):
146159
if node.children[i] is not None:
147160
self._gen_params_recursion_tree_fix(node.children[i],h_node.children[i])
148161

149-
def _set_recursion(self,node,original_tree_node):
162+
def _set_params_recursion(self,node:_Node,original_tree_node:_Node):
150163
""" copy parameters from a fixed tree
151164
152165
Parameters
153166
----------
154167
node : object
155-
a object form GenNode class
168+
a object from _Node class
156169
original_tree_node : object
157-
a object form GenNode class
170+
a object from _Node class
158171
"""
159-
node.h_g = original_tree_node.h_g
160-
node.h_beta_vec[:] = original_tree_node.h_beta_vec
161172
node.theta_vec[:] = original_tree_node.theta_vec
162173
if original_tree_node.leaf or node.depth == self.c_d_max: # 葉ノード
163174
node.leaf = True
164-
if node.depth == self.c_d_max:
165-
node.h_g = 0
166175
else:
167176
node.leaf = False
168177
for i in range(self.c_k):
169-
node.children[i] = _GenNode(node.depth+1,self.c_k,self.h_g)
170-
self._set_recursion(node.children[i],original_tree_node.children[i])
171-
178+
if node.children[i] is None:
179+
node.children[i] = _Node(node.depth+1,self.c_k)
180+
self._set_params_recursion(node.children[i],original_tree_node.children[i])
181+
182+
def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
183+
""" copy parameters from a fixed tree
184+
185+
Parameters
186+
----------
187+
node : object
188+
a object from _Node class
189+
original_tree_node : object
190+
a object from _Node class
191+
"""
192+
if original_tree_node is None:
193+
node.h_g = self.h_g
194+
node.h_beta_vec[:] = self.h_beta_vec
195+
if node.depth == self.c_d_max: # 葉ノード
196+
node.leaf = True
197+
if node.depth == self.c_d_max:
198+
node.h_g = 0
199+
else:
200+
node.leaf = False
201+
for i in range(self.c_k):
202+
if node.children[i] is None:
203+
node.children[i] = _Node(node.depth+1,self.c_k)
204+
self._set_h_params_recursion(node.children[i],None)
205+
else:
206+
node.h_g = original_tree_node.h_g
207+
node.h_beta_vec[:] = original_tree_node.h_beta_vec
208+
if original_tree_node.leaf or node.depth == self.c_d_max: # 葉ノード
209+
node.leaf = True
210+
if node.depth == self.c_d_max:
211+
node.h_g = 0
212+
else:
213+
node.leaf = False
214+
for i in range(self.c_k):
215+
if node.children[i] is None:
216+
node.children[i] = _Node(node.depth+1,self.c_k)
217+
self._set_h_params_recursion(node.children[i],original_tree_node.children[i])
218+
172219
def _gen_sample_recursion(self,node,x):
173220
"""Generate a sample from the stochastic data generative model.
174221
@@ -230,31 +277,35 @@ def set_h_params(self,
230277
h_beta_vec : numpy.ndarray, optional
231278
A vector of positive real numbers,
232279
by default ``None``
233-
h_root : contexttree._GenNode, optional
280+
h_root : contexttree._Node, optional
234281
A root node of a superposed tree for hyperparameters
235282
by default ``None``
236283
"""
237284
if h_g is not None:
238285
self.h_g = _check.float_in_closed01(h_g,'h_g',ParameterFormatError)
286+
if self.h_root is not None:
287+
self._set_h_params_recursion(self.h_root,None)
239288

240289
if h_beta_vec is not None:
241290
_check.pos_floats(h_beta_vec,'h_beta_vec',ParameterFormatError)
242291
self.h_beta_vec[:] = h_beta_vec
292+
if self.h_root is not None:
293+
self._set_h_params_recursion(self.h_root,None)
243294

244295
if h_root is not None:
245-
if type(h_root) is not _GenNode:
296+
if type(h_root) is not _Node:
246297
raise(ParameterFormatError(
247-
"h_root must be an instance of contexttree._GenNode"
298+
"h_root must be an instance of contexttree._Node"
248299
))
249-
self.h_root = _GenNode(0,self.c_k,self.h_g)
250-
self._set_recursion(self.h_root,h_root)
300+
self.h_root = _Node(0,self.c_k)
301+
self._set_h_params_recursion(self.h_root,h_root)
251302

252303
def get_h_params(self):
253304
"""Get the hyperparameters of the prior distribution.
254305
255306
Returns
256307
-------
257-
h_params : dict of {str: float, numpy.ndarray, contexttree._GenNode}
308+
h_params : dict of {str: float, numpy.ndarray, contexttree._Node}
258309
* ``"h_g"`` : the value of ``self.h_g``
259310
* ``"h_beta_vec"`` : the value of ``self.h_beta_vec``
260311
* ``"h_root"`` : the value of ``self.h_root``
@@ -278,43 +329,30 @@ def gen_params(self,tree_fix=False):
278329
else:
279330
self._gen_params_recursion(self.root,self.h_root)
280331

281-
def set_params(self,theta_vec=None,root=None):
332+
def set_params(self,root=None):
282333
"""Set the parameter of the sthocastic data generative model.
283334
284335
Parameters
285336
----------
286-
theta_vec : numpy.ndarray, optional
287-
A vector of real numbers in :math:`[0, 1]`,
288-
by default None.
289-
Sum of its elements must be 1.0.
290-
root : contexttree._GenNode, optional
337+
root : contexttree._Node, optional
291338
A root node of a contexttree, by default None.
292339
"""
293-
if theta_vec is not None:
294-
_check.float_vec_sum_1(theta_vec, "theta_vec", ParameterFormatError)
295-
_check.shape_consistency(
296-
theta_vec.shape[0],"theta_vec.shape[0]",
297-
self.c_k,"self.c_k",
298-
ParameterFormatError
299-
)
300-
self.theta_vec[:] = theta_vec
301340
if root is not None:
302-
if type(root) is not _GenNode:
341+
if type(root) is not _Node:
303342
raise(ParameterFormatError(
304-
"root must be an instance of metatree._GenNode"
343+
"root must be an instance of metatree._Node"
305344
))
306-
self._set_recursion(self.root,root)
345+
self._set_params_recursion(self.root,root)
307346

308347
def get_params(self):
309348
"""Get the parameter of the sthocastic data generative model.
310349
311350
Returns
312351
-------
313352
params : dict of {str:float}
314-
* ``"theta_vec"`` : The value of ``self.theta_vec``.
315353
* ``"root"`` : The value of ``self.root``.
316354
"""
317-
return {"theta_vec":self.theta_vec,"root":self.root}
355+
return {"root":self.root}
318356

319357
def gen_sample(self,sample_length,initial_values=None):
320358
"""Generate a sample from the stochastic data generative model.
@@ -553,7 +591,7 @@ def set_h0_params(self,
553591
h0_beta_vec : numpy.ndarray, optional
554592
A vector of positive real numbers,
555593
by default ``None``
556-
h0_root : contexttree._GenNode, optional
594+
h0_root : contexttree._Node, optional
557595
A root node of a superposed tree for hyperparameters
558596
by default ``None``
559597
"""
@@ -601,7 +639,7 @@ def set_hn_params(self,
601639
hn_beta_vec : numpy.ndarray, optional
602640
A vector of positive real numbers,
603641
by default ``None``
604-
hn_root : contexttree._GenNode, optional
642+
hn_root : contexttree._Node, optional
605643
A root node of a superposed tree for hyperparameters
606644
by default ``None``
607645
"""
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from bayesml import contexttree
2+
import numpy as np
3+
4+
gen_model = contexttree.GenModel(2,c_d_max=3)
5+
gen_model.set_h_params(h_g=0.99,h_beta_vec=np.ones(1))
6+
gen_model.visualize_model(filename='tree1.pdf')
7+
8+
gen_model.gen_params()
9+
gen_model.visualize_model(filename='tree2.pdf')
10+
11+
params = gen_model.get_params()
12+
# gen_model.set_params(root=params['root'])
13+
# gen_model.visualize_model(filename='tree3.pdf')
14+
15+
gen_model2 = contexttree.GenModel(2,c_d_max=2)
16+
gen_model2.set_h_params(h_root=params['root'])
17+
gen_model2.gen_params()
18+
gen_model2.visualize_model(filename='tree3.pdf')

0 commit comments

Comments
 (0)