|
3 | 3 | Balance model complexity and cross-validated score
|
4 | 4 | ==================================================
|
5 | 5 |
|
6 |
| -This example balances model complexity and cross-validated score by |
7 |
| -finding a decent accuracy within 1 standard deviation of the best accuracy |
8 |
| -score while minimising the number of PCA components [1]. |
| 6 | +This example demonstrates how to balance model complexity and cross-validated score by |
| 7 | +finding a decent accuracy within 1 standard deviation of the best accuracy score while |
| 8 | +minimising the number of :class:`~sklearn.decomposition.PCA` components [1]. It uses |
| 9 | +:class:`~sklearn.model_selection.GridSearchCV` with a custom refit callable to select |
| 10 | +the optimal model. |
9 | 11 |
|
10 | 12 | The figure shows the trade-off between cross-validated score and the number
|
11 |
| -of PCA components. The balanced case is when n_components=10 and accuracy=0.88, |
| 13 | +of PCA components. The balanced case is when `n_components=10` and `accuracy=0.88`, |
12 | 14 | which falls into the range within 1 standard deviation of the best accuracy
|
13 | 15 | score.
|
14 | 16 |
|
15 | 17 | [1] Hastie, T., Tibshirani, R.,, Friedman, J. (2001). Model Assessment and
|
16 | 18 | Selection. The Elements of Statistical Learning (pp. 219-260). New York,
|
17 | 19 | NY, USA: Springer New York Inc..
|
18 |
| -
|
19 | 20 | """
|
20 | 21 |
|
21 | 22 | # Authors: The scikit-learn developers
|
22 | 23 | # SPDX-License-Identifier: BSD-3-Clause
|
23 | 24 |
|
24 | 25 | import matplotlib.pyplot as plt
|
25 | 26 | import numpy as np
|
| 27 | +import polars as pl |
26 | 28 |
|
27 | 29 | from sklearn.datasets import load_digits
|
28 | 30 | from sklearn.decomposition import PCA
|
29 |
| -from sklearn.model_selection import GridSearchCV |
| 31 | +from sklearn.linear_model import LogisticRegression |
| 32 | +from sklearn.model_selection import GridSearchCV, ShuffleSplit |
30 | 33 | from sklearn.pipeline import Pipeline
|
31 |
| -from sklearn.svm import LinearSVC |
| 34 | + |
| 35 | +# %% |
| 36 | +# Introduction |
| 37 | +# ------------ |
| 38 | +# |
| 39 | +# When tuning hyperparameters, we often want to balance model complexity and |
| 40 | +# performance. The "one-standard-error" rule is a common approach: select the simplest |
| 41 | +# model whose performance is within one standard error of the best model's performance. |
| 42 | +# This helps to avoid overfitting by preferring simpler models when their performance is |
| 43 | +# statistically comparable to more complex ones. |
| 44 | + |
| 45 | +# %% |
| 46 | +# Helper functions |
| 47 | +# ---------------- |
| 48 | +# |
| 49 | +# We define two helper functions: |
| 50 | +# 1. `lower_bound`: Calculates the threshold for acceptable performance |
| 51 | +# (best score - 1 std) |
| 52 | +# 2. `best_low_complexity`: Selects the model with the fewest PCA components that |
| 53 | +# exceeds this threshold |
32 | 54 |
|
33 | 55 |
|
34 | 56 | def lower_bound(cv_results):
|
@@ -79,49 +101,280 @@ def best_low_complexity(cv_results):
|
79 | 101 | return best_idx
|
80 | 102 |
|
81 | 103 |
|
| 104 | +# %% |
| 105 | +# Set up the pipeline and parameter grid |
| 106 | +# -------------------------------------- |
| 107 | +# |
| 108 | +# We create a pipeline with two steps: |
| 109 | +# 1. Dimensionality reduction using PCA |
| 110 | +# 2. Classification using LogisticRegression |
| 111 | +# |
| 112 | +# We'll search over different numbers of PCA components to find the optimal complexity. |
| 113 | + |
82 | 114 | pipe = Pipeline(
|
83 | 115 | [
|
84 | 116 | ("reduce_dim", PCA(random_state=42)),
|
85 |
| - ("classify", LinearSVC(random_state=42, C=0.01)), |
| 117 | + ("classify", LogisticRegression(random_state=42, C=0.01, max_iter=1000)), |
86 | 118 | ]
|
87 | 119 | )
|
88 | 120 |
|
89 |
| -param_grid = {"reduce_dim__n_components": [6, 8, 10, 12, 14]} |
| 121 | +param_grid = {"reduce_dim__n_components": [6, 8, 10, 15, 20, 25, 35, 45, 55]} |
| 122 | + |
| 123 | +# %% |
| 124 | +# Perform the search with GridSearchCV |
| 125 | +# ------------------------------------ |
| 126 | +# |
| 127 | +# We use `GridSearchCV` with our custom `best_low_complexity` function as the refit |
| 128 | +# parameter. This function will select the model with the fewest PCA components that |
| 129 | +# still performs within one standard deviation of the best model. |
90 | 130 |
|
91 | 131 | grid = GridSearchCV(
|
92 | 132 | pipe,
|
93 |
| - cv=10, |
94 |
| - n_jobs=1, |
| 133 | + # Use a non-stratified CV strategy to make sure that the inter-fold |
| 134 | + # standard deviation of the test scores is informative. |
| 135 | + cv=ShuffleSplit(n_splits=30, random_state=0), |
| 136 | + n_jobs=1, # increase this on your machine to use more physical cores |
95 | 137 | param_grid=param_grid,
|
96 | 138 | scoring="accuracy",
|
97 | 139 | refit=best_low_complexity,
|
| 140 | + return_train_score=True, |
98 | 141 | )
|
| 142 | + |
| 143 | +# %% |
| 144 | +# Load the digits dataset and fit the model |
| 145 | +# ----------------------------------------- |
| 146 | + |
99 | 147 | X, y = load_digits(return_X_y=True)
|
100 | 148 | grid.fit(X, y)
|
101 | 149 |
|
| 150 | +# %% |
| 151 | +# Visualize the results |
| 152 | +# --------------------- |
| 153 | +# |
| 154 | +# We'll create a bar chart showing the test scores for different numbers of PCA |
| 155 | +# components, along with horizontal lines indicating the best score and the |
| 156 | +# one-standard-deviation threshold. |
| 157 | + |
102 | 158 | n_components = grid.cv_results_["param_reduce_dim__n_components"]
|
103 | 159 | test_scores = grid.cv_results_["mean_test_score"]
|
104 | 160 |
|
105 |
| -plt.figure() |
106 |
| -plt.bar(n_components, test_scores, width=1.3, color="b") |
| 161 | +# Create a polars DataFrame for better data manipulation and visualization |
| 162 | +results_df = pl.DataFrame( |
| 163 | + { |
| 164 | + "n_components": n_components, |
| 165 | + "mean_test_score": test_scores, |
| 166 | + "std_test_score": grid.cv_results_["std_test_score"], |
| 167 | + "mean_train_score": grid.cv_results_["mean_train_score"], |
| 168 | + "std_train_score": grid.cv_results_["std_train_score"], |
| 169 | + "mean_fit_time": grid.cv_results_["mean_fit_time"], |
| 170 | + "rank_test_score": grid.cv_results_["rank_test_score"], |
| 171 | + } |
| 172 | +) |
107 | 173 |
|
108 |
| -lower = lower_bound(grid.cv_results_) |
109 |
| -plt.axhline(np.max(test_scores), linestyle="--", color="y", label="Best score") |
110 |
| -plt.axhline(lower, linestyle="--", color=".5", label="Best score - 1 std") |
| 174 | +# Sort by number of components |
| 175 | +results_df = results_df.sort("n_components") |
111 | 176 |
|
112 |
| -plt.title("Balance model complexity and cross-validated score") |
113 |
| -plt.xlabel("Number of PCA components used") |
114 |
| -plt.ylabel("Digit classification accuracy") |
115 |
| -plt.xticks(n_components.tolist()) |
116 |
| -plt.ylim((0, 1.0)) |
117 |
| -plt.legend(loc="upper left") |
| 177 | +# Calculate the lower bound threshold |
| 178 | +lower = lower_bound(grid.cv_results_) |
118 | 179 |
|
| 180 | +# Get the best model information |
119 | 181 | best_index_ = grid.best_index_
|
| 182 | +best_components = n_components[best_index_] |
| 183 | +best_score = grid.cv_results_["mean_test_score"][best_index_] |
| 184 | + |
| 185 | +# Add a column to mark the selected model |
| 186 | +results_df = results_df.with_columns( |
| 187 | + pl.when(pl.col("n_components") == best_components) |
| 188 | + .then(pl.lit("Selected")) |
| 189 | + .otherwise(pl.lit("Regular")) |
| 190 | + .alias("model_type") |
| 191 | +) |
| 192 | + |
| 193 | +# Get the number of CV splits from the results |
| 194 | +n_splits = sum( |
| 195 | + 1 |
| 196 | + for key in grid.cv_results_.keys() |
| 197 | + if key.startswith("split") and key.endswith("test_score") |
| 198 | +) |
| 199 | + |
| 200 | +# Extract individual scores for each split |
| 201 | +test_scores = np.array( |
| 202 | + [ |
| 203 | + [grid.cv_results_[f"split{i}_test_score"][j] for i in range(n_splits)] |
| 204 | + for j in range(len(n_components)) |
| 205 | + ] |
| 206 | +) |
| 207 | +train_scores = np.array( |
| 208 | + [ |
| 209 | + [grid.cv_results_[f"split{i}_train_score"][j] for i in range(n_splits)] |
| 210 | + for j in range(len(n_components)) |
| 211 | + ] |
| 212 | +) |
| 213 | + |
| 214 | +# Calculate mean and std of test scores |
| 215 | +mean_test_scores = np.mean(test_scores, axis=1) |
| 216 | +std_test_scores = np.std(test_scores, axis=1) |
| 217 | + |
| 218 | +# Find best score and threshold |
| 219 | +best_mean_score = np.max(mean_test_scores) |
| 220 | +threshold = best_mean_score - std_test_scores[np.argmax(mean_test_scores)] |
| 221 | + |
| 222 | +# Create a single figure for visualization |
| 223 | +fig, ax = plt.subplots(figsize=(12, 8)) |
120 | 224 |
|
121 |
| -print("The best_index_ is %d" % best_index_) |
122 |
| -print("The n_components selected is %d" % n_components[best_index_]) |
123 |
| -print( |
124 |
| - "The corresponding accuracy score is %.2f" |
125 |
| - % grid.cv_results_["mean_test_score"][best_index_] |
| 225 | +# Plot individual points |
| 226 | +for i, comp in enumerate(n_components): |
| 227 | + # Plot individual test points |
| 228 | + plt.scatter( |
| 229 | + [comp] * n_splits, |
| 230 | + test_scores[i], |
| 231 | + alpha=0.2, |
| 232 | + color="blue", |
| 233 | + s=20, |
| 234 | + label="Individual test scores" if i == 0 else "", |
| 235 | + ) |
| 236 | + # Plot individual train points |
| 237 | + plt.scatter( |
| 238 | + [comp] * n_splits, |
| 239 | + train_scores[i], |
| 240 | + alpha=0.2, |
| 241 | + color="green", |
| 242 | + s=20, |
| 243 | + label="Individual train scores" if i == 0 else "", |
| 244 | + ) |
| 245 | + |
| 246 | +# Plot mean lines with error bands |
| 247 | +plt.plot( |
| 248 | + n_components, |
| 249 | + np.mean(test_scores, axis=1), |
| 250 | + "-", |
| 251 | + color="blue", |
| 252 | + linewidth=2, |
| 253 | + label="Mean test score", |
| 254 | +) |
| 255 | +plt.fill_between( |
| 256 | + n_components, |
| 257 | + np.mean(test_scores, axis=1) - np.std(test_scores, axis=1), |
| 258 | + np.mean(test_scores, axis=1) + np.std(test_scores, axis=1), |
| 259 | + alpha=0.15, |
| 260 | + color="blue", |
| 261 | +) |
| 262 | + |
| 263 | +plt.plot( |
| 264 | + n_components, |
| 265 | + np.mean(train_scores, axis=1), |
| 266 | + "-", |
| 267 | + color="green", |
| 268 | + linewidth=2, |
| 269 | + label="Mean train score", |
| 270 | +) |
| 271 | +plt.fill_between( |
| 272 | + n_components, |
| 273 | + np.mean(train_scores, axis=1) - np.std(train_scores, axis=1), |
| 274 | + np.mean(train_scores, axis=1) + np.std(train_scores, axis=1), |
| 275 | + alpha=0.15, |
| 276 | + color="green", |
126 | 277 | )
|
| 278 | + |
| 279 | +# Add threshold lines |
| 280 | +plt.axhline( |
| 281 | + best_mean_score, |
| 282 | + color="#9b59b6", # Purple |
| 283 | + linestyle="--", |
| 284 | + label="Best score", |
| 285 | + linewidth=2, |
| 286 | +) |
| 287 | +plt.axhline( |
| 288 | + threshold, |
| 289 | + color="#e67e22", # Orange |
| 290 | + linestyle="--", |
| 291 | + label="Best score - 1 std", |
| 292 | + linewidth=2, |
| 293 | +) |
| 294 | + |
| 295 | +# Highlight selected model |
| 296 | +plt.axvline( |
| 297 | + best_components, |
| 298 | + color="#9b59b6", # Purple |
| 299 | + alpha=0.2, |
| 300 | + linewidth=8, |
| 301 | + label="Selected model", |
| 302 | +) |
| 303 | + |
| 304 | +# Set titles and labels |
| 305 | +plt.xlabel("Number of PCA components", fontsize=12) |
| 306 | +plt.ylabel("Score", fontsize=12) |
| 307 | +plt.title("Model Selection: Balancing Complexity and Performance", fontsize=14) |
| 308 | +plt.grid(True, linestyle="--", alpha=0.7) |
| 309 | +plt.legend( |
| 310 | + bbox_to_anchor=(1.02, 1), |
| 311 | + loc="upper left", |
| 312 | + borderaxespad=0, |
| 313 | +) |
| 314 | + |
| 315 | +# Set axis properties |
| 316 | +plt.xticks(n_components) |
| 317 | +plt.ylim((0.85, 1.0)) |
| 318 | + |
| 319 | +# # Adjust layout |
| 320 | +plt.tight_layout() |
| 321 | + |
| 322 | +# %% |
| 323 | +# Print the results |
| 324 | +# ----------------- |
| 325 | +# |
| 326 | +# We print information about the selected model, including its complexity and |
| 327 | +# performance. We also show a summary table of all models using polars. |
| 328 | + |
| 329 | +print("Best model selected by the one-standard-error rule:") |
| 330 | +print(f"Number of PCA components: {best_components}") |
| 331 | +print(f"Accuracy score: {best_score:.4f}") |
| 332 | +print(f"Best possible accuracy: {np.max(test_scores):.4f}") |
| 333 | +print(f"Accuracy threshold (best - 1 std): {lower:.4f}") |
| 334 | + |
| 335 | +# Create a summary table with polars |
| 336 | +summary_df = results_df.select( |
| 337 | + pl.col("n_components"), |
| 338 | + pl.col("mean_test_score").round(4).alias("test_score"), |
| 339 | + pl.col("std_test_score").round(4).alias("test_std"), |
| 340 | + pl.col("mean_train_score").round(4).alias("train_score"), |
| 341 | + pl.col("std_train_score").round(4).alias("train_std"), |
| 342 | + pl.col("mean_fit_time").round(3).alias("fit_time"), |
| 343 | + pl.col("rank_test_score").alias("rank"), |
| 344 | +) |
| 345 | + |
| 346 | +# Add a column to mark the selected model |
| 347 | +summary_df = summary_df.with_columns( |
| 348 | + pl.when(pl.col("n_components") == best_components) |
| 349 | + .then(pl.lit("*")) |
| 350 | + .otherwise(pl.lit("")) |
| 351 | + .alias("selected") |
| 352 | +) |
| 353 | + |
| 354 | +print("\nModel comparison table:") |
| 355 | +print(summary_df) |
| 356 | + |
| 357 | +# %% |
| 358 | +# Conclusion |
| 359 | +# ---------- |
| 360 | +# |
| 361 | +# The one-standard-error rule helps us select a simpler model (fewer PCA components) |
| 362 | +# while maintaining performance statistically comparable to the best model. |
| 363 | +# This approach can help prevent overfitting and improve model interpretability |
| 364 | +# and efficiency. |
| 365 | +# |
| 366 | +# In this example, we've seen how to implement this rule using a custom refit |
| 367 | +# callable with :class:`~sklearn.model_selection.GridSearchCV`. |
| 368 | +# |
| 369 | +# Key takeaways: |
| 370 | +# 1. The one-standard-error rule provides a good rule of thumb to select simpler models |
| 371 | +# 2. Custom refit callables in :class:`~sklearn.model_selection.GridSearchCV` allow for |
| 372 | +# flexible model selection strategies |
| 373 | +# 3. Visualizing both train and test scores helps identify potential overfitting |
| 374 | +# |
| 375 | +# This approach can be applied to other model selection scenarios where balancing |
| 376 | +# complexity and performance is important, or in cases where a use-case specific |
| 377 | +# selection of the "best" model is desired. |
| 378 | + |
| 379 | +# Display the figure |
127 | 380 | plt.show()
|
0 commit comments