@@ -79,15 +79,21 @@ class _Node:
79
79
def __init__ (self ,
80
80
depth ,
81
81
c_num_children ,
82
+ k_candidates = None ,
83
+ h_g = 0.5 ,
84
+ k = None ,
85
+ sub_model = None ,
86
+ leaf = False ,
87
+ map_leaf = False
82
88
):
83
89
self .depth = depth
84
90
self .children = [None for i in range (c_num_children )] # child nodes
85
- self .k_candidates = None
86
- self .h_g = 0.5
87
- self .k = None
88
- self .sub_model = None
89
- self .leaf = False
90
- self .map_leaf = False
91
+ self .k_candidates = k_candidates
92
+ self .h_g = h_g
93
+ self .k = k
94
+ self .sub_model = sub_model
95
+ self .leaf = leaf
96
+ self .map_leaf = map_leaf
91
97
92
98
class GenModel (base .Generative ):
93
99
""" The stochastice data generative model and the prior distribution
@@ -168,16 +174,18 @@ def __init__(
168
174
)
169
175
170
176
# params
171
- self .root = _Node (0 ,self .c_num_children )
172
- self .root .k_candidates = list (range (self .c_k ))
173
- self .root .h_g = self .h_g
174
- self .root .k = 0
175
- self .root .sub_model = self .SubModel .GenModel (** self .sub_h_params )
176
- self .root .leaf = True
177
+ self .root = _Node (
178
+ 0 ,
179
+ self .c_num_children ,
180
+ list (range (self .c_k )),
181
+ self .h_g ,
182
+ sub_model = self .SubModel .GenModel (** self .sub_h_params ),
183
+ leaf = True
184
+ )
177
185
178
186
self .set_params (root )
179
187
180
- def _gen_params_recursion (self ,node :_Node ,feature_fix ):
188
+ def _gen_params_recursion (self ,node :_Node ,h_node : _Node , feature_fix ):
181
189
""" generate parameters recursively
182
190
183
191
Parameters
@@ -187,29 +195,65 @@ def _gen_params_recursion(self,node:_Node,feature_fix):
187
195
feature_fix : bool
188
196
a bool parameter show the feature is fixed or not
189
197
"""
190
- if node .depth == self .c_d_max or node .depth == self .c_k or self .rng .random () > node .h_g : # leaf node
191
- node .sub_model = self .SubModel .GenModel (** self .sub_h_params )
192
- node .sub_model .gen_params ()
198
+ if h_node is None :
193
199
if node .depth == self .c_d_max :
194
200
node .h_g = 0
195
- node .leaf = True
196
- else : # inner node
197
- if feature_fix == False or node .k is None :
198
- node .k = self .rng .choice (node .k_candidates ,
199
- p = self .h_k_prob_vec [node .k_candidates ]/ self .h_k_prob_vec [node .k_candidates ].sum ())
200
- child_k_candidates = copy .copy (node .k_candidates )
201
- child_k_candidates .remove (node .k )
202
- node .leaf = False
203
- for i in range (self .c_num_children ):
204
- if feature_fix == False or node .children [i ] is None :
205
- node .children [i ] = _Node (node .depth + 1 ,self .c_num_children )
201
+ else :
202
+ node .h_g = self .h_g
203
+ # node.sub_model.set_h_params(**self.sub_h_params)
204
+ node .sub_model = self .SubModel .GenModel (** self .sub_h_params )
205
+ if node .depth == self .c_d_max or node .depth == self .c_k or self .rng .random () > self .h_g : # leaf node
206
+ node .sub_model .gen_params ()
207
+ node .leaf = True
208
+ else : # inner node
209
+ if feature_fix == False or node .k is None :
210
+ node .k = self .rng .choice (node .k_candidates ,
211
+ p = self .h_k_prob_vec [node .k_candidates ]/ self .h_k_prob_vec [node .k_candidates ].sum ())
212
+ child_k_candidates = copy .copy (node .k_candidates )
213
+ child_k_candidates .remove (node .k )
214
+ node .leaf = False
215
+ for i in range (self .c_num_children ):
216
+ if node .children [i ] is None :
217
+ node .children [i ] = _Node (
218
+ node .depth + 1 ,
219
+ self .c_num_children ,
220
+ h_g = self .h_g ,
221
+ sub_model = self .SubModel .GenModel (** self .sub_h_params ),
222
+ )
206
223
node .children [i ].k_candidates = child_k_candidates
207
- node .children [i ].h_g = self .h_g
208
- else :
224
+ self ._gen_params_recursion (node .children [i ],None ,feature_fix )
225
+ else :
226
+ if node .depth == self .c_d_max :
227
+ node .h_g = 0
228
+ else :
229
+ node .h_g = h_node .h_g
230
+ try :
231
+ sub_h_params = h_node .sub_model .get_h_params ()
232
+ except :
233
+ sub_h_params = h_node .sub_model .get_hn_params ()
234
+ node .sub_model .set_h_params (* sub_h_params .values ())
235
+ if node .depth == self .c_d_max or node .depth == self .c_k or self .rng .random () > h_node .h_g : # leaf node
236
+ node .sub_model .gen_params ()
237
+ node .leaf = True
238
+ else : # inner node
239
+ if feature_fix == False or node .k is None :
240
+ node .k = self .rng .choice (node .k_candidates ,
241
+ p = self .h_k_prob_vec [node .k_candidates ]/ self .h_k_prob_vec [node .k_candidates ].sum ())
242
+ child_k_candidates = copy .copy (node .k_candidates )
243
+ child_k_candidates .remove (node .k )
244
+ node .leaf = False
245
+ for i in range (self .c_num_children ):
246
+ if node .children [i ] is None :
247
+ node .children [i ] = _Node (
248
+ node .depth + 1 ,
249
+ self .c_num_children ,
250
+ h_g = self .h_g ,
251
+ sub_model = self .SubModel .GenModel (** self .sub_h_params ),
252
+ )
209
253
node .children [i ].k_candidates = child_k_candidates
210
- self ._gen_params_recursion (node .children [i ],feature_fix )
254
+ self ._gen_params_recursion (node . children [ i ], h_node .children [i ],feature_fix )
211
255
212
- def _gen_params_recursion_tree_fix (self ,node :_Node ,feature_fix ):
256
+ def _gen_params_recursion_tree_fix (self ,node :_Node ,h_node : _Node , feature_fix ):
213
257
""" generate parameters recursively for fixed tree
214
258
215
259
Parameters
@@ -219,27 +263,51 @@ def _gen_params_recursion_tree_fix(self,node:_Node,feature_fix):
219
263
feature_fix : bool
220
264
a bool parameter show the feature is fixed or not
221
265
"""
222
- if node .leaf : # leaf node
266
+ if h_node is None :
267
+ if node .depth == self .c_d_max :
268
+ node .h_g = 0
269
+ else :
270
+ node .h_g = self .h_g
271
+ # node.sub_model.set_h_params(**self.sub_h_params)
223
272
node .sub_model = self .SubModel .GenModel (** self .sub_h_params )
224
- node .sub_model .gen_params ()
273
+ if node .leaf : # leaf node
274
+ node .sub_model .gen_params ()
275
+ node .leaf = True
276
+ else : # inner node
277
+ if feature_fix == False or node .k is None :
278
+ node .k = self .rng .choice (node .k_candidates ,
279
+ p = self .h_k_prob_vec [node .k_candidates ]/ self .h_k_prob_vec [node .k_candidates ].sum ())
280
+ child_k_candidates = copy .copy (node .k_candidates )
281
+ child_k_candidates .remove (node .k )
282
+ node .leaf = False
283
+ for i in range (self .c_num_children ):
284
+ if node .children [i ] is not None :
285
+ node .children [i ].k_candidates = child_k_candidates
286
+ self ._gen_params_recursion_tree_fix (node .children [i ],None ,feature_fix )
287
+ else :
225
288
if node .depth == self .c_d_max :
226
289
node .h_g = 0
227
- node .leaf = True
228
- else : # inner node
229
- if feature_fix == False or node .k is None :
230
- node .k = self .rng .choice (node .k_candidates ,
231
- p = self .h_k_prob_vec [node .k_candidates ]/ self .h_k_prob_vec [node .k_candidates ].sum ())
232
- child_k_candidates = copy .copy (node .k_candidates )
233
- child_k_candidates .remove (node .k )
234
- node .leaf = False
235
- for i in range (self .c_num_children ):
236
- if feature_fix == False or node .children [i ] is None :
237
- node .children [i ] = _Node (node .depth + 1 ,self .c_num_children )
238
- node .children [i ].k_candidates = child_k_candidates
239
- node .children [i ].h_g = self .h_g
240
- else :
241
- node .children [i ].k_candidates = child_k_candidates
242
- self ._gen_params_recursion_tree_fix (node .children [i ],feature_fix )
290
+ else :
291
+ node .h_g = h_node .h_g
292
+ try :
293
+ sub_h_params = h_node .sub_model .get_h_params ()
294
+ except :
295
+ sub_h_params = h_node .sub_model .get_hn_params ()
296
+ node .sub_model .set_h_params (* sub_h_params .values ())
297
+ if node .leaf : # leaf node
298
+ node .sub_model .gen_params ()
299
+ node .leaf = True
300
+ else : # inner node
301
+ if feature_fix == False or node .k is None :
302
+ node .k = self .rng .choice (node .k_candidates ,
303
+ p = self .h_k_prob_vec [node .k_candidates ]/ self .h_k_prob_vec [node .k_candidates ].sum ())
304
+ child_k_candidates = copy .copy (node .k_candidates )
305
+ child_k_candidates .remove (node .k )
306
+ node .leaf = False
307
+ for i in range (self .c_num_children ):
308
+ if node .children [i ] is not None :
309
+ node .children [i ].k_candidates = child_k_candidates
310
+ self ._gen_params_recursion_tree_fix (node .children [i ],h_node ,feature_fix )
243
311
244
312
def _set_params_recursion (self ,node :_Node ,original_tree_node :_Node ):
245
313
""" copy parameters from a fixed tree
@@ -260,9 +328,14 @@ def _set_params_recursion(self,node:_Node,original_tree_node:_Node):
260
328
node .k = original_tree_node .k
261
329
child_k_candidates = copy .copy (node .k_candidates )
262
330
child_k_candidates .remove (node .k )
331
+ node .leaf = False
263
332
for i in range (self .c_num_children ):
264
- node .children [i ] = _Node (node .depth + 1 ,self .c_num_children )
265
- node .children [i ].k_candidates = child_k_candidates
333
+ node .children [i ] = _Node (
334
+ node .depth + 1 ,
335
+ self .c_num_children ,
336
+ child_k_candidates ,
337
+ self .h_g ,
338
+ )
266
339
self ._set_params_recursion (node .children [i ],original_tree_node .children [i ])
267
340
268
341
def _gen_sample_recursion (self ,node :_Node ,x ):
@@ -291,7 +364,11 @@ def _visualize_model_recursion(self,tree_graph,node:_Node,node_id,parent_id,sibl
291
364
tmp_p_v = p_v
292
365
293
366
# add node information
294
- label_string = f'k={ node .k } \\ lh_g={ node .h_g :.2f} \\ lp_v={ tmp_p_v :.2f} \\ lsub_params={{'
367
+ if node .leaf :
368
+ label_string = 'k=None\\ l'
369
+ else :
370
+ label_string = f'k={ node .k } \\ l'
371
+ label_string += f'h_g={ node .h_g :.2f} \\ lp_v={ tmp_p_v :.2f} \\ lsub_params={{'
295
372
if node .leaf :
296
373
sub_params = node .sub_model .get_params ()
297
374
for key ,value in sub_params .items ():
@@ -332,28 +409,32 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
332
409
node .h_g = 0
333
410
else :
334
411
node .h_g = self .h_g
335
- node .sub_model .set_h_params (** self .sub_h_params )
412
+ # node.sub_model.set_h_params(**self.sub_h_params)
413
+ node .sub_model = self .SubModel .GenModel (** self .sub_h_params )
336
414
for i in range (self .c_k ):
337
415
if node .children [i ] is not None :
338
416
self ._set_h_params_recursion (node .children [i ],None )
339
417
else :
340
- node .h_g = original_tree_node .h_g
418
+ if node .depth == self .c_d_max :
419
+ node .h_g = 0
420
+ else :
421
+ node .h_g = original_tree_node .h_g
341
422
try :
342
423
sub_h_params = node .sub_model .get_h_params ()
343
424
except :
344
425
sub_h_params = node .sub_model .get_hn_params ()
345
- node .sub_model .set_h_params (
346
- * sub_h_params .values ()
347
- )
426
+ node .sub_model .set_h_params (* sub_h_params .values ())
348
427
if original_tree_node .leaf or node .depth == self .c_d_max : # leaf node
349
428
node .leaf = True
350
- if node .depth == self .c_d_max :
351
- node .h_g = 0
352
429
else :
353
430
node .leaf = False
354
431
for i in range (self .c_k ):
355
432
if node .children [i ] is None :
356
- node .children [i ] = _Node (node .depth + 1 ,self .c_k )
433
+ node .children [i ] = _Node (
434
+ node .depth + 1 ,
435
+ self .c_k ,
436
+ sub_model = self .SubModel .GenModel (** self .sub_h_params ),
437
+ )
357
438
self ._set_h_params_recursion (node .children [i ],original_tree_node .children [i ])
358
439
359
440
def set_h_params (self ,
@@ -398,7 +479,6 @@ def set_h_params(self,
398
479
for h_root in self .h_metatree_list :
399
480
self ._set_h_params_recursion (h_root ,None )
400
481
401
-
402
482
if sub_h_params is not None :
403
483
self .SubModel .GenModel (** sub_h_params )
404
484
self .sub_h_params = copy .deepcopy (sub_h_params )
@@ -474,7 +554,7 @@ def get_h_params(self):
474
554
"h_metatree_list" :self .h_metatree_list ,
475
555
"h_metatree_prob_vec" :self .h_metatree_prob_vec }
476
556
477
- def gen_params (self ,feature_fix = False ,tree_fix = False , from_list = False ):
557
+ def gen_params (self ,feature_fix = False ,tree_fix = False ):
478
558
"""Generate the parameter from the prior distribution.
479
559
480
560
The generated vaule is set at ``self.root``.
@@ -486,13 +566,17 @@ def gen_params(self,feature_fix=False,tree_fix=False,from_list=False):
486
566
tree_fix : bool
487
567
If ``True``, tree shape will be fixed, by default ``False``.
488
568
"""
489
- if from_list == True and len ( self .h_metatree_list ) > 0 :
569
+ if self .h_metatree_list :
490
570
tmp_root = self .rng .choice (self .h_metatree_list ,p = self .h_metatree_prob_vec )
491
- self .set_params (tmp_root )
492
- elif tree_fix :
493
- self ._gen_params_recursion_tree_fix (self .root ,feature_fix )
571
+ if tree_fix :
572
+ self ._gen_params_recursion_tree_fix (self .root ,tmp_root ,feature_fix )
573
+ else :
574
+ self ._gen_params_recursion (self .root ,tmp_root ,feature_fix )
494
575
else :
495
- self ._gen_params_recursion (self .root ,feature_fix )
576
+ if tree_fix :
577
+ self ._gen_params_recursion_tree_fix (self .root ,None ,feature_fix )
578
+ else :
579
+ self ._gen_params_recursion (self .root ,None ,feature_fix )
496
580
497
581
def set_params (self ,root = None ):
498
582
"""Set the parameter of the sthocastic data generative model.
@@ -507,6 +591,14 @@ def set_params(self,root=None):
507
591
raise (ParameterFormatError (
508
592
"root must be an instance of metatree._Node"
509
593
))
594
+ self .root = _Node (
595
+ 0 ,
596
+ self .c_num_children ,
597
+ list (range (self .c_k )),
598
+ self .h_g ,
599
+ sub_model = self .SubModel .GenModel (** self .sub_h_params ),
600
+ leaf = True
601
+ )
510
602
self ._set_params_recursion (self .root ,root )
511
603
512
604
def get_params (self ):
0 commit comments