Skip to content

Commit 6bd570c

Browse files
committed
Add reshaping of x
1 parent 6be85f0 commit 6bd570c

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

bayesml/metatree/_metatree.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ def gen_sample(self,sample_size=None,x_continuous=None,x_categorical=None):
711711
self.c_dim_continuous,'self.c_dim_continuous',
712712
ParameterFormatError
713713
)
714+
x_continuous = x_continuous.reshape(-1,self.c_dim_continuous)
714715
_check.shape_consistency(
715716
x_continuous.shape[0],'x_continuous.shape[0]',
716717
sample_size,'sample_size',
@@ -729,6 +730,7 @@ def gen_sample(self,sample_size=None,x_continuous=None,x_categorical=None):
729730
self.c_dim_categorical,'self.c_dim_categorical',
730731
ParameterFormatError
731732
)
733+
x_categorical = x_categorical.reshape(-1,self.c_dim_categorical)
732734
_check.shape_consistency(
733735
x_categorical.shape[0],'x_categorical.shape[0]',
734736
sample_size,'sample_size',
@@ -765,6 +767,7 @@ def gen_sample(self,sample_size=None,x_continuous=None,x_categorical=None):
765767
self.c_dim_categorical,'self.c_dim_categorical',
766768
ParameterFormatError
767769
)
770+
x_categorical = x_categorical.reshape(-1,self.c_dim_categorical)
768771
_check.shape_consistency(
769772
x_categorical.shape[0],'x_categorical.shape[0]',
770773
x_continuous.shape[0],'x_continuous.shape[0]',

0 commit comments

Comments
 (0)