10000 DOC Increase execution speed of plot_sgd_comparison (#21610) · scikit-learn/scikit-learn@2017d99 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2017d99

Browse files
Sven Eschlbeckthomasjpfan
authored andcommitted
DOC Increase execution speed of plot_sgd_comparison (#21610)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent e067a4e commit 2017d99

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

examples/linear_model/plot_sgd_comparison.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
==================================
33
Comparing various online solvers
44
==================================
5-
65
An example showing how different online solvers perform
76
on the hand-written digits dataset.
8-
97
"""
108

119
# Author: Rob Zinkov <rob at zinkov dot com>
@@ -21,22 +19,28 @@
2119
from sklearn.linear_model import LogisticRegression
2220

2321
heldout = [0.95, 0.90, 0.75, 0.50, 0.01]
24-
rounds = 20
22+
# Number of rounds to fit and evaluate an estimator.
23+
rounds = 10
2524
X, y = datasets.load_digits(return_X_y=True)
2625

2726
classifiers = [
28-
("SGD", SGDClassifier(max_iter=100)),
29-
("ASGD", SGDClassifier(average=True)),
30-
("Perceptron", Perceptron()),
27+
("SGD", SGDClassifier(max_iter=110)),
28+
("ASGD", SGDClassifier(max_iter=110, average=True)),
29+
("Perceptron", Perceptron(max_iter=110)),
3130
(
3231
"Passive-Aggressive I",
33-
PassiveAggressiveClassifier(loss="hinge", C=1.0, tol=1e-4),
32+
PassiveAggressiveClassifier(max_iter=110, loss="hinge", C=1.0, tol=1e-4),
3433
),
3534
(
3635
"Passive-Aggressive II",
37-
PassiveAggressiveClassifier(loss="squared_hinge", C=1.0, tol=1e-4),
36+
PassiveAggressiveClassifier(
37+
max_iter=110, loss="squared_hinge", C=1.0, tol=1e-4
38+
),
39+
),
40+
(
41+
"SAG",
42+
LogisticRegression(max_iter=110, solver="sag", tol=1e-1, C=1.0e4 / X.shape[0]),
3843
),
39-
("SAG", LogisticRegression(solver="sag", tol=1e-1, C=1.0e4 / X.shape[0])),
4044
]
4145

4246
xx = 1.0 - np.array(heldout)

0 commit comments

Comments
 (0)
0