Skip to content

Commit 9a614f4

Browse files
authored
Attempt to bring in categorical support (#46)
<!-- Thanks for contributing a pull request! Please ensure you have taken a look at the contribution guidelines: https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md --> #### Reference Issues/PRs Helps bring in fork wrt changes in scikit-learn#12866 #### What does this implement/fix? Explain your changes. #### Any other comments? <!-- Please be aware that we are a loose team of volunteers so patience is necessary; assistance handling other issues is very welcome. We value all user contributions, no matter how minor they are. If we are slow to review, either the pull request needs some benchmarking, tinkering, convincing, etc. or more likely the reviewers are simply busy. In either case, we ask for your understanding during the review process. For more information, see our FAQ on this topic: http://scikit-learn.org/dev/faq.html#why-is-my-pull-request-not-getting-any-attention. Thanks for contributing! --> --------- Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent e9d702b commit 9a614f4

File tree

15 files changed

+1462
-100
lines changed

15 files changed

+1462
-100
lines changed

benchmarks/bench_tree_nocats.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from itertools import product
2+
from timeit import timeit
3+
4+
import numpy as np
5+
import pandas as pd
6+
7+
from sklearn.datasets import fetch_openml
8+
from sklearn.ensemble import ExtraTreesClassifier, RandomForestClassifier
9+
from sklearn.metrics import roc_auc_score
10+
from sklearn.model_selection import StratifiedKFold
11+
from sklearn.preprocessing import OneHotEncoder
12+
13+
14+
def get_data(trunc_ncat):
15+
# the data is located here: https://www.openml.org/d/4135
16+
X, y = fetch_openml(data_id=4135, return_X_y=True)
17+
X = pd.DataFrame(X)
18+
19+
Xdicts = []
20+
for trunc in trunc_ncat:
21+
X_trunc = X % trunc if trunc > 0 else X
22+
keep_idx = np.array(
23+
[idx[0] for idx in X_trunc.groupby(list(X.columns)).groups.values()]
24+
)
25+
X_trunc = X_trunc.values[keep_idx]
26+
y_trunc = y[keep_idx]
27+
28+
X_ohe = OneHotEncoder(categories="auto").fit_transform(X_trunc)
29+
30+
Xdicts.append({"X": X_trunc, "y": y_trunc, "ohe": False, "trunc": trunc})
31+
Xdicts.append({"X": X_ohe, "y": y_trunc, "ohe": True, "trunc": trunc})
32+
33+
return Xdicts
34+
35+
36+
# Training dataset
37+
trunc_factor = [2, 3, 4, 5, 6, 8, 10, 12, 14, 16, 64, 0]
38+
data = get_data(trunc_factor)
39+
results = []
40+
# Loop over classifiers and datasets
41+
for Xydict, clf_type in product(data, [RandomForestClassifier, ExtraTreesClassifier]):
42+
# Can't use non-truncated categorical data with RandomForest
43+
# and it becomes intractable with too many categories
44+
if (
45+
clf_type is RandomForestClassifier
46+
and not Xydict["ohe"]
47+
and (not Xydict["trunc"] or Xydict["trunc"] > 16)
48+
):
49+
continue
50+
51+
X, y = Xydict["X"], Xydict["y"]
52+
tech = "One-hot" if Xydict["ohe"] else "NOCATS"
53+
trunc = "truncated({})".format(Xydict["trunc"]) if Xydict["trunc"] > 0 else "full"
54+
cat = "none" if Xydict["ohe"] else "all"
55+
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=17).split(X, y)
56+
57+
traintimes = []
58+
testtimes = []
59+
aucs = []
60+
name = "({}, {}, {})".format(clf_type.__name__, trunc, tech)
61+
62+
for train, test in cv:
63+
# Train
64+
clf = clf_type(
65+
n_estimators=10,
66+
max_features=None,
67+
min_samples_leaf=1,
68+
random_state=23,
69+
bootstrap=False,
70+
max_depth=None,
71+
categorical=cat,
72+
)
73+
74+
traintimes.append(
75+
timeit(
76+
"clf.fit(X[train], y[train])".format(),
77+
"from __main__ import clf, X, y, train",
78+
number=1,
79+
)
80+
)
81+
82+
"""
83+
# Check that all leaf nodes are pure
84+
for est in clf.estimators_:
85+
leaves = est.tree_.children_left < 0
86+
print(np.max(est.tree_.impurity[leaves]))
87+
#assert(np.all(est.tree_.impurity[leaves] == 0))
88+
"""
89+
90+
# Test
91+
probs = []
92+
testtimes.append(
93+
timeit(
94+
"probs.append(clf.predict_proba(X[test]))",
95+
"from __main__ import probs, clf, X, test",
96+
number=1,
97+
)
98+
)
99+
100+
aucs.append(roc_auc_score(y[test], probs[0][:, 1]))
101+
102+
traintimes = np.array(traintimes)
103+
testtimes = np.array(testtimes)
104+
aucs = np.array(aucs)
105+
results.append(
106+
[
107+
name,
108+
traintimes.mean(),
109+
traintimes.std(),
110+
testtimes.mean(),
111+
testtimes.std(),
112+
aucs.mean(),
113+
aucs.std(),
114+
]
115+
)
116+
117+
results_df = pd.DataFrame(results)
118+
results_df.columns = [
119+
"name",
120+
"train time mean",
121+
"train time std",
122+
"test time mean",
123+
"test time std",
124+
"auc mean",
125+
"auc std",
126+
]
127+
results_df = results_df.set_index("name")
128+
print(results_df)

