Skip to content

Commit 031d2f8

Browse files
antoinebakerogriseljeremiedbb
authored
FIX Draw indices using sample_weight in Bagging (scikit-learn#31414)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
1 parent d4d4af8 commit 031d2f8

File tree

5 files changed

+286
-101
lines changed

5 files changed

+286
-101
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
- :class:`ensemble.BaggingClassfier`, :class:`ensemble.BaggingRegressor`
2+
and :class:`ensemble.IsolationForest` now use `sample_weight` to draw
3+
the samples instead of forwarding them multiplied by a uniformly sampled
4+
mask to the underlying estimators. Furthermore, `max_samples` is now
5+
interpreted as a fraction of `sample_weight.sum()` instead of `X.shape[0]`
6+
when passed as a float.
7+
By :user:`Antoine Baker <antoinebaker>`.

sklearn/ensemble/_bagging.py

Lines changed: 82 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def _generate_bagging_indices(
7272
n_samples,
7373
max_features,
7474
max_samples,
75+
sample_weight,
7576
):
7677
"""Randomly draw feature and sample indices."""
7778
# Get valid random state
@@ -81,18 +82,37 @@ def _generate_bagging_indices(
8182
feature_indices = _generate_indices(
8283
random_state, bootstrap_features, n_features, max_features
8384
)
84-
sample_indices = _generate_indices(
85-
random_state, bootstrap_samples, n_samples, max_samples
86-
)
85+
if sample_weight is None:
86+
sample_indices = _generate_indices(
87+
random_state, bootstrap_samples, n_samples, max_samples
88+
)
89+
else:
90+
normalized_sample_weight = sample_weight / np.sum(sample_weight)
91+
sample_indices = random_state.choice(
92+
n_samples,
93+
max_samples,
94+
replace=bootstrap_samples,
95+
p=normalized_sample_weight,
96+
)
8797

8898
return feature_indices, sample_indices
8999

90100

101+
def _consumes_sample_weight(estimator):
102+
if _routing_enabled():
103+
request_or_router = get_routing_for_object(estimator)
104+
consumes_sample_weight = request_or_router.consumes("fit", ("sample_weight",))
105+
else:
106+
consumes_sample_weight = has_fit_parameter(estimator, "sample_weight")
107+
return consumes_sample_weight
108+
109+
91110
def _parallel_build_estimators(
92111
n_estimators,
93112
ensemble,
94113
X,
95114
y,
115+
sample_weight,
96116
seeds,
97117
total_n_estimators,
98118
verbose,
@@ -108,22 +128,12 @@ def _parallel_build_estimators(
108128
bootstrap_features = ensemble.bootstrap_features
109129
has_check_input = has_fit_parameter(ensemble.estimator_, "check_input")
110130
requires_feature_indexing = bootstrap_features or max_features != n_features
131+
consumes_sample_weight = _consumes_sample_weight(ensemble.estimator_)
111132

112133
# Build estimators
113134
estimators = []
114135
estimators_features = []
115136

116-
# TODO: (slep6) remove if condition for unrouted sample_weight when metadata
117-
# routing can't be disabled.
118-
support_sample_weight = has_fit_parameter(ensemble.estimator_, "sample_weight")
119-
if not _routing_enabled() and (
120-
not support_sample_weight and fit_params.get("sample_weight") is not None
121-
):
122-
raise ValueError(
123-
"The base estimator doesn't support sample weight, but sample_weight is "
124-
"passed to the fit method."
125-
)
126-
127137
for i in range(n_estimators):
128138
if verbose > 1:
129139
print(
@@ -139,7 +149,8 @@ def _parallel_build_estimators(
139149
else:
140150
estimator_fit = estimator.fit
141151

142-
# Draw random feature, sample indices
152+
# Draw random feature, sample indices (using normalized sample_weight
153+
# as probabilites if provided).
143154
features, indices = _generate_bagging_indices(
144155
random_state,
145156
bootstrap_features,
@@ -148,45 +159,22 @@ def _parallel_build_estimators(
148159
n_samples,
149160
max_features,
150161
max_samples,
162+
sample_weight,
151163
)
152164

153165
fit_params_ = fit_params.copy()
154166

155-
# TODO(SLEP6): remove if condition for unrouted sample_weight when metadata
156-
# routing can't be disabled.
157-
# 1. If routing is enabled, we will check if the routing supports sample
158-
# weight and use it if it does.
159-
# 2. If routing is not enabled, we will check if the base
160-
# estimator supports sample_weight and use it if it does.
161-
162167
# Note: Row sampling can be achieved either through setting sample_weight or
163-
# by indexing. The former is more efficient. Therefore, use this method
168+
# by indexing. The former is more memory efficient. Therefore, use this method
164169
# if possible, otherwise use indexing.
165-
if _routing_enabled():
166-
request_or_router = get_routing_for_object(ensemble.estimator_)
167-
consumes_sample_weight = request_or_router.consumes(
168-
"fit", ("sample_weight",)
169-
)
170-
else:
171-
consumes_sample_weight = support_sample_weight
172170
if consumes_sample_weight:
173-
# Draw sub samples, using sample weights, and then fit
174-
curr_sample_weight = _check_sample_weight(
175-
fit_params_.pop("sample_weight", None), X
176-
).copy()
177-
178-
if bootstrap:
179-
sample_counts = np.bincount(indices, minlength=n_samples)
180-
curr_sample_weight *= sample_counts
181-
else:
182-
not_indices_mask = ~indices_to_mask(indices, n_samples)
183-
curr_sample_weight[not_indices_mask] = 0
184-
185-
fit_params_["sample_weight"] = curr_sample_weight
171+
# Row sampling by setting sample_weight
172+
indices_as_sample_weight = np.bincount(indices, minlength=n_samples)
173+
fit_params_["sample_weight"] = indices_as_sample_weight
186174
X_ = X[:, features] if requires_feature_indexing else X
187175
estimator_fit(X_, y, **fit_params_)
188176
else:
189-
# cannot use sample_weight, so use indexing
177+
# Row sampling by indexing
190178
y_ = _safe_indexing(y, indices)
191179
X_ = _safe_indexing(X, indices)
192180
fit_params_ = _check_method_params(X, params=fit_params_, indices=indices)
@@ -354,9 +342,11 @@ def fit(self, X, y, sample_weight=None, **fit_params):
354342
regression).
355343
356344
sample_weight : array-like of shape (n_samples,), default=None
357-
Sample weights. If None, then samples are equally weighted.
358-
Note that this is supported only if the base estimator supports
359-
sample weighting.
345+
Sample weights. If None, then samples are equally weighted. Used as
346+
probabilities to sample the training set. Note that the expected
347+
frequency semantics for the `sample_weight` parameter are only
348+
fulfilled when sampling with replacement `bootstrap=True`.
349+
360350
**fit_params : dict
361351
Parameters to pass to the underlying estimators.
362352
@@ -386,6 +376,15 @@ def fit(self, X, y, sample_weight=None, **fit_params):
386376
multi_output=True,
387377
)
388378

379+
if sample_weight is not None:
380+
sample_weight = _check_sample_weight(sample_weight, X, dtype=None)
381+
382+
if not self.bootstrap:
383+
warn(
384+
f"When fitting {self.__class__.__name__} with sample_weight "
385+
f"it is recommended to use bootstrap=True, got {self.bootstrap}."
386+
)
387+
389388
return self._fit(
390389
X,
391390
y,
@@ -435,8 +434,6 @@ def _fit(
435434
436435
sample_weight : array-like of shape (n_samples,), default=None
437436
Sample weights. If None, then samples are equally weighted.
438-
Note that this is supported only if the base estimator supports
439-
sample weighting.
440437
441438
**fit_params : dict, default=None
442439
Parameters to pass to the :term:`fit` method of the underlying
@@ -457,30 +454,38 @@ def _fit(
457454
# Check parameters
458455
self._validate_estimator(self._get_estimator())
459456

460-
if sample_weight is not None:
461-
fit_params["sample_weight"] = sample_weight
462-
463457
if _routing_enabled():
464458
routed_params = process_routing(self, "fit", **fit_params)
465459
else:
466460
routed_params = Bunch()
467461
routed_params.estimator = Bunch(fit=fit_params)
468-
if "sample_weight" in fit_params:
469-
routed_params.estimator.fit["sample_weight"] = fit_params[
470-
"sample_weight"
471-
]
472462

473463
if max_depth is not None:
474464
self.estimator_.max_depth = max_depth
475465

476466
# Validate max_samples
477467
if max_samples is None:
478468
max_samples = self.max_samples
479-
elif not isinstance(max_samples, numbers.Integral):
480-
max_samples = int(max_samples * X.shape[0])
481469

482-
if max_samples > X.shape[0]:
483-
raise ValueError("max_samples must be <= n_samples")
470+
if not isinstance(max_samples, numbers.Integral):
471+
if sample_weight is None:
472+
max_samples = max(int(max_samples * X.shape[0]), 1)
473+
else:
474+
sw_sum = np.sum(sample_weight)
475+
if sw_sum <= 1:
476+
raise ValueError(
477+
f"The total sum of sample weights is {sw_sum}, which prevents "
478+
"resampling with a fractional value for max_samples="
479+
f"{max_samples}. Either pass max_samples as an integer or "
480+
"use a larger sample_weight."
481+
)
482+
max_samples = max(int(max_samples * sw_sum), 1)
483+
484+
if not self.bootstrap and max_samples > X.shape[0]:
485+
raise ValueError(
486+
f"Effective max_samples={max_samples} must be <= n_samples="
487+
f"{X.shape[0]} to be able to sample without replacement."
488+
)
484489

485490
# Store validated integer row sampling value
486491
self._max_samples = max_samples
@@ -499,6 +504,11 @@ def _fit(
499504
# Store validated integer feature sampling value
500505
self._max_features = max_features
501506

507+
# Store sample_weight (needed in _get_estimators_indices). Note that
508+
# we intentionally do not materialize `sample_weight=None` as an array
509+
# of ones to avoid unnecessarily cluttering trained estimator pickles.
510+
self._sample_weight = sample_weight
511+
502512
# Other checks
503513
if not self.bootstrap and self.oob_score:
504514
raise ValueError("Out of bag estimation only available if bootstrap=True")
@@ -552,6 +562,7 @@ def _fit(
552562
self,
553563
X,
554564
y,
565+
sample_weight,
555566
seeds[starts[i] : starts[i + 1]],
556567
total_n_estimators,
557568
verbose=self.verbose,
@@ -596,6 +607,7 @@ def _get_estimators_indices(self):
596607
self._n_samples,
597608
self._max_features,
598609
self._max_samples,
610+
self._sample_weight,
599611
)
600612

601613
yield feature_indices, sample_indices
@@ -726,7 +738,8 @@ class BaggingClassifier(ClassifierMixin, BaseBagging):
726738
replacement by default, see `bootstrap` for more details).
727739
728740
- If int, then draw `max_samples` samples.
729-
- If float, then draw `max_samples * X.shape[0]` samples.
741+
- If float, then draw `max_samples * X.shape[0]` unweighted samples
742+
or `max_samples * sample_weight.sum()` weighted samples.
730743
731744
max_features : int or float, default=1.0
732745
The number of features to draw from X to train each base estimator (
@@ -737,8 +750,10 @@ class BaggingClassifier(ClassifierMixin, BaseBagging):
737750
- If float, then draw `max(1, int(max_features * n_features_in_))` features.
738751
739752
bootstrap : bool, default=True
740-
Whether samples are drawn with replacement. If False, sampling
741-
without replacement is performed.
753+
Whether samples are drawn with replacement. If False, sampling without
754+
replacement is performed. If fitting with `sample_weight`, it is
755+
strongly recommended to choose True, as only drawing with replacement
756+
will ensure the expected frequency semantics of `sample_weight`.
742757
743758
bootstrap_features : bool, default=False
744759
Whether features are drawn with replacement.
@@ -1245,8 +1260,10 @@ class BaggingRegressor(RegressorMixin, BaseBagging):
12451260
- If float, then draw `max(1, int(max_features * n_features_in_))` features.
12461261
12471262
bootstrap : bool, default=True
1248-
Whether samples are drawn with replacement. If False, sampling
1249-
without replacement is performed.
1263+
Whether samples are drawn with replacement. If False, sampling without
1264+
replacement is performed. If fitting with `sample_weight`, it is
1265+
strongly recommended to choose True, as only drawing with replacement
1266+
will ensure the expected frequency semantics of `sample_weight`.
12501267
12511268
bootstrap_features : bool, default=False
12521269
Whether features are drawn with replacement.

sklearn/ensemble/_iforest.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from ..utils._chunking import get_chunk_n_rows
2121
from ..utils._param_validation import Interval, RealNotInt, StrOptions
2222
from ..utils.parallel import Parallel, delayed
23-
from ..utils.validation import _num_samples, check_is_fitted, validate_data
23+
from ..utils.validation import (
24+
_check_sample_weight,
25+
_num_samples,
26+
check_is_fitted,
27+
validate_data,
28+
)
2429
from ._bagging import BaseBagging
2530

2631
__all__ = ["IsolationForest"]
@@ -317,6 +322,10 @@ def fit(self, X, y=None, sample_weight=None):
317322
X = validate_data(
318323
self, X, accept_sparse=["csc"], dtype=tree_dtype, ensure_all_finite=False
319324
)
325+
326+
if sample_weight is not None:
327+
sample_weight = _check_sample_weight(sample_weight, X, dtype=None)
328+
320329
if issparse(X):
321330
# Pre-sort indices to avoid that each individual tree of the
322331
# ensemble sorts the indices.
@@ -350,7 +359,7 @@ def fit(self, X, y=None, sample_weight=None):
350359
super()._fit(
351360
X,
352361
y,
353-
max_samples,
362+
max_samples=max_samples,
354363
max_depth=max_depth,
355364
sample_weight=sample_weight,
356365
check_input=False,

0 commit comments

Comments
 (0)