diff --git a/examples/linear_model/plot_sgd_early_stopping.py b/examples/linear_model/plot_sgd_early_stopping.py index c4b99335dff5f..17d905a8fec78 100644 --- a/examples/linear_model/plot_sgd_early_stopping.py +++ b/examples/linear_model/plot_sgd_early_stopping.py @@ -48,7 +48,7 @@ import numpy as np import matplotlib.pyplot as plt -from sklearn import linear_model +from sklearn.linear_model import SGDClassifier from sklearn.datasets import fetch_openml from sklearn.model_selection import train_test_split from sklearn.utils._testing import ignore_warnings @@ -89,17 +89,15 @@ def fit_and_score(estimator, max_iter, X_train, X_test, y_train, y_test): # Define the estimators to compare estimator_dict = { - "No stopping criterion": linear_model.SGDClassifier(n_iter_no_change=3), - "Training loss": linear_model.SGDClassifier( - early_stopping=False, n_iter_no_change=3, tol=0.1 - ), - "Validation score": linear_model.SGDClassifier( + "No stopping criterion": SGDClassifier(n_iter_no_change=3), + "Training loss": SGDClassifier(early_stopping=False, n_iter_no_change=3, tol=0.1), + "Validation score": SGDClassifier( early_stopping=True, n_iter_no_change=3, tol=0.0001, validation_fraction=0.2 ), } # Load the dataset -X, y = load_mnist(n_samples=10000) +X, y = load_mnist(n_samples=5000) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0) results = []