sklearn/ensemble/_forest.py

Lines changed: 74 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ def predict_quantiles(self, X, quantiles=0.5, method="nearest"):
725725
----------
726726
X : {array-like, sparse matrix} of shape (n_samples, n_features)
727727
Input data.
728-
quantiles : float, optional
728+
quantiles : array-like, float, optional
729729
The quantiles at which to evaluate, by default 0.5 (median).
730730
method : str, optional
731731
The method to interpolate, by default 'linear'. Can be any keyword
@@ -746,7 +746,7 @@ def predict_quantiles(self, X, quantiles=0.5, method="nearest"):
746746
X = self._validate_X_predict(X)
747747

748748
if not isinstance(quantiles, (np.ndarray, list)):
749-
quantiles = np.array([quantiles])
749+
quantiles = np.atleast_1d(np.array(quantiles))
750750

751751
# if we trained a binning tree, then we should re-bin the data
752752
# XXX: this is inefficient and should be improved to be in line with what
@@ -777,15 +777,15 @@ def predict_quantiles(self, X, quantiles=0.5, method="nearest"):
777777

778778
# (n_total_leaf_samples, n_outputs)
779779
leaf_node_samples = np.vstack(
780-
(
780+
[
781781
est.leaf_nodes_samples_[leaf_nodes[jdx]]
782782
for jdx, est in enumerate(self.estimators_)
783-
)
783+
]
784784
)
785785

786786
# get quantiles across all leaf node samples
787787
y_hat[idx, ...] = np.quantile(
788-
leaf_node_samples, quantiles, axis=0, interpolation=method
788+
leaf_node_samples, quantiles, axis=0, method=method
789789
)
790790

