8000 Load toy dataset · scikit-learn/scikit-learn@204dc95 · GitHub
[go: up one dir, main page]

Skip to content

Commit 204dc95

Browse files
committed
Load toy dataset
1 parent 9aec439 commit 204dc95

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

examples/linear_model/plot_sparse_logistic_regression_mnist.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,19 @@
2727
import numpy as np
2828

2929
from sklearn.datasets import fetch_openml
30+
from sklearn.datasets import load_digits
3031
from sklearn.linear_model import LogisticRegression
3132
from sklearn.model_selection import train_test_split
3233
from sklearn.preprocessing import StandardScaler
3334
from sklearn.utils import check_random_state
3435

3536
t0 = time.time()
3637

37-
# Load data from https://www.openml.org/d/554
38-
X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
38+
# Load toy dataset
39+
X, y = load_digits(return_X_y=True, as_frame=False)
40+
41+
# Alternatively, load larger MNIST data set from OpenML, https://www.openml.org/d/554
42+
# X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
3943

4044
random_state = check_random_state(0)
4145
permutation = random_state.permutation(X.shape[0])
@@ -52,7 +56,7 @@
5256
X_train = scaler.fit_transform(X_train)
5357
X_test = scaler.transform(X_test)
5458

55-
clf = LogisticRegression(C=50.0 / train_samples, penalty="l1", solver="saga", tol=0.1, random_state=random_state)
59+
clf = LogisticRegression(C=20.0 / train_samples, penalty="l1", solver="saga", tol=0.1, random_state=random_state)
5660
clf.fit(X_train, y_train)
5761
sparsity = np.mean(clf.coef_ == 0) * 100
5862
score = clf.score(X_test, y_test)

0 commit comments

Comments
 (0)
0