Skip to content

Commit df0fae2

Browse files
committed
Complete merge
Signed-off-by: Adam Li <adam2392@gmail.com>
2 parents 6b57c58 + 1499550 commit df0fae2

File tree

4 files changed

+49
-3
lines changed

4 files changed

+49
-3
lines changed

sklearn/ensemble/_forest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2316,6 +2316,25 @@ class ExtraTreesClassifier(ForestClassifier):
23162316
23172317
.. versionadded:: 1.4
23182318
2319+
monotonic_cst : array-like of int of shape (n_features), default=None
2320+
Indicates the monotonicity constraint to enforce on each feature.
2321+
- 1: monotonically increasing
2322+
- 0: no constraint
2323+
- -1: monotonically decreasing
2324+
2325+
If monotonic_cst is None, no constraints are applied.
2326+
2327+
Monotonicity constraints are not supported for:
2328+
- multiclass classifications (i.e. when `n_classes > 2`),
2329+
- multioutput classifications (i.e. when `n_outputs_ > 1`),
2330+
- classifications trained on data with missing values.
2331+
2332+
The constraints hold over the probability of the positive class.
2333+
2334+
Read more in the :ref:`User Guide <monotonic_cst_gbdt>`.
2335+
2336+
.. versionadded:: 1.4
2337+
23192338
Attributes
23202339
----------
23212340
estimator_ : :class:`~sklearn.tree.ExtraTreesClassifier`

sklearn/tree/_splitter.pxd

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ cdef class BaseSplitter:
7171
# +1: monotonic increase
7272
cdef const cnp.int8_t[:] monotonic_cst
7373
cdef bint with_monotonic_cst
74-
7574
cdef const DOUBLE_t[:] sample_weight
7675

7776
# The samples vector `samples` is maintained by the Splitter object such

sklearn/tree/_splitter.pyx

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ cdef class Splitter(BaseSplitter):
300300
"""Copy the value of node samples[start:end] into dest."""
301301

302302
self.criterion.node_value(dest)
303-
303+
304304
cdef inline void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil:
305305
"""Clip the value in dest between lower_bound and upper_bound for monotonic constraints."""
306306

@@ -310,6 +310,11 @@ cdef class Splitter(BaseSplitter):
310310
"""Copy the samples[start:end] into dest."""
311311
self.criterion.node_samples(dest)
312312

313+
cdef inline void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil:
314+
"""Clip the value in dest between lower_bound and upper_bound for monotonic constraints."""
315+
316+
self.criterion.clip_node_value(dest, lower_bound, upper_bound)
317+
313318
cdef double node_impurity(self) noexcept nogil:
314319
"""Return the impurity of the current node."""
315320

@@ -564,6 +569,18 @@ cdef inline int node_split_best(
564569

565570
criterion.update(current_split.pos)
566571

572+
# Reject if monotonicity constraints are not satisfied
573+
if (
574+
with_monotonic_cst and
575+
monotonic_cst[current_split.feature] != 0 and
576+
not criterion.check_monotonicity(
577+
monotonic_cst[current_split.feature],
578+
lower_bound,
579+
upper_bound,
580+
)
581+
):
582+
continue
583+
567584
# Reject if min_weight_leaf is not satisfied
568585
if splitter.check_postsplit_conditions() == 1:
569586
continue
@@ -915,6 +932,18 @@ cdef inline int node_split_random(
915932
):
916933
continue
917934

935+
# Reject if monotonicity constraints are not satisfied
936+
if (
937+
with_monotonic_cst and
938+
monotonic_cst[current_split.feature] != 0 and
939+
not criterion.check_monotonicity(
940+
monotonic_cst[current_split.feature],
941+
lower_bound,
942+
upper_bound,
943+
)
944+
):
945+
continue
946+
918947
current_proxy_improvement = criterion.proxy_impurity_improvement()
919948

920949
if current_proxy_improvement > best_proxy_improvement:

sklearn/tree/_tree.pyx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
364364
"is_left": 1,
365365
"impurity": split.impurity_left,
366366
"n_constant_features": n_constant_features,
367-
"n_constant_features": n_constant_features,
368367
"lower_bound": left_child_min,
369368
"upper_bound": left_child_max,
370369
})

0 commit comments

Comments
 (0)