8000 Merge examples · scikit-learn/scikit-learn@8fcacb7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8fcacb7

Browse files
author
Guillaume Lemaitre
committed
Merge examples
1 parent 05e1dde commit 8fcacb7

File tree

2 files changed

+50
-87
lines changed

2 files changed

+50
-87
lines changed

examples/plot_compare_reduction.py

100644100755
Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,46 @@
1-
#!/usr/bin/python
1+
#!/usr/bin/env python
22
# -*- coding: utf-8 -*-
33
"""
4-
=================================================================
5-
Selecting dimensionality reduction with Pipeline and GridSearchCV
6-
=================================================================
4+
======================================================================
5+
Selecting dimensionality reduction with Pipeline, CachedPipeline, and\
6+
GridSearchCV
7+
======================================================================
78
89
This example constructs a pipeline that does dimensionality
910
reduction followed by prediction with a support vector
10-
classifier. It demonstrates the use of GridSearchCV and
11-
Pipeline to optimize over different classes of estimators in a
12-
single CV run -- unsupervised PCA and NMF dimensionality
11+
classifier. It demonstrates the use of ``GridSearchCV`` and
12+
``Pipeline`` to optimize over different classes of estimators in a
13+
single CV run -- unsupervised ``PCA`` and ``NMF`` dimensionality
1314
reductions are compared to univariate feature selection during
1415
the grid search.
16+
17+
Additionally, ``Pipeline`` can be exchanged with ``CachedPipeline``
18+
to memoize the transformers within the pipeline, avoiding to fit
19+
again the same transformers over and over.
20+
21+
Note that the use of ``CachedPipeline`` becomes interesting when the
22+
fitting of a transformer is costly.
1523
"""
1624
# Authors: Robert McGibbon, Joel Nothman
1725

26+
###############################################################################
27+
# Illustration of ``Pipeline`` and ``GridSearchCV``
28+
###############################################################################
29+
# This section illustrates the use of a ``Pipeline`` with
30+
# ``GridSearchCV``
31+
1832
from __future__ import print_function, division
1933

34+
from tempfile import mkdtemp
2035
import numpy as np
2136
import matplotlib.pyplot as plt
2237
from sklearn.datasets import load_digits
2338
from sklearn.model_selection import GridSearchCV
24-
from sklearn.pipeline import Pipeline
39+
from sklearn.pipeline import Pipeline, CachedPipeline
2540
from sklearn.svm import LinearSVC
2641
from sklearn.decomposition import PCA, NMF
2742
from sklearn.feature_selection import SelectKBest, chi2
28-
29-
print(__doc__)
43+
from sklearn.externals.joblib import Memory
3044

3145
pipe = Pipeline([
3246
('reduce_dim', PCA()),
@@ -73,3 +87,29 @@
7387
plt.ylim((0, 1))
7488
plt.legend(loc='upper left')
7589
plt.show()
90+
91+
###############################################################################
92+
# Illustration of ``CachedPipeline`` instead of ``Pipeline``
93+
###############################################################################
94+
# It is sometimes interesting to store the state of a specific transformer
95+
# since it could be used again. Using a pipeline in ``GridSearchCV`` triggers
96+
# such situations. Therefore, we replace ``Pipeline`` with ``CachedPipeline``
97+
# to memoize the transfomers within the pipeline.
98+
#
99+
# .. warning::
100+
# Note that this example is, however, only an illustration since for this
101+
# specific case fitting PCA is not necessarily slower than loading the
102+
# cache. Hence, use ``CachedPipeline`` when the fitting of a transformer
103+
# is costly.
104+
105+
# Create a temporary folder to store the transformers of the pipeline
106+
cachedir = mkdtemp()
107+
memory = Memory(cachedir=cachedir, verbose=10)
108+
cached_pipe = CachedPipeline([('reduce_dim', PCA()),
109+
('classify', LinearSVC())],
110+
memory=memory)
111+
112+
# This time, a cached pipeline will be used within the grid search
113+
grid = GridSearchCV(cached_pipe, cv=3, n_jobs=2, param_grid=param_grid)
114+
digits = load_digits()
115+
grid.fit(digits.data, digits.target)

examples/plot_compare_reduction_cached.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

0 commit comments

Comments
 (0)
0