Skip to content

Commit 8112691

Browse files
committed
Fixed unit tests
Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent 9a614f4 commit 8112691

File tree

2 files changed

+13
-25
lines changed

2 files changed

+13
-25
lines changed

sklearn/ensemble/_forest.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -784,9 +784,14 @@ def predict_quantiles(self, X, quantiles=0.5, method="nearest"):
784784
)
785785

786786
# get quantiles across all leaf node samples
787-
y_hat[idx, ...] = np.quantile(
788-
leaf_node_samples, quantiles, axis=0, method=method
789-
)
787+
try:
788+
y_hat[idx, ...] = np.quantile(
789+
leaf_node_samples, quantiles, axis=0, method=method
790+
)
791+
except TypeError:
792+
y_hat[idx, ...] = np.quantile(
793+
leaf_node_samples, quantiles, axis=0, interpolation=method
794+
)
790795

791796
if is_classifier(self):
792797
if self.n_outputs_ == 1:

sklearn/tree/_tree.pyx

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,16 +1008,8 @@ cdef class BaseTree:
10081008
cache_mgr = CategoryCacheMgr()
10091009
cache_mgr.populate(self.nodes, self.node_count, self.n_categories)
10101010
cdef vector[vector[UINT64_t]] cat_caches = cache_mgr.bits
1011-
# cdef vector[UINT64_t] cache = NULL
1012-
10131011
cdef const INT32_t[:] n_categories = self.n_categories
10141012

1015-
# apply Cache to speed up categorical "apply"
1016-
# cache_mgr = CategoryCacheMgr()
1017-
# cache_mgr.populate(self.nodes, self.node_count, self.n_categories)
1018-
# cdef UINT64_t** cat_caches = cache_mgr.bits
1019-
# cdef UINT64_t* cache = NULL
1020-
10211013
with nogil:
10221014
for i in range(n_samples):
10231015
node = self.nodes
@@ -1034,9 +1026,6 @@ cdef class BaseTree:
10341026
node = &self.nodes[node.right_child]
10351027
elif goes_left(
10361028
X_i_node_feature,
1037-
# node.split_value,
1038-
# node.threshold,
1039-
# self.n_categories[node.feature],
10401029
node,
10411030
n_categories,
10421031
cache
@@ -1082,7 +1071,6 @@ cdef class BaseTree:
10821071
cache_mgr = CategoryCacheMgr()
10831072
cache_mgr.populate(self.nodes, self.node_count, self.n_categories)
10841073
cdef vector[vector[UINT64_t]] cat_caches = cache_mgr.bits
1085-
# cdef vector[UINT64_t] cache = NULL
10861074

10871075
cdef const INT32_t[:] n_categories = self.n_categories
10881076
# feature_to_sample as a data structure records the last seen sample
@@ -1114,9 +1102,6 @@ cdef class BaseTree:
11141102

11151103
if goes_left(
11161104
feature_value,
1117-
# node.split_value,
1118-
# node.threshold,
1119-
# self.n_categories[node.feature],
11201105
node,
11211106
n_categories,
11221107
cache
@@ -1650,21 +1635,19 @@ cdef class Tree(BaseTree):
16501635
self.n_classes = NULL
16511636
safe_realloc(&self.n_classes, n_outputs)
16521637

1653-
self.n_categories = NULL
1654-
safe_realloc(&self.n_categories, n_features)
1638+
cdef SIZE_t k
16551639

16561640
# n-categories is a 1D array of size n_features
1657-
# self.n_categories = np.empty(n_features, dtype=np.int32)
1658-
# self.n_categories = n_categories
1641+
self.n_categories = NULL
1642+
safe_realloc(&self.n_categories, n_features)
1643+
for k in range(n_features):
1644+
self.n_categories[k] = n_categories[k]
16591645

16601646
self.max_n_classes = np.max(n_classes)
16611647
self.value_stride = n_outputs * self.max_n_classes
16621648

1663-
cdef SIZE_t k
16641649
for k in range(n_outputs):
16651650
self.n_classes[k] = n_classes[k]
1666-
for k in range(n_features):
1667-
self.n_categories[k] = n_categories[k]
16681651

16691652
# Inner structures
16701653
self.max_depth = 0

0 commit comments

Comments
 (0)