Skip to content

Commit 1499550

Browse files
authored
FIX mutual_info_regression when X is of integer dtype (scikit-learn#26748)
1 parent 10997c9 commit 1499550

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

doc/whats_new/v1.4.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,10 @@ TODO: update at the time of the release.
9393
when using a custom initialization. The default value of this parameter will change
9494
from `None` to `auto` in version 1.6.
9595
:pr:`26634` by :user:`Alexandre Landeau <AlexL>` and :user:`Alexandre Vigny <avigny>`.
96+
97+
98+
:mod:`sklearn.feature_selection`
99+
................................
100+
101+
- |Fix| :func:`feature_selection.mutual_info_regression` now correctly computes the
102+
result when `X` is of integer dtype. :pr:`26748` by :user:`Yao Xiao <Charlie-XIAO>`.

sklearn/feature_selection/_mutual_info.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,15 +280,12 @@ def _estimate_mi(
280280

281281
rng = check_random_state(random_state)
282282
if np.any(continuous_mask):
283-
if copy:
284-
X = X.copy()
285-
283+
X = X.astype(np.float64, copy=copy)
286284
X[:, continuous_mask] = scale(
287285
X[:, continuous_mask], with_mean=False, copy=False
288286
)
289287

290288
# Add small noise to continuous features as advised in Kraskov et. al.
291-
X = X.astype(np.float64, copy=False)
292289
means = np.maximum(1, np.mean(np.abs(X[:, continuous_mask]), axis=0))
293290
X[:, continuous_mask] += (
294291
1e-10

sklearn/feature_selection/tests/test_mutual_info.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,18 @@ def test_mutual_information_symmetry_classif_regression(correlated, global_rando
236236
)
237237

238238
assert mi_classif == pytest.approx(mi_regression)
239+
240+
241+
def test_mutual_info_regression_X_int_dtype(global_random_seed):
242+
"""Check that results agree when X is integer dtype and float dtype.
243+
244+
Non-regression test for Issue #26696.
245+
"""
246+
rng = np.random.RandomState(global_random_seed)
247+
X = rng.randint(100, size=(100, 10))
248+
X_float = X.astype(np.float64, copy=True)
249+
y = rng.randint(100, size=100)
250+
251+
expected = mutual_info_regression(X_float, y, random_state=global_random_seed)
252+
result = mutual_info_regression(X, y, random_state=global_random_seed)
253+
assert_allclose(result, expected)

0 commit comments

Comments
 (0)