@@ -711,6 +711,7 @@ def gen_sample(self,sample_size=None,x_continuous=None,x_categorical=None):
711
711
self .c_dim_continuous ,'self.c_dim_continuous' ,
712
712
ParameterFormatError
713
713
)
714
+ x_continuous = x_continuous .reshape (- 1 ,self .c_dim_continuous )
714
715
_check .shape_consistency (
715
716
x_continuous .shape [0 ],'x_continuous.shape[0]' ,
716
717
sample_size ,'sample_size' ,
@@ -729,6 +730,7 @@ def gen_sample(self,sample_size=None,x_continuous=None,x_categorical=None):
729
730
self .c_dim_categorical ,'self.c_dim_categorical' ,
730
731
ParameterFormatError
731
732
)
733
+ x_categorical = x_categorical .reshape (- 1 ,self .c_dim_categorical )
732
734
_check .shape_consistency (
733
735
x_categorical .shape [0 ],'x_categorical.shape[0]' ,
734
736
sample_size ,'sample_size' ,
@@ -765,6 +767,7 @@ def gen_sample(self,sample_size=None,x_continuous=None,x_categorical=None):
765
767
self .c_dim_categorical ,'self.c_dim_categorical' ,
766
768
ParameterFormatError
767
769
)
770
+ x_categorical = x_categorical .reshape (- 1 ,self .c_dim_categorical )
768
771
_check .shape_consistency (
769
772
x_categorical .shape [0 ],'x_categorical.shape[0]' ,
770
773
x_continuous .shape [0 ],'x_continuous.shape[0]' ,
0 commit comments