15
15
16
16
_CMAP = plt .get_cmap ("Blues" )
17
17
18
- class _GenNode :
18
+ class _Node :
19
19
""" The node class used by generative model and the prior distribution
20
20
21
21
Parameters
@@ -28,14 +28,14 @@ class _GenNode:
28
28
def __init__ (self ,
29
29
depth ,
30
30
c_k ,
31
- h_g ,
32
31
):
33
32
self .depth = depth
34
33
self .children = [None for i in range (c_k )] # child nodes
35
- self .h_g = h_g
34
+ self .h_g = 0.5
36
35
self .h_beta_vec = np .ones (c_k ) / 2
37
36
self .theta_vec = np .ones (c_k ) / c_k
38
37
self .leaf = False
38
+ self .map_leaf = False
39
39
40
40
class GenModel (base .Generative ):
41
41
""" The stochastice data generative model and the prior distribution
@@ -46,11 +46,7 @@ class GenModel(base.Generative):
46
46
A positive integer
47
47
c_d_max : int, optional
48
48
A positive integer, by default 10
49
- theta_vec : numpy.ndarray, optional
50
- A vector of real numbers in :math:`[0, 1]`,
51
- by default [1/c_k, 1/c_k, ... , 1/c_k]
52
- Sum of its elements must be 1.0.
53
- root : contexttree._GenNode, optional
49
+ root : contexttree._Node, optional
54
50
A root node of a context tree,
55
51
by default a tree consists of only one node.
56
52
h_g : float, optional
@@ -59,7 +55,7 @@ class GenModel(base.Generative):
59
55
A vector of positive real numbers,
60
56
by default [1/2, 1/2, ... , 1/2].
61
57
If a single real number is input, it will be broadcasted.
62
- h_root : contexttree._GenNode , optional
58
+ h_root : contexttree._Node , optional
63
59
A root node of a superposed tree for hyperparameters
64
60
by default ``None``
65
61
seed : {None, int}, optional
@@ -71,7 +67,6 @@ def __init__(
71
67
c_k ,
72
68
c_d_max = 10 ,
73
69
* ,
74
- theta_vec = None ,
75
70
root = None ,
76
71
h_g = 0.5 ,
77
72
h_beta_vec = None ,
@@ -95,80 +90,132 @@ def __init__(
95
90
)
96
91
97
92
# params
98
- self .theta_vec = np .ones (self .c_k ) / self .c_k
99
- self .root = _GenNode (0 ,self .c_k ,self .h_g )
93
+ self .root = _Node (0 ,self .c_k )
94
+ self .root .h_g = self .h_g
95
+ self .root .h_beta_vec [:] = self .h_beta_vec
100
96
self .root .leaf = True
101
97
102
98
self .set_params (
103
- theta_vec ,
104
99
root ,
105
100
)
106
101
107
- def _gen_params_recursion (self ,node ,h_node ):
102
+ def _gen_params_recursion (self ,node : _Node ,h_node : _Node ):
108
103
""" generate parameters recursively"""
109
- if node .depth == self .c_d_max :
110
- node .h_g = 0
111
104
if h_node is None :
105
+ if node .depth == self .c_d_max :
106
+ node .h_g = 0
107
+ else :
108
+ node .h_g = self .h_g
109
+ node .h_beta_vec [:] = self .h_beta_vec
112
110
if node .depth == self .c_d_max or self .rng .random () > self .h_g : # 葉ノード
113
111
node .theta_vec [:] = self .rng .dirichlet (self .h_beta_vec )
114
112
node .leaf = True
115
113
else : # 内部ノード
116
114
node .leaf = False
117
115
for i in range (self .c_k ):
118
116
if node .children [i ] is None :
119
- node .children [i ] = _GenNode (node .depth + 1 ,self .c_k , self . h_g )
117
+ node .children [i ] = _Node (node .depth + 1 ,self .c_k )
120
118
self ._gen_params_recursion (node .children [i ],None )
121
119
else :
120
+ if node .depth == self .c_d_max :
121
+ node .h_g = 0
122
+ else :
123
+ node .h_g = h_node .h_g
124
+ node .h_beta_vec [:] = h_node .h_beta_vec
122
125
if node .depth == self .c_d_max or self .rng .random () > h_node .h_g : # 葉ノード
123
126
node .theta_vec [:] = self .rng .dirichlet (h_node .h_beta_vec )
124
127
node .leaf = True
125
128
else : # 内部ノード
126
129
node .leaf = False
127
130
for i in range (self .c_k ):
128
131
if node .children [i ] is None :
129
- node .children [i ] = _GenNode (node .depth + 1 ,self .c_k , self . h_g )
132
+ node .children [i ] = _Node (node .depth + 1 ,self .c_k )
130
133
self ._gen_params_recursion (node .children [i ],h_node .children [i ])
131
134
132
- def _gen_params_recursion_tree_fix (self ,node ,h_node ):
135
+ def _gen_params_recursion_tree_fix (self ,node : _Node ,h_node : _Node ):
133
136
""" generate parameters recursively for fixed tree"""
134
137
if h_node is None :
138
+ if node .depth == self .c_d_max :
139
+ node .h_g = 0
140
+ else :
141
+ node .h_g = self .h_g
142
+ node .h_beta_vec [:] = self .h_beta_vec
135
143
if node .leaf : # 葉ノード
136
144
node .theta_vec [:] = self .rng .dirichlet (self .h_beta_vec )
137
145
else : # 内部ノード
138
146
for i in range (self .c_k ):
139
147
if node .children [i ] is not None :
140
148
self ._gen_params_recursion_tree_fix (node .children [i ],None )
141
149
else :
150
+ if node .depth == self .c_d_max :
151
+ node .h_g = 0
152
+ else :
153
+ node .h_g = h_node .h_g
154
+ node .h_beta_vec [:] = h_node .h_beta_vec
142
155
if node .leaf : # 葉ノード
143
156
node .theta_vec [:] = self .rng .dirichlet (h_node .h_beta_vec )
144
157
else : # 内部ノード
145
158
for i in range (self .c_k ):
146
159
if node .children [i ] is not None :
147
160
self ._gen_params_recursion_tree_fix (node .children [i ],h_node .children [i ])
148
161
149
- def _set_recursion (self ,node ,original_tree_node ):
162
+ def _set_params_recursion (self ,node : _Node ,original_tree_node : _Node ):
150
163
""" copy parameters from a fixed tree
151
164
152
165
Parameters
153
166
----------
154
167
node : object
155
- a object form GenNode class
168
+ a object from _Node class
156
169
original_tree_node : object
157
- a object form GenNode class
170
+ a object from _Node class
158
171
"""
159
- node .h_g = original_tree_node .h_g
160
- node .h_beta_vec [:] = original_tree_node .h_beta_vec
161
172
node .theta_vec [:] = original_tree_node .theta_vec
162
173
if original_tree_node .leaf or node .depth == self .c_d_max : # 葉ノード
163
174
node .leaf = True
164
- if node .depth == self .c_d_max :
165
- node .h_g = 0
166
175
else :
167
176
node .leaf = False
168
177
for i in range (self .c_k ):
169
- node .children [i ] = _GenNode (node .depth + 1 ,self .c_k ,self .h_g )
170
- self ._set_recursion (node .children [i ],original_tree_node .children [i ])
171
-
178
+ if node .children [i ] is None :
179
+ node .children [i ] = _Node (node .depth + 1 ,self .c_k )
180
+ self ._set_params_recursion (node .children [i ],original_tree_node .children [i ])
181
+
182
+ def _set_h_params_recursion (self ,node :_Node ,original_tree_node :_Node ):
183
+ """ copy parameters from a fixed tree
184
+
185
+ Parameters
186
+ ----------
187
+ node : object
188
+ a object from _Node class
189
+ original_tree_node : object
190
+ a object from _Node class
191
+ """
192
+ if original_tree_node is None :
193
+ node .h_g = self .h_g
194
+ node .h_beta_vec [:] = self .h_beta_vec
195
+ if node .depth == self .c_d_max : # 葉ノード
196
+ node .leaf = True
197
+ if node .depth == self .c_d_max :
198
+ node .h_g = 0
199
+ else :
200
+ node .leaf = False
201
+ for i in range (self .c_k ):
202
+ if node .children [i ] is None :
203
+ node .children [i ] = _Node (node .depth + 1 ,self .c_k )
204
+ self ._set_h_params_recursion (node .children [i ],None )
205
+ else :
206
+ node .h_g = original_tree_node .h_g
207
+ node .h_beta_vec [:] = original_tree_node .h_beta_vec
208
+ if original_tree_node .leaf or node .depth == self .c_d_max : # 葉ノード
209
+ node .leaf = True
210
+ if node .depth == self .c_d_max :
211
+ node .h_g = 0
212
+ else :
213
+ node .leaf = False
214
+ for i in range (self .c_k ):
215
+ if node .children [i ] is None :
216
+ node .children [i ] = _Node (node .depth + 1 ,self .c_k )
217
+ self ._set_h_params_recursion (node .children [i ],original_tree_node .children [i ])
218
+
172
219
def _gen_sample_recursion (self ,node ,x ):
173
220
"""Generate a sample from the stochastic data generative model.
174
221
@@ -230,31 +277,35 @@ def set_h_params(self,
230
277
h_beta_vec : numpy.ndarray, optional
231
278
A vector of positive real numbers,
232
279
by default ``None``
233
- h_root : contexttree._GenNode , optional
280
+ h_root : contexttree._Node , optional
234
281
A root node of a superposed tree for hyperparameters
235
282
by default ``None``
236
283
"""
237
284
if h_g is not None :
238
285
self .h_g = _check .float_in_closed01 (h_g ,'h_g' ,ParameterFormatError )
286
+ if self .h_root is not None :
287
+ self ._set_h_params_recursion (self .h_root ,None )
239
288
240
289
if h_beta_vec is not None :
241
290
_check .pos_floats (h_beta_vec ,'h_beta_vec' ,ParameterFormatError )
242
291
self .h_beta_vec [:] = h_beta_vec
292
+ if self .h_root is not None :
293
+ self ._set_h_params_recursion (self .h_root ,None )
243
294
244
295
if h_root is not None :
245
- if type (h_root ) is not _GenNode :
296
+ if type (h_root ) is not _Node :
246
297
raise (ParameterFormatError (
247
- "h_root must be an instance of contexttree._GenNode "
298
+ "h_root must be an instance of contexttree._Node "
248
299
))
249
- self .h_root = _GenNode (0 ,self .c_k , self . h_g )
250
- self ._set_recursion (self .h_root ,h_root )
300
+ self .h_root = _Node (0 ,self .c_k )
301
+ self ._set_h_params_recursion (self .h_root ,h_root )
251
302
252
303
def get_h_params (self ):
253
304
"""Get the hyperparameters of the prior distribution.
254
305
255
306
Returns
256
307
-------
257
- h_params : dict of {str: float, numpy.ndarray, contexttree._GenNode }
308
+ h_params : dict of {str: float, numpy.ndarray, contexttree._Node }
258
309
* ``"h_g"`` : the value of ``self.h_g``
259
310
* ``"h_beta_vec"`` : the value of ``self.h_beta_vec``
260
311
* ``"h_root"`` : the value of ``self.h_root``
@@ -278,43 +329,30 @@ def gen_params(self,tree_fix=False):
278
329
else :
279
330
self ._gen_params_recursion (self .root ,self .h_root )
280
331
281
- def set_params (self ,theta_vec = None , root = None ):
332
+ def set_params (self ,root = None ):
282
333
"""Set the parameter of the sthocastic data generative model.
283
334
284
335
Parameters
285
336
----------
286
- theta_vec : numpy.ndarray, optional
287
- A vector of real numbers in :math:`[0, 1]`,
288
- by default None.
289
- Sum of its elements must be 1.0.
290
- root : contexttree._GenNode, optional
337
+ root : contexttree._Node, optional
291
338
A root node of a contexttree, by default None.
292
339
"""
293
- if theta_vec is not None :
294
- _check .float_vec_sum_1 (theta_vec , "theta_vec" , ParameterFormatError )
295
- _check .shape_consistency (
296
- theta_vec .shape [0 ],"theta_vec.shape[0]" ,
297
- self .c_k ,"self.c_k" ,
298
- ParameterFormatError
299
- )
300
- self .theta_vec [:] = theta_vec
301
340
if root is not None :
302
- if type (root ) is not _GenNode :
341
+ if type (root ) is not _Node :
303
342
raise (ParameterFormatError (
304
- "root must be an instance of metatree._GenNode "
343
+ "root must be an instance of metatree._Node "
305
344
))
306
- self ._set_recursion (self .root ,root )
345
+ self ._set_params_recursion (self .root ,root )
307
346
308
347
def get_params (self ):
309
348
"""Get the parameter of the sthocastic data generative model.
310
349
311
350
Returns
312
351
-------
313
352
params : dict of {str:float}
314
- * ``"theta_vec"`` : The value of ``self.theta_vec``.
315
353
* ``"root"`` : The value of ``self.root``.
316
354
"""
317
- return {"theta_vec" : self . theta_vec , " root" :self .root }
355
+ return {"root" :self .root }
318
356
319
357
def gen_sample (self ,sample_length ,initial_values = None ):
320
358
"""Generate a sample from the stochastic data generative model.
@@ -553,7 +591,7 @@ def set_h0_params(self,
553
591
h0_beta_vec : numpy.ndarray, optional
554
592
A vector of positive real numbers,
555
593
by default ``None``
556
- h0_root : contexttree._GenNode , optional
594
+ h0_root : contexttree._Node , optional
557
595
A root node of a superposed tree for hyperparameters
558
596
by default ``None``
559
597
"""
@@ -601,7 +639,7 @@ def set_hn_params(self,
601
639
hn_beta_vec : numpy.ndarray, optional
602
640
A vector of positive real numbers,
603
641
by default ``None``
604
- hn_root : contexttree._GenNode , optional
642
+ hn_root : contexttree._Node , optional
605
643
A root node of a superposed tree for hyperparameters
606
644
by default ``None``
607
645
"""
0 commit comments