8000 [MRG+1] EXA Adding cv indices example by choldgraf · Pull Request #11475 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] EXA Adding cv indices example #11475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 31, 2018
49 changes: 49 additions & 0 deletions doc/modules/cross_validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,14 @@ Example of 2-fold cross-validation on a dataset with 4 samples::
[2 3] [0 1]
[0 1] [2 3]

Here is a visualization of the cross-validation behavior. Note that
:class:`KFold` is not affected by classes or groups.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_004.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%

Each fold is constituted by two arrays: the first one is related to the
*training set*, and the second one to the *test set*.
Thus, one can create the training/test sets using numpy indexing::
Expand Down Expand Up @@ -471,6 +479,14 @@ Here is a usage example::
[2 7 5 8 0 3 4] [6 1 9]
[4 1 0 6 8 9 3] [5 2 7]

Here is a visualization of the cross-validation behavior. Note that
:class:`ShuffleSplit` is not affected by classes or groups.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_006.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%

:class:`ShuffleSplit` is thus a good alternative to :class:`KFold` cross
validation that allows a finer control on the number of iterations and
the proportion of samples on each side of the train / test split.
Expand Down Expand Up @@ -506,6 +522,13 @@ two slightly unbalanced classes::
[0 1 3 4 5 8 9] [2 6 7]
[0 1 2 4 5 6 7] [3 8 9]

Here is a visualization of the cross-validation behavior.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_007.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%

:class:`RepeatedStratifiedKFold` can be used to repeat Stratified K-Fold n times
with different randomization in each repetition.

Expand All @@ -517,6 +540,13 @@ Stratified Shuffle Split
stratified splits, *i.e* which creates splits by preserving the same
percentage for each target class as in the complete set.

Here is a visualization of the cross-validation behavior.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_009.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%

.. _group_cv:

Cross-validation iterators for grouped data.
Expand Down Expand Up @@ -569,6 +599,12 @@ Each subject is in a different testing fold, and the same subject is never in
both testing and training. Notice that the folds do not have exactly the same
size due to the imbalance in the data.

Here is a visualization of the cross-validation behavior.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_005.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%

Leave One Group Out
^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -645,6 +681,13 @@ Here is a usage example::
[2 3 4 5] [0 1 6 7]
[4 5 6 7] [0 1 2 3]

Here is a visualization of the cross-validation behavior.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_008.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%

This class is useful when the behavior of :class:`LeavePGroupsOut` is
desired, but the number of groups is large enough that generating all
possible partitions with :math:`P` groups withheld would be prohibitively
Expand Down Expand Up @@ -709,6 +752,12 @@ Example of 3-split time series cross-validation on a dataset with 6 samples::
[0 1 2 3] [4]
[0 1 2 3 4] [5]

Here is a visualization of the cross-validation behavior.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_010.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%

A note on shuffling
===================
Expand Down
149 changes: 149 additions & 0 deletions examples/model_selection/plot_cv_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
Visualizing cross-validation behavior in scikit-learn
=====================================================

Choosing the right cross-validation object is a crucial part of fitting a
model properly. There are many ways to split data into training and test
sets in order to avoid model overfitting, to standardize the number of
groups in test sets, etc.

