Skip to content

Commit ee3e617

Browse files
affanv14amueller
authored andcommitted
[MRG+2] adding multilabel support for score_func (scikit-learn#7676)
* added multilabel support for score function * added test for multilabel score function * updated whats_new.rst * updated whats_new.rst with working link
1 parent 568c002 commit ee3e617

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,11 @@ Bug fixes
9595
`#6497 <https://github.com/scikit-learn/scikit-learn/pull/6497>`_
9696
by `Sebastian Säger`_
9797

98-
98+
- Fixes issue in :ref:`univariate_feature_selection` where score
99+
functions were not accepting multi-label targets.(`#7676
100+
<https://github.com/scikit-learn/scikit-learn/pull/7676>`_)
101+
by `Mohammed Affan`_
102+
99103
.. _changes_0_18:
100104

101105
Version 0.18

sklearn/feature_selection/tests/test_feature_select.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,21 @@ def test_tied_pvalues():
519519
assert_not_in(9998, Xt)
520520

521521

522+
def test_scorefunc_multilabel():
523+
# Test whether k-best and percentiles works with multilabels with chi2.
524+
525+
X = np.array([[10000, 9999, 0], [100, 9999, 0], [1000, 99, 0]])
526+
y = [[1, 1], [0, 1], [1, 0]]
527+
528+
Xt = SelectKBest(chi2, k=2).fit_transform(X, y)
529+
assert_equal(Xt.shape, (3, 2))
530+
assert_not_in(0, Xt)
531+
532+
Xt = SelectPercentile(chi2, percentile=67).fit_transform(X, y)
533+
assert_equal(Xt.shape, (3, 2))
534+
assert_not_in(0, Xt)
535+
536+
522537
def test_tied_scores():
523538
# Test for stable sorting in k-best with tied scores.
524539
X_train = np.array([[0, 0, 0], [1, 1, 1]])

sklearn/feature_selection/univariate_selection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def fit(self, X, y):
319319
self : object
320320
Returns self.
321321
"""
322-
X, y = check_X_y(X, y, ['csr', 'csc'])
322+
X, y = check_X_y(X, y, ['csr', 'csc'], multi_output=True)
323323

324324
if not callable(self.score_func):
325325
raise TypeError("The score function should be a callable, %s (%s) "

0 commit comments

Comments
 (0)