791791
if is_classifier(self):
@@ -1550,6 +1550,17 @@ class RandomForestClassifier(ForestClassifier):
15501550
15511551
.. versionadded:: 1.4
15521552
1553+
categorical : array-like or str
1554+
Array of feature indices, boolean array of length n_features,
1555+
``'all'`` or `None`. Indicates which features should be
1556+
considered as categorical rather than ordinal. For decision trees,
1557+
the maximum number of categories is 64. In practice, the limit will
1558+
often be lower because the process of searching for the best possible
1559+
split grows exponentially with the number of categories. However, a
1560+
shortcut due to Breiman (1984) is used when fitting data with binary
1561+
labels using the ``Gini`` or ``Entropy`` criteria. In this case,
1562+
the runtime is linear in the number of categories.
1563+
15531564
Attributes
15541565
----------
15551566
estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier`
@@ -1693,6 +1704,7 @@ def __init__(
16931704
max_bins=None,
16941705
store_leaf_values=False,
16951706
monotonic_cst=None,
1707+
categorical=None,
16961708
):
16971709
super().__init__(
16981710
estimator=DecisionTreeClassifier(),
@@ -1710,6 +1722,7 @@ def __init__(
17101722
"ccp_alpha",
17111723
"store_leaf_values",
17121724
"monotonic_cst",
1725+
"categorical",
17131726
),
17141727
bootstrap=bootstrap,
17151728
oob_score=oob_score,
@@ -1733,6 +1746,7 @@ def __init__(
17331746
self.min_impurity_decrease = min_impurity_decrease
17341747
self.monotonic_cst = monotonic_cst
17351748
self.ccp_alpha = ccp_alpha
1749+
self.categorical = categorical
17361750

17371751

17381752
class RandomForestRegressor(ForestRegressor):
@@ -1935,6 +1949,17 @@ class RandomForestRegressor(ForestRegressor):
19351949
19361950
.. versionadded:: 1.4
19371951
1952+
categorical : array-like or str
1953+
Array of feature indices, boolean array of length n_features,
1954+
``'all'`` or `None`. Indicates which features should be
1955+
considered as categorical rather than ordinal. For decision trees,
1956+
the maximum number of categories is 64. In practice, the limit will
1957+
often be lower because the process of searching for the best possible
1958+
split grows exponentially with the number of categories. However, a
1959+
shortcut due to Breiman (1984) is used when fitting data with binary
1960+
labels using the ``Gini`` or ``Entropy`` criteria. In this case,
1961+
the runtime is linear in the number of categories.
1962+
19381963
Attributes
19391964
----------
19401965
estimator_ : :class:`~sklearn.tree.DecisionTreeRegressor`
@@ -2065,6 +2090,7 @@ def __init__(
20652090
max_bins=None,
20662091
store_leaf_values=False,
20672092
monotonic_cst=None,
2093+
categorical=None,
20682094
):
20692095
super().__init__(
20702096
estimator=DecisionTreeRegressor(),
@@ -2082,6 +2108,7 @@ def __init__(
20822108
"ccp_alpha",
20832109
"store_leaf_values",
20842110
"monotonic_cst",
2111+
"categorical",
20852112
),
20862113
bootstrap=bootstrap,
20872114
oob_score=oob_score,
@@ -2104,6 +2131,7 @@ def __init__(
21042131
self.min_impurity_decrease = min_impurity_decrease
21052132
self.ccp_alpha = ccp_alpha
21062133
self.monotonic_cst = monotonic_cst
2134+
self.categorical = categorical
21072135

21082136

21092137
class ExtraTreesClassifier(ForestClassifier):
@@ -2316,24 +2344,16 @@ class ExtraTreesClassifier(ForestClassifier):
23162344
23172345
.. versionadded:: 1.4
23182346
2319-
monotonic_cst : array-like of int of shape (n_features), default=None
2320-
Indicates the monotonicity constraint to enforce on each feature.
2321-
- 1: monotonically increasing
2322-
- 0: no constraint
2323-
- -1: monotonically decreasing
2324-
2325-
If monotonic_cst is None, no constraints are applied.
2326-
2327-
Monotonicity constraints are not supported for:
2328-
- multiclass classifications (i.e. when `n_classes > 2`),
2329-
- multioutput classifications (i.e. when `n_outputs_ > 1`),
2330-
- classifications trained on data with missing values.
2331-
2332-
The constraints hold over the probability of the positive class.
2333-
2334-
Read more in the :ref:`User Guide <monotonic_cst_gbdt>`.
2335-
2336-
.. versionadded:: 1.4
2347+
categorical : array-like or str
2348+
Array of feature indices, boolean array of length n_features,
2349+
``'all'`` or `None`. Indicates which features should be
2350+
considered as categorical rather than ordinal. For decision trees,
2351+
the maximum number of categories is 64. In practice, the limit will
2352+
often be lower because the process of searching for the best possible
2353+
split grows exponentially with the number of categories. However, a
2354+
shortcut due to Breiman (1984) is used when fitting data with binary
2355+
labels using the ``Gini`` or ``Entropy`` criteria. In this case,
2356+
the runtime is linear in the number of categories.
23372357
23382358
Attributes
23392359
----------
@@ -2467,6 +2487,7 @@ def __init__(
24672487
max_bins=None,
24682488
store_leaf_values=False,
24692489
monotonic_cst=None,
2490+
categorical=None,
24702491
):
24712492
super().__init__(
24722493
estimator=ExtraTreeClassifier(),
@@ -2484,6 +2505,7 @@ def __init__(
24842505
"ccp_alpha",
24852506
"store_leaf_values",
24862507
"monotonic_cst",
2508+
"categorical",
24872509
),
24882510
bootstrap=bootstrap,
24892511
oob_score=oob_score,
@@ -2507,6 +2529,7 @@ def __init__(
25072529
self.min_impurity_decrease = min_impurity_decrease
25082530
self.ccp_alpha = ccp_alpha
25092531
self.monotonic_cst = monotonic_cst
2532+
self.categorical = categorical
25102533

25112534

25122535
class ExtraTreesRegressor(ForestRegressor):
@@ -2704,6 +2727,17 @@ class ExtraTreesRegressor(ForestRegressor):
27042727
27052728
.. versionadded:: 1.4
27062729
2730+
categorical : array-like or str
2731+
Array of feature indices, boolean array of length n_features,
2732+
``'all'`` or `None`. Indicates which features should be
2733+
considered as categorical rather than ordinal. For decision trees,
2734+
the maximum number of categories is 64. In practice, the limit will
2735+
often be lower because the process of searching for the best possible
2736+
split grows exponentially with the number of categories. However, a
2737+
shortcut due to Breiman (1984) is used when fitting data with binary
2738+
labels using the ``Gini`` or ``Entropy`` criteria. In this case,
2739+
the runtime is linear in the number of categories.
2740+
27072741
Attributes
27082742
----------
27092743
estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor`
@@ -2819,6 +2853,7 @@ def __init__(
28192853
max_bins=None,
28202854
store_leaf_values=False,
28212855
monotonic_cst=None,
2856+
categorical=None,
28222857
):
28232858
super().__init__(
28242859
estimator=ExtraTreeRegressor(),
@@ -2836,6 +2871,7 @@ def __init__(
28362871
"ccp_alpha",
28372872
"store_leaf_values",
28382873
"monotonic_cst",
2874+
"categorical",
28392875
),
28402876
bootstrap=bootstrap,
28412877
oob_score=oob_score,
@@ -2858,6 +2894,7 @@ def __init__(
28582894
self.min_impurity_decrease = min_impurity_decrease
28592895
self.ccp_alpha = ccp_alpha
28602896
self.monotonic_cst = monotonic_cst
2897+
self.categorical = categorical
28612898

28622899

28632900
class RandomTreesEmbedding(TransformerMixin, BaseForest):
@@ -2969,6 +3006,17 @@ class RandomTreesEmbedding(TransformerMixin, BaseForest):
29693006
new forest. See :term:`Glossary <warm_start>` and
29703007
:ref:`gradient_boosting_warm_start` for details.
29713008
3009+
categorical : array-like or str
3010+
Array of feature indices, boolean array of length n_features,
3011+
``'all'`` or `None`. Indicates which features should be
3012+
considered as categorical rather than ordinal. For decision trees,
3013+
the maximum number of categories is 64. In practice, the limit will
3014+
often be lower because the process of searching for the best possible
3015+
split grows exponentially with the number of categories. However, a
3016+
shortcut due to Breiman (1984) is used when fitting data with binary
3017+
labels using the ``Gini`` or ``Entropy`` criteria. In this case,
3018+
the runtime is linear in the number of categories.
3019+
29723020
Attributes
29733021
----------
29743022
estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor` instance
@@ -3073,6 +3121,7 @@ def __init__(
30733121
verbose=0,
30743122
warm_start=False,
30753123
store_leaf_values=False,
3124+
categorical=None,
30763125
):
30773126
super().__init__(
30783127
estimator=ExtraTreeRegressor(),
@@ -3088,6 +3137,7 @@ def __init__(
30883137
"min_impurity_decrease",
30893138
"random_state",
30903139
"store_leaf_values",
3140+
"categorical",
30913141
),
30923142
bootstrap=False,
30933143
oob_score=False,
@@ -3106,6 +3156,7 @@ def __init__(
31063156
self.max_leaf_nodes = max_leaf_nodes
31073157
self.min_impurity_decrease = min_impurity_decrease
31083158
self.sparse_output = sparse_output
3159+
self.categorical = categorical
31093160

31103161
def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
31113162
raise NotImplementedError("OOB score not supported by tree embedding")

0 commit comments

Comments
 (0)