8000 DOC improve plot_grid_search_refit_callable.py and add links (#30990) · scikit-learn/scikit-learn@18cdea7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 18cdea7

Browse files
adrinjalalilorentzenchrogrisel
authored
DOC improve plot_grid_search_refit_callable.py and add links (#30990)
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 9b40cbc commit 18cdea7

File tree

5 files changed

+322
-35
lines changed

5 files changed

+322
-35
lines changed

doc/whats_new/v0.20.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ Miscellaneous
445445

446446
- |API| Removed all mentions of ``sklearn.externals.joblib``, and deprecated
447447
joblib methods exposed in ``sklearn.utils``, except for
448-
:func:`utils.parallel_backend` and :func:`utils.register_parallel_backend`,
448+
`utils.parallel_backend` and `utils.register_parallel_backend`,
449449
which allow users to configure parallel computation in scikit-learn.
450450
Other functionalities are part of `joblib <https://joblib.readthedocs.io/>`_.
451451
package and should be used directly, by installing it.

doc/whats_new/v1.5.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ Changelog
656656
- |API| :func:`utils.tosequence` is deprecated and will be removed in version 1.7.
657657
:pr:`28763` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
658658

659-
- |API| :class:`utils.parallel_backend` and :func:`utils.register_parallel_backend` are
659+
- |API| `utils.parallel_backend` and `utils.register_parallel_backend` are
660660
deprecated and will be removed in version 1.7. Use `joblib.parallel_backend` and
661661
`joblib.register_parallel_backend` instead.
662662
:pr:`28847` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

examples/model_selection/plot_grid_search_refit_callable.py

Lines changed: 280 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,54 @@
33
Balance model complexity and cross-validated score
44
==================================================
55
6-
This example balances model complexity and cross-validated score by
7-
finding a decent accuracy within 1 standard deviation of the best accuracy
8-
score while minimising the number of PCA components [1].
6+
This example demonstrates how to balance model complexity and cross-validated score by
7+
finding a decent accuracy within 1 standard deviation of the best accuracy score while
8+
minimising the number of :class:`~sklearn.decomposition.PCA` components [1]. It uses
9+
:class:`~sklearn.model_selection.GridSearchCV` with a custom refit callable to select
10+
the optimal model.
911
1012
The figure shows the trade-off between cross-validated score and the number
11-
of PCA components. The balanced case is when n_components=10 and accuracy=0.88,
13+
of PCA components. The balanced case is when `n_components=10` and `accuracy=0.88`,
1214
which falls into the range within 1 standard deviation of the best accuracy
1315
score.
1416
1517
[1] Hastie, T., Tibshirani, R.,, Friedman, J. (2001). Model Assessment and
1618
Selection. The Elements of Statistical Learning (pp. 219-260). New York,
1719
NY, USA: Springer New York Inc..
18-
1920
"""
2021

2122
# Authors: The scikit-learn developers
2223
# SPDX-License-Identifier: BSD-3-Clause
2324

2425
import matplotlib.pyplot as plt
2526
import numpy as np
27+
import polars as pl
2628

2729
from sklearn.datasets import load_digits
2830
from sklearn.decomposition import PCA
29-
from sklearn.model_selection import GridSearchCV
31+
from sklearn.linear_model import LogisticRegression
32+
from sklearn.model_selection import GridSearchCV, ShuffleSplit
3033
from sklearn.pipeline import Pipeline
31-
from sklearn.svm import LinearSVC
34+
35+
# %%
36+
# Introduction
37+
# ------------
38+
#
39+
# When tuning hyperparameters, we often want to balance model complexity and
40+
# performance. The "one-standard-error" rule is a common approach: select the simplest
41+
# model whose performance is within one standard error of the best model's performance.
42+
# This helps to avoid overfitting by preferring simpler models when their performance is
43+
# statistically comparable to more complex ones.
44+
45+
# %%
46+
# Helper functions
47+
# ----------------
48+
#
49+
# We define two helper functions:
50+
# 1. `lower_bound`: Calculates the threshold for acceptable performance
51+
# (best score - 1 std)
52+
# 2. `best_low_complexity`: Selects the model with the fewest PCA components that
53+
# exceeds this threshold
3254

3355

3456
def lower_bound(cv_results):
@@ -79,49 +101,280 @@ def best_low_complexity(cv_results):
79101
return best_idx
80102

81103

104+
# %%
105+
# Set up the pipeline and parameter grid
106+
# --------------------------------------
107+
#
108+
# We create a pipeline with two steps:
109+
# 1. Dimensionality reduction using PCA
110+
# 2. Classification using LogisticRegression
111+
#
112+
# We'll search over different numbers of PCA components to find the optimal complexity.
113+
82114
pipe = Pipeline(
83115
[
84116
("reduce_dim", PCA(random_state=42)),
85-
("classify", LinearSVC(random_state=42, C=0.01)),
117+
("classify", LogisticRegression(random_state=42, C=0.01, max_iter=1000)),
86118
]
87119
)
88120

89-
param_grid = {"reduce_dim__n_components": [6, 8, 10, 12, 14]}
121+
param_grid = {"reduce_dim__n_components": [6, 8, 10, 15, 20, 25, 35, 45, 55]}
122+
123+
# %%
124+
# Perform the search with GridSearchCV
125+
# ------------------------------------
126+
#
127+
# We use `GridSearchCV` with our custom `best_low_complexity` function as the refit
128+
# parameter. This function will select the model with the fewest PCA components that
129+
# still performs within one standard deviation of the best model.
90130

91131
grid = GridSearchCV(
92132
pipe,
93-
cv=10,
94-
n_jobs=1,
133+
# Use a non-stratified CV strategy to make sure that the inter-fold
134+
# standard deviation of the test scores is informative.
135+
cv=ShuffleSplit(n_splits=30, random_state=0),
136+
n_jobs=1, # increase this on your machine to use more physical cores
95137
param_grid=param_grid,
96138
scoring="accuracy",
97139
refit=best_low_complexity,
140+
return_train_score=True,
98141
)
142+
143+
# %%
144+
# Load the digits dataset and fit the model
145+
# -----------------------------------------
146+
99147
X, y = load_digits(return_X_y=True)
100148
grid.fit(X, y)
101149

150+
# %%
151+
# Visualize the results
152+
# ---------------------
153+
#
154+
# We'll create a bar chart showing the test scores for different numbers of PCA
155+
# components, along with horizontal lines indicating the best score and the
156+
# one-standard-deviation threshold.
157+
102158
n_components = grid.cv_results_["param_reduce_dim__n_components"]
103159
test_scores = grid.cv_results_["mean_test_score"]
104160

105-
plt.figure()
106-
plt.bar(n_components, test_scores, width=1.3, color="b")
161+
# Create a polars DataFrame for better data manipulation and visualization
162+
results_df = pl.DataFrame(
163+
{
164+
"n_components": n_components,
165+
"mean_test_score": test_scores,
166+
"std_test_score": grid.cv_results_["std_test_score"],
167+
"mean_train_score": grid.cv_results_["mean_train_score"],
168+
"std_train_score": grid.cv_results_["std_train_score"],
169+
"mean_fit_time": grid.cv_results_["mean_fit_time"],
170+
"rank_test_score": grid.cv_results_["rank_test_score"],
171+
}
172+
)
107173

108-
lower = lower_bound(grid.cv_results_)
109-
plt.axhline(np.max(test_scores), linestyle="--", color="y", label="Best score")
110-
plt.axhline(lower, linestyle="--", color=".5", label="Best score - 1 std")
174+
# Sort by number of components
175+
results_df = results_df.sort("n_components")
111176

112-
plt.title("Balance model complexity and cross-validated score")
113-
plt.xlabel("Number of PCA components used")
114-
plt.ylabel("Digit classification accuracy")
115-
plt.xticks(n_components.tolist())
116-
plt.ylim((0, 1.0))
117-
plt.legend(loc="upper left")
177+
# Calculate the lower bound threshold
178+
lower = lower_bound(grid.cv_results_)
118179

180+
# Get the best model information
119181
best_index_ = grid.best_index_
182+
best_components = n_components[best_index_]
183+
best_score = grid.cv_results_["mean_test_score"][best_index_]
184+
185+
# Add a column to mark the selected model
186+
results_df = results_df.with_columns(
187+
pl.when(pl.col("n_components") == best_components)
188+
.then(pl.lit("Selected"))
189+
.otherwise(pl.lit("Regular"))
190+
.alias("model_type")
191+
)
192+
193+
# Get the number of CV splits from the results
194+
n_splits = sum(
195+
1
196+
for key in grid.cv_results_.keys()
197+
if key.startswith("split") and key.endswith("test_score")
198+
)
199+
200+
# Extract individual scores for each split
201+
test_scores = np.array(
202+
[
203+
[grid.cv_results_[f"split{i}_test_score"][j] for i in range(n_splits)]
204+
for j in range(len(n_components))
205+
]
206+
)
207+
train_scores = np.array(
208+
[
209+
[grid.cv_results_[f"split{i}_train_score"][j] for i in range(n_splits)]
210+
for j in range(len(n_components))
211+
]
212+
)
213+
214+
# Calculate mean and std of test scores
215+
mean_test_scores = np.mean(test_scores, axis=1)
216+
std_test_scores = np.std(test_scores, axis=1)
217+
218+
# Find best score and threshold
219+
best_mean_score = np.max(mean_test_scores)
220+
threshold = best_mean_score - std_test_scores[np.argmax(mean_test_scores)]
221+
222+
# Create a single figure for visualization
223+
fig, ax = plt.subplots(figsize=(12, 8))
120224

121-
print("The best_index_ is %d" % best_index_)
122-
print("The n_components selected is %d" % n_components[best_index_])
123-
print(
124-
"The corresponding accuracy score is %.2f"
125-
% grid.cv_results_["mean_test_score"][best_index_]
225+
# Plot individual points
226+
for i, comp in enumerate(n_components):
227+
# Plot individual test points
228+
plt.scatter(
229+
[comp] * n_splits,
230+
test_scores[i],
231+
alpha=0.2,
232+
color="blue",
233+
s=20,
234+
label="Individual test scores" if i == 0 else "",
235+
)
236+
# Plot individual train points
237+
plt.scatter(
238+
[comp] * n_splits,
239+
train_scores[i],
240+
alpha=0.2,
241+
color="green",
242+
s=20,
243+
label="Individual train scores" if i == 0 else "",
244+
)
245+
246+
# Plot mean lines with error bands
247+
plt.plot(
248+
n_components,
249+
np.mean(test_scores, axis=1),
250+
"-",
251+
color="blue",
252+
linewidth=2,
253+
label="Mean test score",
254+
)
255+
plt.fill_between(
256+
n_components,
257+
np.mean(test_scores, axis=1) - np.std(test_scores, axis=1),
258+
np.mean(test_scores, axis=1) + np.std(test_scores, axis=1),
259+
alpha=0.15,
260+
color="blue",
261+
)
262+
263+
plt.plot(
264+
n_components,
265+
np.mean(train_scores, axis=1),
266+
"-",
267+
color="green",
268+
linewidth=2,
269+
label="Mean train score",
270+
)
271+
plt.fill_between(
272+
n_components,
273+
np.mean(train_scores, axis=1) - np.std(train_scores, axis=1),
274+
np.mean(train_scores, axis=1) + np.std(train_scores, axis=1),
275+
alpha=0.15,
276+
color="green",
126277
)
278+
279+
# Add threshold lines
280+
plt.axhline(
281+
best_mean_score,
282+
color="#9b59b6", # Purple
283+
linestyle="--",
284+
label="Best score",
285+
linewidth=2,
286+
)
287+
plt.axhline(
288+
threshold,
289+
color="#e67e22", # Orange
290+
linestyle="--",
291+
label="Best score - 1 std",
292+
linewidth=2,
293+
)
294+
295+
# Highlight selected model
296+
plt.axvline(
297+
best_components,
298+
color="#9b59b6", # Purple
299+
alpha=0.2,
300+
linewidth=8,
301+
label="Selected model",
302+
)
303+
304+
# Set titles and labels
305+
plt.xlabel("Number of PCA components", fontsize=12)
306+
plt.ylabel("Score", fontsize=12)
307+
plt.title("Model Selection: Balancing Complexity and Performance", fontsize=14)
308+
plt.grid(True, linestyle="--", alpha=0.7)
309+
plt.legend(
310+
bbox_to_anchor=(1.02, 1),
311+
loc="upper left",
312+
borderaxespad=0,
313+
)
314+
315+
# Set axis properties
316+
plt.xticks(n_components)
317+
plt.ylim((0.85, 1.0))
318+
319+
# # Adjust layout
320+
plt.tight_layout()
321+
322+
# %%
323+
# Print the results
324+
# -----------------
325+
#
326+
# We print information about the selected model, including its complexity and
327+
# performance. We also show a summary table of all models using polars.
328+
329+
print("Best model selected by the one-standard-error rule:")
330+
print(f"Number of PCA components: {best_components}")
331+
print(f"Accuracy score: {best_score:.4f}")
332+
print(f"Best possible accuracy: {np.max(test_scores):.4f}")
333+
print(f"Accuracy threshold (best - 1 std): {lower:.4f}")
334+
335+
# Create a summary table with polars
336+
summary_df = results_df.select(
337+
pl.col("n_components"),
338+
pl.col("mean_test_score").round(4).alias("test_score"),
339+
pl.col("std_test_score").round(4).alias("test_std"),
340+
pl.col("mean_train_score").round(4).alias("train_score"),
341+
pl.col("std_train_score").round(4).alias("train_std"),
342+
pl.col("mean_fit_time").round(3).alias("fit_time"),
343+
pl.col("rank_test_score").alias("rank"),
344+
)
345+
346+
# Add a column to mark the selected model
347+
summary_df = summary_df.with_columns(
348+
pl.when(pl.col("n_components") == best_components)
349+
.then(pl.lit("*"))
350+
.otherwise(pl.lit(""))
351+
.alias("selected")
352+
)
353+
354+
print("\nModel comparison table:")
355+
print(summary_df)
356+
357+
# %%
358+
# Conclusion
359+
# ----------
360+
#
361+
# The one-standard-error rule helps us select a simpler model (fewer PCA components)
362+
# while maintaining performance statistically comparable to the best model.
363+
# This approach can help prevent overfitting and improve model interpretability
364+
# and efficiency.
365+
#
366+
# In this example, we've seen how to implement this rule using a custom refit
367+
# callable with :class:`~sklearn.model_selection.GridSearchCV`.
368+
#
369+
# Key takeaways:
370+
# 1. The one-standard-error rule provides a good rule of thumb to select simpler models
371+
# 2. Custom refit callables in :class:`~sklearn.model_selection.GridSearchCV` allow for
372+
# flexible model selection strategies
373+
# 3. Visualizing both train and test scores helps identify potential overfitting
374+
#
375+
# This approach can be applied to other model selection scenarios where balancing
376+
# complexity and performance is important, or in cases where a use-case specific
377+
# selection of the "best" model is desired.
378+
379+
# Display the figure
127380
plt.show()

0 commit comments

Comments
 (0)
0