From 9b3d6603f57fac3aa06e157e039e8060d728feab Mon Sep 17 00:00:00 2001 From: Chris Holdgraf Date: Tue, 10 Jul 2018 18:36:35 -0500 Subject: [PATCH 01/10] adding cv indices example --- examples/model_selection/plot_cv_indices.py | 120 ++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 examples/model_selection/plot_cv_indices.py diff --git a/examples/model_selection/plot_cv_indices.py b/examples/model_selection/plot_cv_indices.py new file mode 100644 index 0000000000000..7531433616d5f --- /dev/null +++ b/examples/model_selection/plot_cv_indices.py @@ -0,0 +1,120 @@ +""" +Visualizing cross-validation behavior in scikit-learn +===================================================== + +Choosing the right cross-validation object is a crucial part of fitting a +model properly. There are many ways to split data into training and test +sets in order to avoid model overfitting, to standardize the number of +labels in test sets, etc. + +This example visualizes the behavior of several common scikit-learn objects +for comparison. +""" + +from sklearn.model_selection import (TimeSeriesSplit, KFold, ShuffleSplit, + GroupShuffleSplit, GroupKFold) +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patches import Patch + +######################## +# Visualize our data +# ------------------ +# +# First, let's visualize the structure of our raw data. We have 100 data +# points in total, each one with a label attached to it, encoded by +# color. There are 10 labels in total. +# +# As we'll see, some cross-validation objects do specific things with +# labeled data, while others do not. + +# Create our dataset +labels = np.hstack([[ii] * 10 for ii in range(5)]) +labels = np.hstack([labels, labels]) + +# Visualize dataset labels +fig, ax = plt.subplots() +ax.scatter(range(len(labels)), [.5] * len(labels), c=labels, marker='_', + lw=50, cmap='rainbow') +ax.set(ylim=[-3, 4], title="Data labels (color = label)", + xlabel="Sample index") +plt.setp(ax.get_yticklabels() + ax.yaxis.majorTicks, visible=False) + + +############################################################################### +# Define a function to visualize cross-validation behavior +# -------------------------------------------------------- +# +# Next we'll define a function that lets us visualize the behavior of each +# cross-validation object. We'll perform 5 splits of the data. On each +# split, we'll visualize the indices chosen for the training set +# (in blue) and the test set (in red). + +def plot_cv_indices(cv, data, ax, n_splits, lw=10): + """Create a sample plot for indices of a cross-validation object.""" + + # Generate the training/testing visualizations for each CV split + for ii, (tr, tt) in enumerate(cv.split(data, groups=data)): + # Fill in indices with the training/test labels + indices = np.array([np.nan] * len(data)) + indices[tt] = 1 + indices[tr] = 0 + + # Visualize the results + ax.scatter(range(len(indices)), [ii + .5] * len(indices), + c=indices, marker='_', lw=lw, cmap=plt.cm.coolwarm, + vmin=-.2, vmax=1.2) + + # Add white bars for group splits + ixs_splits = np.arange(0, len(data), 10) - .5 + ax.scatter(ixs_splits, [ii + .5] * len(ixs_splits), marker='_', + lw=lw, c='w') + + + # Plot the data at the end + ax.scatter(range(len(data)), [ii + 1.5] * len(data), + c=data, marker='_', lw=lw, cmap=plt.cm.rainbow) + + # Formatting + yticklabels = list(range(n_splits)) + ['labels'] + ax.set(yticks=np.arange(n_splits+1) + .5, yticklabels=yticklabels, + xlabel='Sample index', ylabel="CV iteration", + ylim=[n_splits+1.2, -.2], xlim=[0, 100], + title='{}'.format(type(cv).__name__)) + return ax + + +############################################################################### +# Let's see how it looks for the `KFold` cross-validation object: + +fig, ax = plt.subplots() +n_splits = 5 +cv = KFold(n_splits) +plot_cv_indices(cv, labels, ax, n_splits) + +############################################################################### +# Visualize cross-validation indices for many CV objects +# ------------------------------------------------------ +# +# Finally, let's visually-compare the cross validation behavior for many +# scikit-learn cross-validation objects. Below we will loop through several +# common cross-validation objects, visualizing the behavior of each. +# +# Note that some keep labels together, while others ignore +# label identity completely. Some have overlapping test sets between CV +# splits, while others do not. + +cvs = [ShuffleSplit(n_splits=5), GroupShuffleSplit(n_splits=5), + KFold(n_splits=5), GroupKFold(n_splits=5), TimeSeriesSplit(n_splits=5)] + + +fig, axs = plt.subplots(len(cvs), 1, figsize=(6, 3*len(cvs)), sharex=True) +for cv, ax in zip(cvs, axs.ravel()): + plot_cv_indices(cv, labels, ax, n_splits) + +cmap = plt.cm.coolwarm +axs[-1].legend([Patch(color=cmap(.8)), Patch(color=cmap(.2))], + ['Testing set', 'Training set'], loc=(.7, .8)) +plt.setp([ax for ax in axs[1:-1]], xlabel='') +plt.tight_layout() +plt.show() From 577975afc6c7f87d58c5411a680c222231d3c2a6 Mon Sep 17 00:00:00 2001 From: Chris Holdgraf Date: Sat, 14 Jul 2018 08:26:49 -0700 Subject: [PATCH 02/10] updating cv indices example for multiple data groups --- examples/model_selection/plot_cv_indices.py | 145 +++++++++++++++----- 1 file changed, 109 insertions(+), 36 deletions(-) diff --git a/examples/model_selection/plot_cv_indices.py b/examples/model_selection/plot_cv_indices.py index 7531433616d5f..4d58aadd7f92d 100644 --- a/examples/model_selection/plot_cv_indices.py +++ b/examples/model_selection/plot_cv_indices.py @@ -5,58 +5,67 @@ Choosing the right cross-validation object is a crucial part of fitting a model properly. There are many ways to split data into training and test sets in order to avoid model overfitting, to standardize the number of -labels in test sets, etc. +groups in test sets, etc. This example visualizes the behavior of several common scikit-learn objects for comparison. """ from sklearn.model_selection import (TimeSeriesSplit, KFold, ShuffleSplit, - GroupShuffleSplit, GroupKFold) + StratifiedKFold, GroupShuffleSplit, + GroupKFold) import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Patch -######################## +############################################################################### # Visualize our data # ------------------ # -# First, let's visualize the structure of our raw data. We have 100 data -# points in total, each one with a label attached to it, encoded by -# color. There are 10 labels in total. +# First, we must understand the structure of our data. We'll take a look at +# several datasets. They'll each have 100 randomly generated datapoints, +# but different "labels" that assign the data to groups. # # As we'll see, some cross-validation objects do specific things with # labeled data, while others do not. +# +# To begin, we'll use a dataset in which each datapoint belongs to the same +# group (or, put another way, there are no groups). -# Create our dataset -labels = np.hstack([[ii] * 10 for ii in range(5)]) -labels = np.hstack([labels, labels]) +n_points = 100 +X = np.random.randn(100, 10) +y = np.ones(100) +y[50:] = 0 -# Visualize dataset labels -fig, ax = plt.subplots() -ax.scatter(range(len(labels)), [.5] * len(labels), c=labels, marker='_', - lw=50, cmap='rainbow') -ax.set(ylim=[-3, 4], title="Data labels (color = label)", - xlabel="Sample index") -plt.setp(ax.get_yticklabels() + ax.yaxis.majorTicks, visible=False) +groups = np.ones(n_points) +def visualize_groups(groups, name): + # Visualize dataset groups + fig, ax = plt.subplots() + ax.scatter(range(len(groups)), [.5] * len(groups), c=groups, marker='_', + lw=50, cmap='rainbow') + ax.set(ylim=[-3, 4], title="Data groups (color = label)\n{}".format(name), + xlabel="Sample index") + plt.setp(ax.get_yticklabels() + ax.yaxis.majorTicks, visible=False) + +visualize_groups(groups, 'no groups') ############################################################################### # Define a function to visualize cross-validation behavior # -------------------------------------------------------- # -# Next we'll define a function that lets us visualize the behavior of each +# We'll define a function that lets us visualize the behavior of each # cross-validation object. We'll perform 5 splits of the data. On each # split, we'll visualize the indices chosen for the training set # (in blue) and the test set (in red). -def plot_cv_indices(cv, data, ax, n_splits, lw=10): +def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): """Create a sample plot for indices of a cross-validation object.""" # Generate the training/testing visualizations for each CV split - for ii, (tr, tt) in enumerate(cv.split(data, groups=data)): - # Fill in indices with the training/test labels - indices = np.array([np.nan] * len(data)) + for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)): + # Fill in indices with the training/test groups + indices = np.array([np.nan] * len(X)) indices[tt] = 1 indices[tr] = 0 @@ -65,18 +74,13 @@ def plot_cv_indices(cv, data, ax, n_splits, lw=10): c=indices, marker='_', lw=lw, cmap=plt.cm.coolwarm, vmin=-.2, vmax=1.2) - # Add white bars for group splits - ixs_splits = np.arange(0, len(data), 10) - .5 - ax.scatter(ixs_splits, [ii + .5] * len(ixs_splits), marker='_', - lw=lw, c='w') - # Plot the data at the end - ax.scatter(range(len(data)), [ii + 1.5] * len(data), - c=data, marker='_', lw=lw, cmap=plt.cm.rainbow) + ax.scatter(range(len(X)), [ii + 1.5] * len(X), + c=group, marker='_', lw=lw, cmap=plt.cm.rainbow) # Formatting - yticklabels = list(range(n_splits)) + ['labels'] + yticklabels = list(range(n_splits)) + ['groups'] ax.set(yticks=np.arange(n_splits+1) + .5, yticklabels=yticklabels, xlabel='Sample index', ylabel="CV iteration", ylim=[n_splits+1.2, -.2], xlim=[0, 100], @@ -90,29 +94,98 @@ def plot_cv_indices(cv, data, ax, n_splits, lw=10): fig, ax = plt.subplots() n_splits = 5 cv = KFold(n_splits) -plot_cv_indices(cv, labels, ax, n_splits) +plot_cv_indices(cv, X, y, groups, ax, n_splits) ############################################################################### # Visualize cross-validation indices for many CV objects # ------------------------------------------------------ # -# Finally, let's visually-compare the cross validation behavior for many +# Let's visually compare the cross validation behavior for many # scikit-learn cross-validation objects. Below we will loop through several # common cross-validation objects, visualizing the behavior of each. # -# Note that some keep labels together, while others ignore +# In this case, there is only a single group, so grouping behavior won't +# really apply. + +cvs = [ShuffleSplit(n_splits=5), KFold(n_splits=5), + TimeSeriesSplit(n_splits=5)] + + +fig, axs = plt.subplots(len(cvs), 1, figsize=(6, 3*len(cvs)), sharex=True) +for ax, cv in zip(axs, cvs): + plot_cv_indices(cv, X, y, groups, ax, n_splits) + +cmap = plt.cm.coolwarm +axs[-1].legend([Patch(color=cmap(.8)), Patch(color=cmap(.2))], + ['Testing set', 'Training set'], loc=(.7, .8)) +plt.setp([ax for ax in axs[1:-1]], xlabel='') +plt.tight_layout() + +############################################################################### +# Using data with balanced groups +# ------------------------------- +# +# Next we'll take a look at some data that has several groups, each with the +# same number of members. Here's what the data look like. + +groups_even = np.hstack([[ii] * 10 for ii in range(5)]) +groups_even = np.hstack([groups_even, groups_even]) +y = groups_even +visualize_groups(groups_even, 'balanced groups') + +############################################################################### +# We'll visualize these groups with several cross-validation objects that +# are relevant to grouped data. +# +# Note that some keep groups together, while others ignore # label identity completely. Some have overlapping test sets between CV # splits, while others do not. + cvs = [ShuffleSplit(n_splits=5), GroupShuffleSplit(n_splits=5), - KFold(n_splits=5), GroupKFold(n_splits=5), TimeSeriesSplit(n_splits=5)] + KFold(n_splits=5), GroupKFold(n_splits=5), StratifiedKFold(n_splits=5), + TimeSeriesSplit(n_splits=5)] fig, axs = plt.subplots(len(cvs), 1, figsize=(6, 3*len(cvs)), sharex=True) -for cv, ax in zip(cvs, axs.ravel()): - plot_cv_indices(cv, labels, ax, n_splits) +for ax, cv in zip(axs, cvs): + plot_cv_indices(cv, X, y, groups_even, ax, n_splits) + +axs[-1].legend([Patch(color=cmap(.8)), Patch(color=cmap(.2))], + ['Testing set', 'Training set'], loc=(.7, .8)) +plt.setp([ax for ax in axs[1:-1]], xlabel='') +plt.tight_layout() + +############################################################################### +# Using data in with imbalanced groups +# ------------------------------------ +# +# Finally, let's see how these cross-validation objects behave with imbalanced +# groups. + +percentiles = [.05, .1, .15, .2, .5] +groups_imbalanced = np.hstack([[ii] * int(100 * perc) + for ii, perc in enumerate(percentiles)]) +y = groups_imbalanced +visualize_groups(groups_imbalanced, 'imbalanced groups') + +############################################################################### +# We'll visualize these groups with several cross-validation objects that +# are relevant to grouped data with imbalanced groups. +# +# Several scikit-learn CV objects take special consideration to maintain +# the ratio of group membership present in the data. + + +cvs = [ShuffleSplit(n_splits=5), GroupShuffleSplit(n_splits=5), + KFold(n_splits=5), GroupKFold(n_splits=5), StratifiedKFold(n_splits=5), + TimeSeriesSplit(n_splits=5)] + + +fig, axs = plt.subplots(len(cvs), 1, figsize=(6, 3*len(cvs)), sharex=True) +for ax, cv in zip(axs, cvs): + plot_cv_indices(cv, X, y, groups_imbalanced, ax, n_splits) -cmap = plt.cm.coolwarm axs[-1].legend([Patch(color=cmap(.8)), Patch(color=cmap(.2))], ['Testing set', 'Training set'], loc=(.7, .8)) plt.setp([ax for ax in axs[1:-1]], xlabel='') From f6a3b4a5717d532680365ccfe293edb78b56acb1 Mon Sep 17 00:00:00 2001 From: Chris Holdgraf Date: Mon, 16 Jul 2018 14:37:09 -0700 Subject: [PATCH 03/10] cv indices example down to one dataset --- examples/model_selection/plot_cv_indices.py | 164 +++++++------------- 1 file changed, 60 insertions(+), 104 deletions(-) diff --git a/examples/model_selection/plot_cv_indices.py b/examples/model_selection/plot_cv_indices.py index 4d58aadd7f92d..a34c7e6e8eaf6 100644 --- a/examples/model_selection/plot_cv_indices.py +++ b/examples/model_selection/plot_cv_indices.py @@ -17,38 +17,49 @@ import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Patch +np.random.seed(1337) +cmap_data = plt.cm.tab10 +cmap_cv = plt.cm.coolwarm ############################################################################### # Visualize our data # ------------------ # -# First, we must understand the structure of our data. We'll take a look at -# several datasets. They'll each have 100 randomly generated datapoints, -# but different "labels" that assign the data to groups. +# First, we must understand the structure of our data. It has 100 randomly +# generated input datapoints, 5 labels split unevenly across datapoints, +# and five "groups" split unevenly across datapoints. # # As we'll see, some cross-validation objects do specific things with -# labeled data, while others do not. +# labeled data, others behave differently with grouped data, and others +# do not use this information. # -# To begin, we'll use a dataset in which each datapoint belongs to the same -# group (or, put another way, there are no groups). +# To begin, create an visualize our data +# Generate the label/group data n_points = 100 X = np.random.randn(100, 10) -y = np.ones(100) -y[50:] = 0 -groups = np.ones(n_points) +percentiles_labels = [.1, .2, .3, .4] +y = np.hstack([[ii] * int(100 * perc) + for ii, perc in enumerate(percentiles_labels)]) -def visualize_groups(groups, name): +# Evenly spaced groups repeated once +groups = np.hstack([[ii] * 10 for ii in range(5)]) +groups = np.hstack([groups, groups]) + + +def visualize_groups(labels, groups, name): # Visualize dataset groups fig, ax = plt.subplots() ax.scatter(range(len(groups)), [.5] * len(groups), c=groups, marker='_', - lw=50, cmap='rainbow') - ax.set(ylim=[-3, 4], title="Data groups (color = label)\n{}".format(name), - xlabel="Sample index") - plt.setp(ax.get_yticklabels() + ax.yaxis.majorTicks, visible=False) + lw=50, cmap=cmap_data) + ax.scatter(range(len(groups)), [3.5] * len(groups), c=labels, marker='_', + lw=50, cmap=cmap_data) + ax.set(ylim=[-1, 5], yticks=[.5, 3.5], + yticklabels=['Data groups', 'Data label'], xlabel="Sample index") + -visualize_groups(groups, 'no groups') +visualize_groups(y, groups, 'no groups') ############################################################################### # Define a function to visualize cross-validation behavior @@ -59,6 +70,7 @@ def visualize_groups(groups, name): # split, we'll visualize the indices chosen for the training set # (in blue) and the test set (in red). + def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): """Create a sample plot for indices of a cross-validation object.""" @@ -71,19 +83,21 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): # Visualize the results ax.scatter(range(len(indices)), [ii + .5] * len(indices), - c=indices, marker='_', lw=lw, cmap=plt.cm.coolwarm, + c=indices, marker='_', lw=lw, cmap=cmap_cv, vmin=-.2, vmax=1.2) - - # Plot the data at the end + # Plot the data labels and groups at the end ax.scatter(range(len(X)), [ii + 1.5] * len(X), - c=group, marker='_', lw=lw, cmap=plt.cm.rainbow) + c=y, marker='_', lw=lw, cmap=cmap_data) + + ax.scatter(range(len(X)), [ii + 2.5] * len(X), + c=group, marker='_', lw=lw, cmap=cmap_data) # Formatting - yticklabels = list(range(n_splits)) + ['groups'] - ax.set(yticks=np.arange(n_splits+1) + .5, yticklabels=yticklabels, + yticklabels = list(range(n_splits)) + ['labels', 'groups'] + ax.set(yticks=np.arange(n_splits+2) + .5, yticklabels=yticklabels, xlabel='Sample index', ylabel="CV iteration", - ylim=[n_splits+1.2, -.2], xlim=[0, 100], + ylim=[n_splits+2.2, -.2], xlim=[0, 100], title='{}'.format(type(cv).__name__)) return ax @@ -92,11 +106,24 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): # Let's see how it looks for the `KFold` cross-validation object: fig, ax = plt.subplots() -n_splits = 5 +n_splits = 4 cv = KFold(n_splits) plot_cv_indices(cv, X, y, groups, ax, n_splits) ############################################################################### +# As you can see, by default the KFold cross-validation iterator does not +# take either datapoint label or group into consideration. We can change this +# by using the ``StratifiedKFold`` like so. + +fig, ax = plt.subplots() +cv = StratifiedKFold(n_splits) +plot_cv_indices(cv, X, y, groups, ax, n_splits) + +############################################################################### +# In this case, the cross-validation retained the same ratio of labels across +# each CV split. Next we'll visualize this behavior for a number of CV +# iterators. +# # Visualize cross-validation indices for many CV objects # ------------------------------------------------------ # @@ -104,90 +131,19 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): # scikit-learn cross-validation objects. Below we will loop through several # common cross-validation objects, visualizing the behavior of each. # -# In this case, there is only a single group, so grouping behavior won't -# really apply. +# Note how some use the group/label information while others do not. -cvs = [ShuffleSplit(n_splits=5), KFold(n_splits=5), - TimeSeriesSplit(n_splits=5)] +cvs = [KFold(n_splits=n_splits), GroupKFold(n_splits=n_splits), + ShuffleSplit(n_splits=n_splits), StratifiedKFold(n_splits=n_splits), + GroupShuffleSplit(n_splits=n_splits), + TimeSeriesSplit(n_splits=n_splits)] -fig, axs = plt.subplots(len(cvs), 1, figsize=(6, 3*len(cvs)), sharex=True) -for ax, cv in zip(axs, cvs): +for cv in cvs: + fig, ax = plt.subplots(figsize=(6, 3)) plot_cv_indices(cv, X, y, groups, ax, n_splits) -cmap = plt.cm.coolwarm -axs[-1].legend([Patch(color=cmap(.8)), Patch(color=cmap(.2))], - ['Testing set', 'Training set'], loc=(.7, .8)) -plt.setp([ax for ax in axs[1:-1]], xlabel='') -plt.tight_layout() - -############################################################################### -# Using data with balanced groups -# ------------------------------- -# -# Next we'll take a look at some data that has several groups, each with the -# same number of members. Here's what the data look like. - -groups_even = np.hstack([[ii] * 10 for ii in range(5)]) -groups_even = np.hstack([groups_even, groups_even]) -y = groups_even -visualize_groups(groups_even, 'balanced groups') - -############################################################################### -# We'll visualize these groups with several cross-validation objects that -# are relevant to grouped data. -# -# Note that some keep groups together, while others ignore -# label identity completely. Some have overlapping test sets between CV -# splits, while others do not. - - -cvs = [ShuffleSplit(n_splits=5), GroupShuffleSplit(n_splits=5), - KFold(n_splits=5), GroupKFold(n_splits=5), StratifiedKFold(n_splits=5), - TimeSeriesSplit(n_splits=5)] - - -fig, axs = plt.subplots(len(cvs), 1, figsize=(6, 3*len(cvs)), sharex=True) -for ax, cv in zip(axs, cvs): - plot_cv_indices(cv, X, y, groups_even, ax, n_splits) - -axs[-1].legend([Patch(color=cmap(.8)), Patch(color=cmap(.2))], - ['Testing set', 'Training set'], loc=(.7, .8)) -plt.setp([ax for ax in axs[1:-1]], xlabel='') -plt.tight_layout() - -############################################################################### -# Using data in with imbalanced groups -# ------------------------------------ -# -# Finally, let's see how these cross-validation objects behave with imbalanced -# groups. - -percentiles = [.05, .1, .15, .2, .5] -groups_imbalanced = np.hstack([[ii] * int(100 * perc) - for ii, perc in enumerate(percentiles)]) -y = groups_imbalanced -visualize_groups(groups_imbalanced, 'imbalanced groups') - -############################################################################### -# We'll visualize these groups with several cross-validation objects that -# are relevant to grouped data with imbalanced groups. -# -# Several scikit-learn CV objects take special consideration to maintain -# the ratio of group membership present in the data. - - -cvs = [ShuffleSplit(n_splits=5), GroupShuffleSplit(n_splits=5), - KFold(n_splits=5), GroupKFold(n_splits=5), StratifiedKFold(n_splits=5), - TimeSeriesSplit(n_splits=5)] - - -fig, axs = plt.subplots(len(cvs), 1, figsize=(6, 3*len(cvs)), sharex=True) -for ax, cv in zip(axs, cvs): - plot_cv_indices(cv, X, y, groups_imbalanced, ax, n_splits) - -axs[-1].legend([Patch(color=cmap(.8)), Patch(color=cmap(.2))], - ['Testing set', 'Training set'], loc=(.7, .8)) -plt.setp([ax for ax in axs[1:-1]], xlabel='') -plt.tight_layout() + ax.legend([Patch(color=cmap_cv(.8)), Patch(color=cmap_cv(.2))], + ['Testing set', 'Training set'], loc=(1.02, .8)) + plt.tight_layout() plt.show() From 5a1bd3bd253abff103dc9ea5054ccb07a633ec95 Mon Sep 17 00:00:00 2001 From: Chris Holdgraf Date: Tue, 17 Jul 2018 22:29:26 -0700 Subject: [PATCH 04/10] new colormap and fewer classes in example --- examples/model_selection/plot_cv_indices.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/model_selection/plot_cv_indices.py b/examples/model_selection/plot_cv_indices.py index a34c7e6e8eaf6..d4ca2c256b7a3 100644 --- a/examples/model_selection/plot_cv_indices.py +++ b/examples/model_selection/plot_cv_indices.py @@ -13,12 +13,12 @@ from sklearn.model_selection import (TimeSeriesSplit, KFold, ShuffleSplit, StratifiedKFold, GroupShuffleSplit, - GroupKFold) + GroupKFold, StratifiedShuffleSplit) import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Patch np.random.seed(1337) -cmap_data = plt.cm.tab10 +cmap_data = plt.cm.Paired cmap_cv = plt.cm.coolwarm ############################################################################### @@ -39,13 +39,12 @@ n_points = 100 X = np.random.randn(100, 10) -percentiles_labels = [.1, .2, .3, .4] +percentiles_labels = [.1, .3, .6] y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_labels)]) # Evenly spaced groups repeated once -groups = np.hstack([[ii] * 10 for ii in range(5)]) -groups = np.hstack([groups, groups]) +groups = np.hstack([[ii] * 10 for ii in range(10)]) def visualize_groups(labels, groups, name): @@ -133,15 +132,14 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): # # Note how some use the group/label information while others do not. -cvs = [KFold(n_splits=n_splits), GroupKFold(n_splits=n_splits), - ShuffleSplit(n_splits=n_splits), StratifiedKFold(n_splits=n_splits), - GroupShuffleSplit(n_splits=n_splits), - TimeSeriesSplit(n_splits=n_splits)] +cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold, + GroupShuffleSplit, StratifiedShuffleSplit, TimeSeriesSplit] for cv in cvs: + this_cv = cv(n_splits=n_splits) fig, ax = plt.subplots(figsize=(6, 3)) - plot_cv_indices(cv, X, y, groups, ax, n_splits) + plot_cv_indices(this_cv, X, y, groups, ax, n_splits) ax.legend([Patch(color=cmap_cv(.8)), Patch(color=cmap_cv(.2))], ['Testing set', 'Training set'], loc=(1.02, .8)) From 81ab5c87801d48e5030bcd5bbc88d6bb8097e69d Mon Sep 17 00:00:00 2001 From: Chris Holdgraf Date: Thu, 19 Jul 2018 08:44:39 -0700 Subject: [PATCH 05/10] small updates --- examples/model_selection/plot_cv_indices.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/model_selection/plot_cv_indices.py b/examples/model_selection/plot_cv_indices.py index d4ca2c256b7a3..8da83da305e5c 100644 --- a/examples/model_selection/plot_cv_indices.py +++ b/examples/model_selection/plot_cv_indices.py @@ -26,14 +26,14 @@ # ------------------ # # First, we must understand the structure of our data. It has 100 randomly -# generated input datapoints, 5 labels split unevenly across datapoints, -# and five "groups" split unevenly across datapoints. +# generated input datapoints, 3 labels split unevenly across datapoints, +# and 10 "groups" split unevenly across datapoints. # # As we'll see, some cross-validation objects do specific things with # labeled data, others behave differently with grouped data, and others # do not use this information. # -# To begin, create an visualize our data +# To begin, we'll visualize our data. # Generate the label/group data n_points = 100 From df19254812803706c59c2c564c0f875edd927ee1 Mon Sep 17 00:00:00 2001 From: Chris Holdgraf Date: Thu, 19 Jul 2018 10:03:25 -0700 Subject: [PATCH 06/10] small comments --- examples/model_selection/plot_cv_indices.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/model_selection/plot_cv_indices.py b/examples/model_selection/plot_cv_indices.py index 8da83da305e5c..aeb9c7e33794b 100644 --- a/examples/model_selection/plot_cv_indices.py +++ b/examples/model_selection/plot_cv_indices.py @@ -20,6 +20,7 @@ np.random.seed(1337) cmap_data = plt.cm.Paired cmap_cv = plt.cm.coolwarm +n_splits = 4 ############################################################################### # Visualize our data @@ -27,7 +28,7 @@ # # First, we must understand the structure of our data. It has 100 randomly # generated input datapoints, 3 labels split unevenly across datapoints, -# and 10 "groups" split unevenly across datapoints. +# and 10 "groups" split evenly across datapoints. # # As we'll see, some cross-validation objects do specific things with # labeled data, others behave differently with grouped data, and others @@ -65,7 +66,7 @@ def visualize_groups(labels, groups, name): # -------------------------------------------------------- # # We'll define a function that lets us visualize the behavior of each -# cross-validation object. We'll perform 5 splits of the data. On each +# cross-validation object. We'll perform 4 splits of the data. On each # split, we'll visualize the indices chosen for the training set # (in blue) and the test set (in red). @@ -105,7 +106,6 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): # Let's see how it looks for the `KFold` cross-validation object: fig, ax = plt.subplots() -n_splits = 4 cv = KFold(n_splits) plot_cv_indices(cv, X, y, groups, ax, n_splits) From c903d9ffd046ff8cfec27705d4fc8f31e330dba1 Mon Sep 17 00:00:00 2001 From: Chris Holdgraf Date: Sun, 22 Jul 2018 09:52:07 -0400 Subject: [PATCH 07/10] better RNG and fixing labeling --- examples/model_selection/plot_cv_indices.py | 30 ++++++++++----------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/model_selection/plot_cv_indices.py b/examples/model_selection/plot_cv_indices.py index aeb9c7e33794b..bb3a5c4b8c507 100644 --- a/examples/model_selection/plot_cv_indices.py +++ b/examples/model_selection/plot_cv_indices.py @@ -17,7 +17,7 @@ import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Patch -np.random.seed(1337) +np.random.seed(1338) cmap_data = plt.cm.Paired cmap_cv = plt.cm.coolwarm n_splits = 4 @@ -27,7 +27,7 @@ # ------------------ # # First, we must understand the structure of our data. It has 100 randomly -# generated input datapoints, 3 labels split unevenly across datapoints, +# generated input datapoints, 3 classes split unevenly across datapoints, # and 10 "groups" split evenly across datapoints. # # As we'll see, some cross-validation objects do specific things with @@ -36,27 +36,27 @@ # # To begin, we'll visualize our data. -# Generate the label/group data +# Generate the class/group data n_points = 100 X = np.random.randn(100, 10) -percentiles_labels = [.1, .3, .6] +percentiles_classes = [.1, .3, .6] y = np.hstack([[ii] * int(100 * perc) - for ii, perc in enumerate(percentiles_labels)]) + for ii, perc in enumerate(percentiles_classes)]) # Evenly spaced groups repeated once groups = np.hstack([[ii] * 10 for ii in range(10)]) -def visualize_groups(labels, groups, name): +def visualize_groups(classes, groups, name): # Visualize dataset groups fig, ax = plt.subplots() ax.scatter(range(len(groups)), [.5] * len(groups), c=groups, marker='_', lw=50, cmap=cmap_data) - ax.scatter(range(len(groups)), [3.5] * len(groups), c=labels, marker='_', + ax.scatter(range(len(groups)), [3.5] * len(groups), c=classes, marker='_', lw=50, cmap=cmap_data) ax.set(ylim=[-1, 5], yticks=[.5, 3.5], - yticklabels=['Data groups', 'Data label'], xlabel="Sample index") + yticklabels=['Data\ngroup', 'Data\nclass'], xlabel="Sample index") visualize_groups(y, groups, 'no groups') @@ -86,7 +86,7 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): c=indices, marker='_', lw=lw, cmap=cmap_cv, vmin=-.2, vmax=1.2) - # Plot the data labels and groups at the end + # Plot the data classes and groups at the end ax.scatter(range(len(X)), [ii + 1.5] * len(X), c=y, marker='_', lw=lw, cmap=cmap_data) @@ -94,11 +94,11 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): c=group, marker='_', lw=lw, cmap=cmap_data) # Formatting - yticklabels = list(range(n_splits)) + ['labels', 'groups'] + yticklabels = list(range(n_splits)) + ['class', 'group'] ax.set(yticks=np.arange(n_splits+2) + .5, yticklabels=yticklabels, xlabel='Sample index', ylabel="CV iteration", - ylim=[n_splits+2.2, -.2], xlim=[0, 100], - title='{}'.format(type(cv).__name__)) + ylim=[n_splits+2.2, -.2], xlim=[0, 100]) + ax.set_title('{}'.format(type(cv).__name__), fontsize=15) return ax @@ -111,7 +111,7 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): ############################################################################### # As you can see, by default the KFold cross-validation iterator does not -# take either datapoint label or group into consideration. We can change this +# take either datapoint class or group into consideration. We can change this # by using the ``StratifiedKFold`` like so. fig, ax = plt.subplots() @@ -119,7 +119,7 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): plot_cv_indices(cv, X, y, groups, ax, n_splits) ############################################################################### -# In this case, the cross-validation retained the same ratio of labels across +# In this case, the cross-validation retained the same ratio of classes across # each CV split. Next we'll visualize this behavior for a number of CV # iterators. # @@ -130,7 +130,7 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): # scikit-learn cross-validation objects. Below we will loop through several # common cross-validation objects, visualizing the behavior of each. # -# Note how some use the group/label information while others do not. +# Note how some use the group/class information while others do not. cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold, GroupShuffleSplit, StratifiedShuffleSplit, TimeSeriesSplit] From 686f466bc189e2f12ce083a6bd40d0a3fba3467b Mon Sep 17 00:00:00 2001 From: Chris Holdgraf Date: Sun, 22 Jul 2018 13:44:27 -0400 Subject: [PATCH 08/10] test adding cv indices images to docs --- doc/modules/cross_validation.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst index 328270b086ed3..f58e69e834765 100644 --- a/doc/modules/cross_validation.rst +++ b/doc/modules/cross_validation.rst @@ -323,6 +323,14 @@ Example of 2-fold cross-validation on a dataset with 4 samples:: [2 3] [0 1] [0 1] [2 3] +Visualization of cross-validation with both grouped data and +multiple classes: + +.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_004.png + :target: ../auto_examples/model_selection/plot_cv_indices.html + :align: center + :scale: 75% + Each fold is constituted by two arrays: the first one is related to the *training set*, and the second one to the *test set*. Thus, one can create the training/test sets using numpy indexing:: From 793ee4d5c11cc2266fc94ab04bc1eb6d6ed3fb3a Mon Sep 17 00:00:00 2001 From: Chris Holdgraf Date: Mon, 30 Jul 2018 14:43:00 -0700 Subject: [PATCH 09/10] adding CV viz images to the docs --- doc/modules/cross_validation.rst | 51 ++++++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst index f58e69e834765..2d05e4b81c69d 100644 --- a/doc/modules/cross_validation.rst +++ b/doc/modules/cross_validation.rst @@ -323,13 +323,13 @@ Example of 2-fold cross-validation on a dataset with 4 samples:: [2 3] [0 1] [0 1] [2 3] -Visualization of cross-validation with both grouped data and -multiple classes: +Here is a visualization of the cross-validation behavior. Note that +:class:`KFold` is not affected by classes or groups. .. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_004.png - :target: ../auto_examples/model_selection/plot_cv_indices.html - :align: center - :scale: 75% + :target: ../auto_examples/model_selection/plot_cv_indices.html + :align: center + :scale: 75% Each fold is constituted by two arrays: the first one is related to the *training set*, and the second one to the *test set*. @@ -479,6 +479,14 @@ Here is a usage example:: [2 7 5 8 0 3 4] [6 1 9] [4 1 0 6 8 9 3] [5 2 7] +Here is a visualization of the cross-validation behavior. Note that +:class:`ShuffleSplit` is not affected by classes or groups. + +.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_006.png + :target: ../auto_examples/model_selection/plot_cv_indices.html + :align: center + :scale: 75% + :class:`ShuffleSplit` is thus a good alternative to :class:`KFold` cross validation that allows a finer control on the number of iterations and the proportion of samples on each side of the train / test split. @@ -514,6 +522,13 @@ two slightly unbalanced classes:: [0 1 3 4 5 8 9] [2 6 7] [0 1 2 4 5 6 7] [3 8 9] +Here is a visualization of the cross-validation behavior. + +.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_007.png + :target: ../auto_examples/model_selection/plot_cv_indices.html + :align: center + :scale: 75% + :class:`RepeatedStratifiedKFold` can be used to repeat Stratified K-Fold n times with different randomization in each repetition. @@ -525,6 +540,13 @@ Stratified Shuffle Split stratified splits, *i.e* which creates splits by preserving the same percentage for each target class as in the complete set. +Here is a visualization of the cross-validation behavior. + +.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_009.png + :target: ../auto_examples/model_selection/plot_cv_indices.html + :align: center + :scale: 75% + .. _group_cv: Cross-validation iterators for grouped data. @@ -577,6 +599,12 @@ Each subject is in a different testing fold, and the same subject is never in both testing and training. Notice that the folds do not have exactly the same size due to the imbalance in the data. +Here is a visualization of the cross-validation behavior. + +.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_005.png + :target: ../auto_examples/model_selection/plot_cv_indices.html + :align: center + :scale: 75% Leave One Group Out ^^^^^^^^^^^^^^^^^^^ @@ -653,6 +681,13 @@ Here is a usage example:: [2 3 4 5] [0 1 6 7] [4 5 6 7] [0 1 2 3] +Here is a visualization of the cross-validation behavior. + +.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_008.png + :target: ../auto_examples/model_selection/plot_cv_indices.html + :align: center + :scale: 75% + This class is useful when the behavior of :class:`LeavePGroupsOut` is desired, but the number of groups is large enough that generating all possible partitions with :math:`P` groups withheld would be prohibitively @@ -717,6 +752,12 @@ Example of 3-split time series cross-validation on a dataset with 6 samples:: [0 1 2 3] [4] [0 1 2 3 4] [5] +Here is a visualization of the cross-validation behavior. + +.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_010.png + :target: ../auto_examples/model_selection/plot_cv_indices.html + :align: center + :scale: 75% A note on shuffling =================== From 900566b143a2272fcfb4adba17b66c3244497843 Mon Sep 17 00:00:00 2001 From: Chris Holdgraf Date: Mon, 30 Jul 2018 17:03:19 -0700 Subject: [PATCH 10/10] get the legend to fit --- examples/model_selection/plot_cv_indices.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/model_selection/plot_cv_indices.py b/examples/model_selection/plot_cv_indices.py index bb3a5c4b8c507..078c3f5e54d19 100644 --- a/examples/model_selection/plot_cv_indices.py +++ b/examples/model_selection/plot_cv_indices.py @@ -141,7 +141,9 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): fig, ax = plt.subplots(figsize=(6, 3)) plot_cv_indices(this_cv, X, y, groups, ax, n_splits) - ax.legend([Patch(color=cmap_cv(.8)), Patch(color=cmap_cv(.2))], + ax.legend([Patch(color=cmap_cv(.8)), Patch(color=cmap_cv(.02))], ['Testing set', 'Training set'], loc=(1.02, .8)) + # Make the legend fit plt.tight_layout() + fig.subplots_adjust(right=.7) plt.show()