Skip to content

Commit 5b20d48

Browse files
authored
[MRG] ENH enable setting pipeline components as parameters (scikit-learn#1769)
Pipeline and FeatureUnion steps may now be set with set_params, and transformers may be replaced with None to effectively remove them. Also test and improve ducktyping of Pipeline methods
1 parent ebb0645 commit 5b20d48

File tree

7 files changed

+761
-166
lines changed

7 files changed

+761
-166
lines changed

doc/modules/pipeline.rst

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,16 @@ is an estimator object::
3737
>>> from sklearn.pipeline import Pipeline
3838
>>> from sklearn.svm import SVC
3939
>>> from sklearn.decomposition import PCA
40-
>>> estimators = [('reduce_dim', PCA()), ('svm', SVC())]
41-
>>> clf = Pipeline(estimators)
42-
>>> clf # doctest: +NORMALIZE_WHITESPACE
40+
>>> estimators = [('reduce_dim', PCA()), ('clf', SVC())]
41+
>>> pipe = Pipeline(estimators)
42+
>>> pipe # doctest: +NORMALIZE_WHITESPACE
4343
Pipeline(steps=[('reduce_dim', PCA(copy=True, iterated_power=4,
4444
n_components=None, random_state=None, svd_solver='auto', tol=0.0,
45-
whiten=False)), ('svm', SVC(C=1.0, cache_size=200, class_weight=None,
45+
whiten=False)), ('clf', SVC(C=1.0, cache_size=200, class_weight=None,
4646
coef0=0.0, decision_function_shape=None, degree=3, gamma='auto',
4747
kernel='rbf', max_iter=-1, probability=False, random_state=None,
4848
shrinking=True, tol=0.001, verbose=False))])
4949

50-
5150
The utility function :func:`make_pipeline` is a shorthand
5251
for constructing pipelines;
5352
it takes a variable number of estimators and returns a pipeline,
@@ -64,23 +63,23 @@ filling in the names automatically::
6463

6564
The estimators of a pipeline are stored as a list in the ``steps`` attribute::
6665

67-
>>> clf.steps[0]
66+
>>> pipe.steps[0]
6867
('reduce_dim', PCA(copy=True, iterated_power=4, n_components=None, random_state=None,
6968
svd_solver='auto', tol=0.0, whiten=False))
7069

7170
and as a ``dict`` in ``named_steps``::
7271

73-
>>> clf.named_steps['reduce_dim']
72+
>>> pipe.named_steps['reduce_dim']
7473
PCA(copy=True, iterated_power=4, n_components=None, random_state=None,
7574
svd_solver='auto', tol=0.0, whiten=False)
7675

7776
Parameters of the estimators in the pipeline can be accessed using the
7877
``<estimator>__<parameter>`` syntax::
7978

80-
>>> clf.set_params(svm__C=10) # doctest: +NORMALIZE_WHITESPACE
79+
>>> pipe.set_params(clf__C=10) # doctest: +NORMALIZE_WHITESPACE
8180
Pipeline(steps=[('reduce_dim', PCA(copy=True, iterated_power=4,
8281
n_components=None, random_state=None, svd_solver='auto', tol=0.0,
83-
whiten=False)), ('svm', SVC(C=10, cache_size=200, class_weight=None,
82+
whiten=False)), ('clf', SVC(C=10, cache_size=200, class_weight=None,
8483
coef0=0.0, decision_function_shape=None, degree=3, gamma='auto',
8584
kernel='rbf', max_iter=-1, probability=False, random_state=None,
8685
shrinking=True, tol=0.001, verbose=False))])
@@ -90,9 +89,17 @@ This is particularly important for doing grid searches::
9089

9190
>>> from sklearn.model_selection import GridSearchCV
9291
>>> params = dict(reduce_dim__n_components=[2, 5, 10],
93-
... svm__C=[0.1, 10, 100])
94-
>>> grid_search = GridSearchCV(clf, param_grid=params)
92+
... clf__C=[0.1, 10, 100])
93+
>>> grid_search = GridSearchCV(pipe, param_grid=params)
94+
95+
Individual steps may also be replaced as parameters, and non-final steps may be
96+
ignored by setting them to ``None``::
9597

98+
>>> from sklearn.linear_model import LogisticRegression
99+
>>> params = dict(reduce_dim=[None, PCA(5), PCA(10)],
100+
... clf=[SVC(), LogisticRegression()],
101+
... clf__C=[0.1, 10, 100])
102+
>>> grid_search = GridSearchCV(pipe, param_grid=params)
96103

97104
.. topic:: Examples:
98105

@@ -172,6 +179,15 @@ Like pipelines, feature unions have a shorthand constructor called
172179
:func:`make_union` that does not require explicit naming of the components.
173180