This example visualizes the behavior of several common scikit-learn objects
for comparison.
"""

from sklearn.model_selection import (TimeSeriesSplit, KFold, ShuffleSplit,
StratifiedKFold, GroupShuffleSplit,
GroupKFold, StratifiedShuffleSplit)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
np.random.seed(1338)
cmap_data = plt.cm.Paired
cmap_cv = plt.cm.coolwarm
n_splits = 4

###############################################################################
# Visualize our data
# ------------------
#
# First, we must understand the structure of our data. It has 100 randomly
# generated input datapoints, 3 classes split unevenly across datapoints,
# and 10 "groups" split evenly across datapoints.
#
# As we'll see, some cross-validation objects do specific things with
# labeled data, others behave differently with grouped data, and others
# do not use this information.
#
# To begin, we'll visualize our data.

# Generate the class/group data
n_points = 100
X = np.random.randn(100, 10)

percentiles_classes = [.1, .3, .6]
y = np.hstack([[ii] * int(100 * perc)
for ii, perc in enumerate(percentiles_classes)])

# Evenly spaced groups repeated once
groups = np.hstack([[ii] * 10 for ii in range(10)])


def visualize_groups(classes, groups, name):
# Visualize dataset groups
fig, ax = plt.subplots()
ax.scatter(range(len(groups)), [.5] * len(groups), c=groups, marker='_',
lw=50, cmap=cmap_data)
ax.scatter(range(len(groups)), [3.5] * len(groups), c=classes, marker='_',
lw=50, cmap=cmap_data)
ax.set(ylim=[-1, 5], yticks=[.5, 3.5],
yticklabels=['Data\ngroup', 'Data\nclass'], xlabel="Sample index")


visualize_groups(y, groups, 'no groups')

###############################################################################
# Define a function to visualize cross-validation behavior
# --------------------------------------------------------
#
# We'll define a function that lets us visualize the behavior of each
# cross-validation object. We'll perform 4 splits of the data. On each
# split, we'll visualize the indices chosen for the training set
# (in blue) and the test set (in red).


def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
"""Create a sample plot for indices of a cross-validation object."""

# Generate the training/testing visualizations for each CV split
for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
# Fill in indices with the training/test groups
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0

# Visualize the results
ax.scatter(range(len(indices)), [ii + .5] * len(indices),
c=indices, marker='_', lw=lw, cmap=cmap_cv,
vmin=-.2, vmax=1.2)

# Plot the data classes and groups at the end
ax.scatter(range(len(X)), [ii + 1.5] * len(X),
c=y, marker='_', lw=lw, cmap=cmap_data)

ax.scatter(range(len(X)), [ii + 2.5] * len(X),
c=group, marker='_', lw=lw, cmap=cmap_data)

# Formatting
yticklabels = list(range(n_splits)) + ['class', 'group']
ax.set(yticks=np.arange(n_splits+2) + .5, yticklabels=yticklabels,
xlabel='Sample index', ylabel="CV iteration",
ylim=[n_splits+2.2, -.2], xlim=[0, 100])
ax.set_title('{}'.format(type(cv).__name__), fontsize=15)
return ax


###############################################################################
# Let's see how it looks for the `KFold` cross-validation object:

fig, ax = plt.subplots()
cv = KFold(n_splits)
plot_cv_indices(cv, X, y, groups, ax, n_splits)

###############################################################################
# As you can see, by default the KFold cross-validation iterator does not
# take either datapoint class or group into consideration. We can change this
# by using the ``StratifiedKFold`` like so.

fig, ax = plt.subplots()
cv = StratifiedKFold(n_splits)
plot_cv_indices(cv, X, y, groups, ax, n_splits)

###############################################################################
# In this case, the cross-validation retained the same ratio of classes across
# each CV split. Next we'll visualize this behavior for a number of CV
# iterators.
#
# Visualize cross-validation indices for many CV objects
# ------------------------------------------------------
#
# Let's visually compare the cross validation behavior for many
# scikit-learn cross-validation objects. Below we will loop through several
# common cross-validation objects, visualizing the behavior of each.
#
# Note how some use the group/class information while others do not.

cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold,
GroupShuffleSplit, StratifiedShuffleSplit, TimeSeriesSplit]


for cv in cvs:
this_cv = cv(n_splits=n_splits)
fig, ax = plt.subplots(figsize=(6, 3))
plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

ax.legend([Patch(color=cmap_cv(.8)), Patch(color=cmap_cv(.02))],
['Testing set', 'Training set'], loc=(1.02, .8))
# Make the legend fit
plt.tight_layout()
fig.subplots_adjust(right=.7)
plt.show()
0