@@ -411,7 +411,7 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
411
411
node .h_g = self .h_g
412
412
# node.sub_model.set_h_params(**self.sub_h_params)
413
413
node .sub_model = self .SubModel .GenModel (** self .sub_h_params )
414
- for i in range (self .c_k ):
414
+ for i in range (self .c_num_children ):
415
415
if node .children [i ] is not None :
416
416
self ._set_h_params_recursion (node .children [i ],None )
417
417
else :
@@ -428,7 +428,7 @@ def _set_h_params_recursion(self,node:_Node,original_tree_node:_Node):
428
428
node .leaf = True
429
429
else :
430
430
node .leaf = False
431
- for i in range (self .c_k ):
431
+ for i in range (self .c_num_children ):
432
432
if node .children [i ] is None :
433
433
node .children [i ] = _Node (
434
434
node .depth + 1 ,
@@ -611,7 +611,7 @@ def get_params(self):
611
611
"""
612
612
return {"root" :self .root }
613
613
614
- def gen_sample (self ,sample_size ,x = None ):
614
+ def gen_sample (self ,sample_size = None ,x = None ):
615
615
"""Generate a sample from the stochastic data generative model.
616
616
617
617
Parameters
@@ -630,10 +630,20 @@ def gen_sample(self,sample_size,x=None):
630
630
y : numpy ndarray
631
631
1 dimensional array whose size is ``sammple_size``.
632
632
"""
633
- _check .pos_int (sample_size ,'sample_size' ,DataFormatError )
634
-
635
- if x is None :
633
+ if x is not None :
634
+ _check .int_vecs (x ,'x' ,DataFormatError )
635
+ _check .shape_consistency (
636
+ x .shape [- 1 ],'x.shape[-1]' ,
637
+ self .c_k ,'self.c_k' ,
638
+ ParameterFormatError
639
+ )
640
+ x = x .reshape (- 1 ,self .c_k )
641
+ sample_size = x .shape [0 ]
642
+ elif sample_size is not None :
643
+ _check .pos_int (sample_size ,'sample_size' ,DataFormatError )
636
644
x = self .rng .choice (self .c_num_children ,(sample_size ,self .c_k ))
645
+ else :
646
+ raise (DataFormatError ("Either of the sample_size and the x must be given as a input." ))
637
647
638
648
if self .SubModel in DISCRETE_MODELS :
639
649
y = np .empty (sample_size ,dtype = int )
0 commit comments