174181

182+
Like ``Pipeline``, individual steps may be replaced using ``set_params``,
183+
and ignored by setting to ``None``::
184+
185+
>>> combined.set_params(kernel_pca=None) # doctest: +NORMALIZE_WHITESPACE
186+
FeatureUnion(n_jobs=1, transformer_list=[('linear_pca', PCA(copy=True,
187+
iterated_power=4, n_components=None, random_state=None,
188+
svd_solver='auto', tol=0.0, whiten=False)), ('kernel_pca', None)],
189+
transformer_weights=None)
190+
175191
.. topic:: Examples:
176192

177193
* :ref:`sphx_glr_auto_examples_feature_stacker.py`

doc/whats_new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,12 @@ Enhancements
286286
(`#5805 <https://github.com/scikit-learn/scikit-learn/pull/5805>`_)
287287
By `Ibraim Ganiev`_.
288288

289+
- Added support for substituting or disabling :class:`pipeline.Pipeline`
290+
and :class:`pipeline.FeatureUnion` components using the ``set_params``
291+
interface that powers :mod:`sklearn.grid_search`.
292+
See :ref:`example_plot_compare_reduction.py`. By `Joel Nothman`_ and
293+
`Robert McGibbon`_.
294+
289295
Bug fixes
290296
.........
291297

examples/plot_compare_reduction.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#!/usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
"""
4+
=================================================================
5+
Selecting dimensionality reduction with Pipeline and GridSearchCV
6+
=================================================================
7+
8+
This example constructs a pipeline that does dimensionality
9+
reduction followed by prediction with a support vector
10+
classifier. It demonstrates the use of GridSearchCV and
11+
Pipeline to optimize over different classes of estimators in a
12+
single CV run -- unsupervised PCA and NMF dimensionality
13+
reductions are compared to univariate feature selection during
14+
the grid search.
15+
"""
16+
# Authors: Robert McGibbon, Joel Nothman
17+
18+
from __future__ import print_function, division
19+
20+
import numpy as np
21+
import matplotlib.pyplot as plt
22+
from sklearn.datasets import load_digits
23+
from sklearn.model_selection import GridSearchCV
24+
from sklearn.pipeline import Pipeline
25+
from sklearn.svm import LinearSVC
26+
from sklearn.decomposition import PCA, NMF
27+
from sklearn.feature_selection import SelectKBest, chi2
28+
29+
print(__doc__)
30+
31+
pipe = Pipeline([
32+
('reduce_dim', PCA()),
33+
('classify', LinearSVC())
34+
])
35+
36+
N_FEATURES_OPTIONS = [2, 4, 8]
37+
C_OPTIONS = [1, 10, 100, 1000]
38+
param_grid = [
39+
{
40+
'reduce_dim': [PCA(iterated_power=7), NMF()],
41+
'reduce_dim__n_components': N_FEATURES_OPTIONS,
42+
'classify__C': C_OPTIONS
43+
},
44+
{
45+
'reduce_dim': [SelectKBest(chi2)],
46+
'reduce_dim__k': N_FEATURES_OPTIONS,
47+
'classify__C': C_OPTIONS
48+
},
49+
]
50+
reducer_labels = ['PCA', 'NMF', 'KBest(chi2)']
51+
52+
grid = GridSearchCV(pipe, cv=3, n_jobs=2, param_grid=param_grid)
53+
digits = load_digits()
54+
grid.fit(digits.data, digits.target)
55+
56+
mean_scores = np.array(grid.results_['test_mean_score'])
57+
# scores are in the order of param_grid iteration, which is alphabetical
58+
mean_scores = mean_scores.reshape(len(C_OPTIONS), -1, len(N_FEATURES_OPTIONS))
59+
# select score for best C
60+
mean_scores = mean_scores.max(axis=0)
61+
bar_offsets = (np.arange(len(N_FEATURES_OPTIONS)) *
62+
(len(reducer_labels) + 1) + .5)
63+
64+
plt.figure()
65+
COLORS = 'bgrcmyk'
66+
for i, (label, reducer_scores) in enumerate(zip(reducer_labels, mean_scores)):
67+
plt.bar(bar_offsets + i, reducer_scores, label=label, color=COLORS[i])
68+
69+
plt.title("Comparing feature reduction techniques")
70+
plt.xlabel('Reduced number of features')
71+
plt.xticks(bar_offsets + len(reducer_labels) / 2, N_FEATURES_OPTIONS)
72+
plt.ylabel('Digit classification accuracy')
73+
plt.ylim((0, 1))
74+
plt.legend(loc='upper left')
75+
plt.show()

0 commit comments

Comments
 (0)