Skip to content

Commit 6f74a7c

Browse files
glemaitrejmschrei
authored andcommitted
[MRG+1] ENH add memory to make_pipeline (scikit-learn#8831)
[MRG+2] ENH add memory to make_pipeline
1 parent 0bee058 commit 6f74a7c

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

sklearn/pipeline.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@
1111

1212
from collections import defaultdict
1313

14-
from abc import ABCMeta, abstractmethod
15-
1614
import numpy as np
1715
from scipy import sparse
1816

19-
from .base import clone, BaseEstimator, TransformerMixin
17+
from .base import clone, TransformerMixin
2018
from .externals.joblib import Parallel, delayed, Memory
2119
from .externals import six
2220
from .utils import tosequence
@@ -35,7 +33,7 @@ class Pipeline(_BaseComposition):
3533
Intermediate steps of the pipeline must be 'transforms', that is, they
3634
must implement fit and transform methods.
3735
The final estimator only needs to implement fit.
38-
The transformers in the pipeline can be cached using ```memory`` argument.
36+
The transformers in the pipeline can be cached using ``memory`` argument.
3937
4038
The purpose of the pipeline is to assemble several steps that can be
4139
cross-validated together while setting different parameters.
@@ -527,13 +525,27 @@ def _name_estimators(estimators):
527525
return list(zip(names, estimators))
528526

529527

530-
def make_pipeline(*steps):
528+
def make_pipeline(*steps, **kwargs):
531529
"""Construct a Pipeline from the given estimators.
532530
533531
This is a shorthand for the Pipeline constructor; it does not require, and
534532
does not permit, naming the estimators. Instead, their names will be set
535533
to the lowercase of their types automatically.
536534
535+
Parameters
536+
----------
537+
*steps : list of estimators,
538+
539+
memory : Instance of joblib.Memory or string, optional (default=None)
540+
Used to cache the fitted transformers of the pipeline. By default,
541+
no caching is performed. If a string is given, it is the path to
542+
the caching directory. Enabling caching triggers a clone of
543+
the transformers before fitting. Therefore, the transformer
544+
instance given to the pipeline cannot be inspected
545+
directly. Use the attribute ``named_steps`` or ``steps`` to
546+
inspect estimators within the pipeline. Caching the
547+
transformers is advantageous when fitting is time consuming.
548+
537549
Examples
538550
--------
539551
>>> from sklearn.naive_bayes import GaussianNB
@@ -549,7 +561,11 @@ def make_pipeline(*steps):
549561
-------
550562
p : Pipeline
551563
"""
552-
return Pipeline(_name_estimators(steps))
564+
memory = kwargs.pop('memory', None)
565+
if kwargs:
566+
raise TypeError('Unknown keyword arguments: "{}"'
567+
.format(list(kwargs.keys())[0]))
568+
return Pipeline(_name_estimators(steps), memory=memory)
553569

554570

555571
def _fit_one_transformer(transformer, X, y):

sklearn/tests/test_pipeline.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,12 @@ def test_make_pipeline():
637637
assert_equal(pipe.steps[1][0], "transf-2")
638638
assert_equal(pipe.steps[2][0], "fitparamt")
639639

640+
assert_raise_message(
641+
TypeError,
642+
'Unknown keyword arguments: "random_parameter"',
643+
make_pipeline, t1, t2, random_parameter='rnd'
644+
)
645+
640646

641647
def test_feature_union_weights():
642648
# test feature union with transformer weights
@@ -911,3 +917,14 @@ def test_pipeline_memory():
911917
assert_equal(ts, cached_pipe_2.named_steps['transf_2'].timestamp_)
912918
finally:
913919
shutil.rmtree(cachedir)
920+
921+
922+
def test_make_pipeline_memory():
923+
cachedir = mkdtemp()
924+
memory = Memory(cachedir=cachedir)
925+
pipeline = make_pipeline(DummyTransf(), SVC(), memory=memory)
926+
assert_true(pipeline.memory is memory)
927+
pipeline = make_pipeline(DummyTransf(), SVC())
928+
assert_true(pipeline.memory is None)
929+
930+
shutil.rmtree(cachedir)

0 commit comments

Comments
 (0)